[2noise/ChatTTS]报错:stft input and window must be on the same device but got self on cpu and window on cuda

2024-10-17 78 views
4
chat.load(compile=False)

加载模型后提示使用cuda:0。执行示例代码的Zero shot下的chat.sample_audio_speaker时,会提示

stft input and window must be on the same device but got self on cpu and window on cuda:0

错误。需要修改ChatTTS\ChatTTS\model\dvae.py文件,在第200行处插入

audio = audio.to('cuda')

才可解决。但是当chat.load()载入时,如检测到显存不足,使用了cpu。则需要把这行删除才能正常运行。否则会报stft的输入使用了gpu,但是window使用了cpu。

RuntimeError                              Traceback (most recent call last)
Cell In[5], line 3
      1 from tools.audio import load_audio
----> 3 spk_smp = chat.sample_audio_speaker(load_audio("zm.mp3", 24000))
      4 print(spk_smp)  # save it in order to load the speaker without sample audio next time
      6 params_infer_code = ChatTTS.Chat.InferCodeParams(
      7     spk_smp=spk_smp,
      8     txt_smp="每到夏天就有一个特别有名的小吃,就是冰粉。现在的冰粉从最传统的红糖冰粉已经发展到了现在的糍粑冰粉跟玫瑰冰粉各种各样的冰粉然后今天的天气也很。",
      9 )

File D:\ChatTTS\ChatTTS\core.py:163, in Chat.sample_audio_speaker(self, wav, device)
    162 def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor], device: str = 'cuda:0') -> str:
--> 163     return self.speaker.encode_prompt(self.dvae.sample_audio(wav))

File D:\ChatTTS\.conda\lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File D:\ChatTTS\ChatTTS\model\dvae.py:292, in DVAE.sample_audio(self, wav, device)
    290 if isinstance(wav, np.ndarray):
    291     wav = torch.from_numpy(wav)
--> 292 return self(wav, "encode").squeeze_(0)

File D:\ChatTTS\ChatTTS\model\dvae.py:248, in DVAE.__call__(self, inp, mode)
    245 def __call__(
    246     self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
    247 ) -> torch.Tensor:
--> 248     return super().__call__(inp, mode)

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File D:\ChatTTS\.conda\lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File D:\ChatTTS\ChatTTS\model\dvae.py:255, in DVAE.forward(self, inp, mode)
    250 @torch.inference_mode()
    251 def forward(
    252     self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
    253 ) -> torch.Tensor:
    254     if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
--> 255         mel = self.preprocessor_mel(inp)
    256         x: torch.Tensor = self.downsample_conv(
    257             torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
    258         ).unsqueeze_(0)
    259         del mel

File D:\ChatTTS\ChatTTS\model\dvae.py:197, in MelSpectrogramFeatures.__call__(self, audio)
    196 def __call__(self, audio: torch.Tensor) -> torch.Tensor:
--> 197     return super().__call__(audio)

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File D:\ChatTTS\ChatTTS\model\dvae.py:200, in MelSpectrogramFeatures.forward(self, audio)
    199 def forward(self, audio: torch.Tensor) -> torch.Tensor:
--> 200     mel: torch.Tensor = self.mel_spec(audio)
    201     features = torch.log(torch.clip(mel, min=1e-5))
    202     return features

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File D:\ChatTTS\.conda\lib\site-packages\torchaudio\transforms\_transforms.py:619, in MelSpectrogram.forward(self, waveform)
    611 def forward(self, waveform: Tensor) -> Tensor:
    612     r"""
    613     Args:
    614         waveform (Tensor): Tensor of audio of dimension (..., time).
   (...)
    617         Tensor: Mel frequency spectrogram of size (..., `n_mels, time).
    618     """
--> 619     specgram = self.spectrogram(waveform)
    620     mel_specgram = self.mel_scale(specgram)
    621     return mel_specgram

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File D:\ChatTTS\.conda\lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File D:\ChatTTS\.conda\lib\site-packages\torchaudio\transforms\_transforms.py:110, in Spectrogram.forward(self, waveform)
    100 def forward(self, waveform: Tensor) -> Tensor:
    101     r"""
    102     Args:
    103         waveform (Tensor): Tensor of audio of dimension (..., time).
   (...)
    108         Fourier bins, and time is the number of window hops (n_frame).
    109     """
--> 110     return F.spectrogram(
    111         waveform,
    112         self.pad,
    113         self.window,
    114         self.n_fft,
    115         self.hop_length,
    116         self.win_length,
    117         self.power,
    118         self.normalized,
    119         self.center,
    120         self.pad_mode,
    121         self.onesided,
    122     )

File D:\ChatTTS\.conda\lib\site-packages\torchaudio\functional\functional.py:126, in spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized, center, pad_mode, onesided, return_complex)
    123 waveform = waveform.reshape(-1, shape[-1])
    125 # default values are consistent with librosa.core.spectrum._spectrogram
--> 126 spec_f = torch.stft(
    127     input=waveform,
    128     n_fft=n_fft,
    129     hop_length=hop_length,
    130     win_length=win_length,
    131     window=window,
    132     center=center,
    133     pad_mode=pad_mode,
    134     normalized=frame_length_norm,
    135     onesided=onesided,
    136     return_complex=True,
    137 )
    139 # unpack batch
    140 spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])

File D:\ChatTTS\.conda\lib\site-packages\torch\functional.py:665, in stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex)
    663     input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
    664     input = input.view(input.shape[-signal_dim:])
--> 665 return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
    666                 normalized, onesided, return_complex)

RuntimeError: stft input and window must be on the same device but got self on cpu and window on cuda:0

回答

3

请尝试最新提交。

6

ref #603