1+ import soundfile as sf
2+ import torch ,pdb ,time ,argparse ,os ,warnings ,sys ,librosa
3+ import numpy as np
4+ import onnxruntime as ort
5+ from scipy .io .wavfile import write
6+ from tqdm import tqdm
7+ import torch
8+ import torch .nn as nn
9+
10+ dim_c = 4
11+ class Conv_TDF_net_trim ():
12+ def __init__ (self , device , model_name , target_name ,
13+ L , dim_f , dim_t , n_fft , hop = 1024 ):
14+ super (Conv_TDF_net_trim , self ).__init__ ()
15+
16+ self .dim_f = dim_f
17+ self .dim_t = 2 ** dim_t
18+ self .n_fft = n_fft
19+ self .hop = hop
20+ self .n_bins = self .n_fft // 2 + 1
21+ self .chunk_size = hop * (self .dim_t - 1 )
22+ self .window = torch .hann_window (window_length = self .n_fft , periodic = True ).to (device )
23+ self .target_name = target_name
24+ self .blender = 'blender' in model_name
25+
26+ out_c = dim_c * 4 if target_name == '*' else dim_c
27+ self .freq_pad = torch .zeros ([1 , out_c , self .n_bins - self .dim_f , self .dim_t ]).to (device )
28+
29+ self .n = L // 2
30+
31+ def stft (self , x ):
32+ x = x .reshape ([- 1 , self .chunk_size ])
33+ x = torch .stft (x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True , return_complex = True )
34+ x = torch .view_as_real (x )
35+ x = x .permute ([0 , 3 , 1 , 2 ])
36+ x = x .reshape ([- 1 , 2 , 2 , self .n_bins , self .dim_t ]).reshape ([- 1 , dim_c , self .n_bins , self .dim_t ])
37+ return x [:, :, :self .dim_f ]
38+
39+ def istft (self , x , freq_pad = None ):
40+ freq_pad = self .freq_pad .repeat ([x .shape [0 ], 1 , 1 , 1 ]) if freq_pad is None else freq_pad
41+ x = torch .cat ([x , freq_pad ], - 2 )
42+ c = 4 * 2 if self .target_name == '*' else 2
43+ x = x .reshape ([- 1 , c , 2 , self .n_bins , self .dim_t ]).reshape ([- 1 , 2 , self .n_bins , self .dim_t ])
44+ x = x .permute ([0 , 2 , 3 , 1 ])
45+ x = x .contiguous ()
46+ x = torch .view_as_complex (x )
47+ x = torch .istft (x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True )
48+ return x .reshape ([- 1 , c , self .chunk_size ])
49+ def get_models (device , dim_f , dim_t , n_fft ):
50+ return Conv_TDF_net_trim (
51+ device = device ,
52+ model_name = 'Conv-TDF' , target_name = 'vocals' ,
53+ L = 11 ,
54+ dim_f = dim_f , dim_t = dim_t ,
55+ n_fft = n_fft
56+ )
57+
58+ warnings .filterwarnings ("ignore" )
59+ cpu = torch .device ('cpu' )
60+ device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
61+
62+ class Predictor :
63+ def __init__ (self ,args ):
64+ self .args = args
65+ self .model_ = get_models (device = cpu , dim_f = args .dim_f , dim_t = args .dim_t , n_fft = args .n_fft )
66+ self .model = ort .InferenceSession (os .path .join (args .onnx ,self .model_ .target_name + '.onnx' ), providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
67+ print ('onnx load done' )
68+ def demix (self , mix ):
69+ samples = mix .shape [- 1 ]
70+ margin = self .args .margin
71+ chunk_size = self .args .chunks * 44100
72+ assert not margin == 0 , 'margin cannot be zero!'
73+ if margin > chunk_size :
74+ margin = chunk_size
75+
76+ segmented_mix = {}
77+
78+ if self .args .chunks == 0 or samples < chunk_size :
79+ chunk_size = samples
80+
81+ counter = - 1
82+ for skip in range (0 , samples , chunk_size ):
83+ counter += 1
84+
85+ s_margin = 0 if counter == 0 else margin
86+ end = min (skip + chunk_size + margin , samples )
87+
88+ start = skip - s_margin
89+
90+ segmented_mix [skip ] = mix [:,start :end ].copy ()
91+ if end == samples :
92+ break
93+
94+ sources = self .demix_base (segmented_mix , margin_size = margin )
95+ '''
96+ mix:(2,big_sample)
97+ segmented_mix:offset->(2,small_sample)
98+ sources:(1,2,big_sample)
99+ '''
100+ return sources
101+ def demix_base (self , mixes , margin_size ):
102+ chunked_sources = []
103+ progress_bar = tqdm (total = len (mixes ))
104+ progress_bar .set_description ("Processing" )
105+ for mix in mixes :
106+ cmix = mixes [mix ]
107+ sources = []
108+ n_sample = cmix .shape [1 ]
109+ model = self .model_
110+ trim = model .n_fft // 2
111+ gen_size = model .chunk_size - 2 * trim
112+ pad = gen_size - n_sample % gen_size
113+ mix_p = np .concatenate ((np .zeros ((2 ,trim )), cmix , np .zeros ((2 ,pad )), np .zeros ((2 ,trim ))), 1 )
114+ mix_waves = []
115+ i = 0
116+ while i < n_sample + pad :
117+ waves = np .array (mix_p [:, i :i + model .chunk_size ])
118+ mix_waves .append (waves )
119+ i += gen_size
120+ mix_waves = torch .tensor (mix_waves , dtype = torch .float32 ).to (cpu )
121+ with torch .no_grad ():
122+ _ort = self .model
123+ spek = model .stft (mix_waves )
124+ if self .args .denoise :
125+ spec_pred = - _ort .run (None , {'input' : - spek .cpu ().numpy ()})[0 ]* 0.5 + _ort .run (None , {'input' : spek .cpu ().numpy ()})[0 ]* 0.5
126+ tar_waves = model .istft (torch .tensor (spec_pred ))
127+ else :
128+ tar_waves = model .istft (torch .tensor (_ort .run (None , {'input' : spek .cpu ().numpy ()})[0 ]))
129+ tar_signal = tar_waves [:,:,trim :- trim ].transpose (0 ,1 ).reshape (2 , - 1 ).numpy ()[:, :- pad ]
130+
131+ start = 0 if mix == 0 else margin_size
132+ end = None if mix == list (mixes .keys ())[::- 1 ][0 ] else - margin_size
133+ if margin_size == 0 :
134+ end = None
135+ sources .append (tar_signal [:,start :end ])
136+
137+ progress_bar .update (1 )
138+
139+ chunked_sources .append (sources )
140+ _sources = np .concatenate (chunked_sources , axis = - 1 )
141+ # del self.model
142+ progress_bar .close ()
143+ return _sources
144+ def prediction (self , m ,vocal_root ,others_root ):
145+ os .makedirs (vocal_root ,exist_ok = True )
146+ os .makedirs (others_root ,exist_ok = True )
147+ basename = os .path .basename (m )
148+ mix , rate = librosa .load (m , mono = False , sr = 44100 )
149+ if mix .ndim == 1 :
150+ mix = np .asfortranarray ([mix ,mix ])
151+ mix = mix .T
152+ sources = self .demix (mix .T )
153+ opt = sources [0 ].T
154+ sf .write ("%s/%s_main_vocal.wav" % (vocal_root ,basename ), mix - opt , rate )
155+ sf .write ("%s/%s_others.wav" % (others_root ,basename ), opt , rate )
156+
157+ class MDXNetDereverb ():
158+ def __init__ (self ,chunks ):
159+ self .onnx = "uvr5_weights/onnx_dereverb_By_FoxJoy"
160+ self .shifts = 10 #'Predict with randomised equivariant stabilisation'
161+ self .mixing = "min_mag" #['default','min_mag','max_mag']
162+ self .chunks = chunks
163+ self .margin = 44100
164+ self .dim_t = 9
165+ self .dim_f = 3072
166+ self .n_fft = 6144
167+ self .denoise = True
168+ self .pred = Predictor (self )
169+
170+ def _path_audio_ (self ,input ,vocal_root ,others_root ):
171+ self .pred .prediction (input ,vocal_root ,others_root )
172+
173+ if __name__ == '__main__' :
174+ dereverb = MDXNetDereverb (15 )
175+ from time import time as ttime
176+ t0 = ttime ()
177+ dereverb ._path_audio_ (
178+ "雪雪伴奏对消HP5.wav" ,
179+ "vocal" ,
180+ "others" ,
181+ )
182+ t1 = ttime ()
183+ print (t1 - t0 )
184+
185+
186+ '''
187+
188+ runtime\python.exe MDXNet.py
189+
190+ 6G:
191+ 15/9:0.8G->6.8G
192+ 14:0.8G->6.5G
193+ 25:炸
194+
195+ half15:0.7G->6.6G,22.69s
196+ fp32-15:0.7G->6.6G,20.85s
197+
198+ '''
0 commit comments