1+ import os
12import argparse
23import sys
34import torch
@@ -35,11 +36,13 @@ def __init__(self):
3536 self .iscolab ,
3637 self .noparallel ,
3738 self .noautoopen ,
39+ self .dml
3840 ) = self .arg_parse ()
39- self .instead = ""
41+ self .instead = ""
4042 self .x_pad , self .x_query , self .x_center , self .x_max = self .device_config ()
4143
42- def arg_parse (self ) -> tuple :
44+ @staticmethod
45+ def arg_parse () -> tuple :
4346 exe = sys .executable or "python"
4447 parser = argparse .ArgumentParser ()
4548 parser .add_argument ("--port" , type = int , default = 7865 , help = "Listen port" )
@@ -61,13 +64,14 @@ def arg_parse(self) -> tuple:
6164 cmd_opts = parser .parse_args ()
6265
6366 cmd_opts .port = cmd_opts .port if 0 <= cmd_opts .port <= 65535 else 7865
64- self . dml = cmd_opts . dml
67+
6568 return (
6669 cmd_opts .pycmd ,
6770 cmd_opts .port ,
6871 cmd_opts .colab ,
6972 cmd_opts .noparallel ,
7073 cmd_opts .noautoopen ,
74+ cmd_opts .dml
7175 )
7276
7377 # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
@@ -112,12 +116,12 @@ def device_config(self) -> tuple:
112116 f .write (strr )
113117 elif self .has_mps ():
114118 print ("No supported Nvidia GPU found" )
115- self .device = self .instead = "mps"
119+ self .device = self .instead = "mps"
116120 self .is_half = False
117121 use_fp32_config ()
118122 else :
119123 print ("No supported Nvidia GPU found" )
120- self .device = self .instead = "cpu"
124+ self .device = self .instead = "cpu"
121125 self .is_half = False
122126 use_fp32_config ()
123127
@@ -137,25 +141,34 @@ def device_config(self) -> tuple:
137141 x_center = 38
138142 x_max = 41
139143
140- if self .gpu_mem != None and self .gpu_mem <= 4 :
144+ if self .gpu_mem is not None and self .gpu_mem <= 4 :
141145 x_pad = 1
142146 x_query = 5
143147 x_center = 30
144148 x_max = 32
145- if ( self .dml == True ) :
149+ if self .dml :
146150 print ("use DirectML instead" )
147- try :os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
148- except :pass
149- try :os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
150- except :pass
151+ try :
152+ os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
153+ except :
154+ pass
155+ try :
156+ os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
157+ except :
158+
159+ pass
151160 import torch_directml
152- self .device = torch_directml .device (torch_directml .default_device ())
153- self .is_half = False
161+ self .device = torch_directml .device (torch_directml .default_device ())
162+ self .is_half = False
154163 else :
155- if (self .instead ):
156- print ("use %s instead" % self .instead )
157- try :os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
158- except :pass
159- try :os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
160- except :pass
164+ if self .instead :
165+ print (f"use { self .instead } instead" )
166+ try :
167+ os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
168+ except :
169+ pass
170+ try :
171+ os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
172+ except :
173+ pass
161174 return x_pad , x_query , x_center , x_max
0 commit comments