[2noise/ChatTTS]Windows环境下显示Windows not yet supported for torch.compile,而WSL环境下性能不佳,4090下只能以40s/it的速度生成

2024-06-05 21 views
9

硬件: cpu:5950x gpu:rtx 4090 系统: windows 11 wsl: ubuntu 22.04

windows环境下报错:

发生异常: RuntimeError Windows not yet supported for torch.compile File "E:\git\ChatTTS\ChatTTS\core.py", line 102, in _load gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True) File "E:\git\ChatTTS\ChatTTS\core.py", line 61, in load_models self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs) File "E:\git\ChatTTS\test.py", line 10, in <module> chat.load_models() RuntimeError: Windows not yet supported for torch.compile

并不知道为什么会需要调用torch.compile。

而切换WSL发现性能不佳,初步判断可能是并没有用gpu进行推理,但通过检查torch.cuda.is_available,发现结果为True。输入如下:

INFO:ChatTTS.core:Load from cache: /home/xxxx/.cache/huggingface/hub/models--2Noise--ChatTTS/snapshots/c0aa9139945a4d7bb1c84f07785db576f2bb1bfa INFO:ChatTTS.core:use cuda:0 INFO:ChatTTS.core:vocos loaded. INFO:ChatTTS.core:dvae loaded. INFO:ChatTTS.core:gpt loaded. INFO:ChatTTS.core:decoder loaded. INFO:ChatTTS.core:tokenizer loaded. INFO:ChatTTS.core:All initialized. INFO:ChatTTS.core:All initialized. 0%|▎ | 1/384 [00:48<5:09:58, 48.56s/it]W0531 07:15:26.417000 140610319078464 torch/_dynamo/exc.py:184] [0/1] Backend compiler failed with a fake tensor exception at W0531 07:15:26.417000 140610319078464 torch/_dynamo/exc.py:184] [0/1] File "/home/xxxx/miniconda3/envs/tts/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 998, in forward W0531 07:15:26.417000 140610319078464 torch/_dynamo/exc.py:184] [0/1] return BaseModelOutputWithPast( W0531 07:15:26.417000 140610319078464 torch/_dynamo/exc.py:184] [0/1] Adding a graph break. W0531 07:15:43.909000 140610319078464 torch/_dynamo/exc.py:184] [0/1_1] Backend compiler failed with a fake tensor exception at W0531 07:15:43.909000 140610319078464 torch/_dynamo/exc.py:184] [0/1_1] File "/home/xxxx/miniconda3/envs/tts/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 998, in forward W0531 07:15:43.909000 140610319078464 torch/_dynamo/exc.py:184] [0/1_1] return BaseModelOutputWithPast( W0531 07:15:43.909000 140610319078464 torch/_dynamo/exc.py:184] [0/1_1] Adding a graph break. W0531 07:15:47.588000 140610319078464 torch/_dynamo/exc.py:184] [3/0] Backend compiler failed with a fake tensor exception at W0531 07:15:47.588000 140610319078464 torch/_dynamo/exc.py:184] [3/0] File "/home/xxxx/miniconda3/envs/tts/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 738, in forward W0531 07:15:47.588000 140610319078464 torch/_dynamo/exc.py:184] [3/0] return outputs W0531 07:15:47.588000 140610319078464 torch/_dynamo/exc.py:184] [3/0] Adding a graph break. 1%|▌ | 2/384 [02:02<6:44:00, 63.46s/it]W0531 07:16:25.436000 140610319078464 torch/_dynamo/exc.py:184] [3/20] Backend compiler failed with a fake tensor exception at W0531 07:16:25.436000 140610319078464 torch/_dynamo/exc.py:184] [3/20] File "/home/xxxx/miniconda3/envs/tts/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 738, in forward W0531 07:16:25.436000 140610319078464 torch/_dynamo/exc.py:184] [3/20] return outputs W0531 07:16:25.436000 140610319078464 torch/_dynamo/exc.py:184] [3/20] Adding a graph break. 1%|▉ | 3/384 [02:42<5:34:20, 52.65s/it]

回答

8

问题大致解决,解决方法是将 def _load中的compile bool = False,windows下可以运行,并且性能改善,大概30it/s。初步认为即便在linux环境下,torch.compile依然有很多限制导致报错。想了解一下为什么这里一定要用到torch.compile。

2

这一步编译应该是为了加速推理用的,需要选择正确的编译版本

8

这一步编译应该是为了加速推理用的,需要选择正确的编译版本

但这反而是问题,至少默认编译器编译反而没有提升性能,而且还有兼容性问题。关闭反而速度提升。

9

linux下, 4090大概能有140-150its/s. 其他平台没有测试, 可以酌情手动关闭

6

linux下, 4090大概能有140-150its/s. 其他平台没有测试, 可以酌情手动关闭

是纯linux环境吧,目前wsl性能不佳可能就是兼容性问题了。我换个pytorch版本试试能不能解决性能问题。

5

问题大致解决,解决方法是将 def _load中的compile bool = False,windows下可以运行,并且性能改善,大概30it/s。初步认为即便在linux环境下,torch.compile依然有很多限制导致报错。想了解一下为什么这里一定要用到torch.compile。

我把compile 改为False后还是不行,继续报错。目前还没解决,pytorch是2.1版本,不知道换个版本会不会好一些

8

问题大致解决,解决方法是将 def _load中的compile bool = False,windows下可以运行,并且性能改善,大概30it/s。初步认为即便在linux环境下,torch.compile依然有很多限制导致报错。想了解一下为什么这里一定要用到torch.compile。

我把compile 改为False后还是不行,继续报错。目前还没解决,pytorch是2.1版本,不知道换个版本会不会好一些

根据警告信息,有可能不是pytorch而是transformers的问题。你换个transformer版本试试吧。

4

你的环境应该有问题。我也是在WSL中,但是显示的速度至少60its/s。而且计算进度条有问题,很多走到一半就已经算完了,考虑这点的话,速度其实应该为120its/s。这个才符合wsl和原生Linux的速度差异。

3

你的环境应该有问题。我也是在WSL中,但是显示的速度至少60its/s。而且计算进度条有问题,很多走到一半就已经算完了,考虑这点的话,速度其实应该为120its/s。这个才符合wsl和原生Linux的速度差异。

对了,我想起来个原因,你的环境安装cuda-toolkit了吗?我好像没有安。

3

你的环境应该有问题。我也是在WSL中,但是显示的速度至少60its/s。而且计算进度条有问题,很多走到一半就已经算完了,考虑这点的话,速度其实应该为120its/s。这个才符合wsl和原生Linux的速度差异。

最新的尝试是安装了cuda-toolkit,版本如下: nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Wed_Apr_17_19:19:55_PDT_2024 Cuda compilation tools, release 12.5, V12.5.40 Build cuda_12.5.r12.5/compiler.34177558_0

这次性能大致正常了,推理速度跑到了90it/s,但是在开始可能是因为需要编译的原因,会停顿很久。算上时间其实跟不编译的时间差不多,你是否也有类似的问题。