@@ -36,7 +36,7 @@ def __init__(self):
3636 self .iscolab ,
3737 self .noparallel ,
3838 self .noautoopen ,
39- self .dml
39+ self .dml ,
4040 ) = self .arg_parse ()
4141 self .instead = ""
4242 self .x_pad , self .x_query , self .x_center , self .x_max = self .device_config ()
@@ -71,7 +71,7 @@ def arg_parse() -> tuple:
7171 cmd_opts .colab ,
7272 cmd_opts .noparallel ,
7373 cmd_opts .noautoopen ,
74- cmd_opts .dml
74+ cmd_opts .dml ,
7575 )
7676
7777 # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
@@ -149,26 +149,38 @@ def device_config(self) -> tuple:
149149 if self .dml :
150150 print ("use DirectML instead" )
151151 try :
152- os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
152+ os .rename (
153+ "runtime\Lib\site-packages\onnxruntime" ,
154+ "runtime\Lib\site-packages\onnxruntime-cuda" ,
155+ )
153156 except :
154157 pass
155158 try :
156- os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
159+ os .rename (
160+ "runtime\Lib\site-packages\onnxruntime-dml" ,
161+ "runtime\Lib\site-packages\onnxruntime" ,
162+ )
157163 except :
158-
159164 pass
160165 import torch_directml
166+
161167 self .device = torch_directml .device (torch_directml .default_device ())
162168 self .is_half = False
163169 else :
164170 if self .instead :
165171 print (f"use { self .instead } instead" )
166172 try :
167- os .rename ("runtime\Lib\site-packages\onnxruntime" ,"runtime\Lib\site-packages\onnxruntime-cuda" )
173+ os .rename (
174+ "runtime\Lib\site-packages\onnxruntime" ,
175+ "runtime\Lib\site-packages\onnxruntime-cuda" ,
176+ )
168177 except :
169178 pass
170179 try :
171- os .rename ("runtime\Lib\site-packages\onnxruntime-dml" ,"runtime\Lib\site-packages\onnxruntime" )
180+ os .rename (
181+ "runtime\Lib\site-packages\onnxruntime-dml" ,
182+ "runtime\Lib\site-packages\onnxruntime" ,
183+ )
172184 except :
173185 pass
174186 return x_pad , x_query , x_center , x_max
0 commit comments