当输入文本较短时,zero-shot 生成 0 秒噪声。 比如当文本为:"一二线城市用地紧张.",生成 0 秒噪声 当将文本扩展为“"一二线城市用地紧张,房价居高不下,百姓负担重,不敢消费,生活压力大,苦于维持生计"”,能够生成较完整的语音,但是会多个"就"的发音。 如果使用这个语音数据集,zero-shot 生成的大多数声音,都是 0 秒噪声。 aishell_dataset = MsDataset.load('modelscope/speech_asr_aishell1_testsets',subset_name='default',split='test')
请帮忙看看是什么原因,能否解决,多谢!
软件版本: ChatTTS branch: main commit: 51ec0c7 python 3.10 torch 2.3.0 torchaudio 2.3.0 os: ubuntu
以下是复现代码:
from tools.audio import load_audio
from tools.audio import pcm_arr_to_mp3_view
from tools.logger import get_logger
import torchaudio
import pybase16384 as b14
import torch
import numpy as np
from modelscope.msdatasets import MsDataset
from tqdm import tqdm
import time
import os
import shutil
logger = get_logger("Command")
import ChatTTS
chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")
def seed_everything(seed, device='cuda'):
"""
Set a random seed to ensure reproducibility.
Parameters:
seed (int): The random seed
"""
import random
import numpy as np
import torch
# Set the seed for the Python random number generator
random.seed(seed)
# Set the seed for the NumPy random number generator
np.random.seed(seed)
# Set the seed for the PyTorch random number generator
torch.manual_seed(seed)
# If using CUDA, set the seed for the CUDA random number generator
if device == 'cuda':
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # Set the seed for all GPUs
# Set reproducibility (optional)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = 'cuda'
seed_everything(42, device)
if chat.load(source='custom',
custom_path='./pretrained_models/chatTTS/',
device=device,
compile=False, # Set to True for better performance
use_flash_attn=False):
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
sys.exit(1)
try:
chat.normalizer.register("zh", normalizer_zh_tn())
except ValueError as e:
logger.error(e)
except:
logger.warning("Package WeTextProcessing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)
def save_mp3_file(wav, out_dir, filename):
data = pcm_arr_to_mp3_view(wav)
mp3_filename = os.path.join(out_dir, filename)
with open(mp3_filename, "wb") as f:
f.write(data)
logger.info(f"Audio saved to {mp3_filename}")
t0 = time.time()
aishell_dataset = MsDataset.load('modelscope/speech_asr_aishell1_testsets',subset_name='default',split='test')
t1 = time.time()
print(f'time-cost of downloading dataset: {t1 - t0}')
inference_times = 50
total_mse = []
cost_time = []
fail_cnt = 0
failed_audio_path = []
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_6]',
)
for i, item in tqdm(enumerate(aishell_dataset), desc="speech_asr_aishell1_testsets Inference", total=inference_times, unit="sample"):
if i == inference_times:
break
try:
speech_path ,label = item['Audio:FILE'], item["Text:LABEL"]
print(f'speech_path: {speech_path}')
init_wave, _ = torchaudio.load(speech_path)
start_time = time.time()
spk_smp = chat.sample_audio_speaker(init_wave.cpu().numpy()[0])
# spk_smp = chat.sample_audio_speaker(load_audio(speech_path, 24000))
params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_smp=spk_smp,
# temperature=0.9,
txt_smp=label,
)
wavs = chat.infer(label,
params_infer_code=params_infer_code,
params_refine_text=params_refine_text,
)
for index, wav in enumerate(wavs):
arr = speech_path.split('/')
folder = arr[-2]
file_name = arr[-1]
output_dir='./result/seedeverything/0822-1111'
our_dir = os.path.join(output_dir, device, folder, 'generated')
os.makedirs(our_dir,exist_ok=True)
save_mp3_file(wav, our_dir, file_name)
our_dir = os.path.join(output_dir, device, folder, 'original')
os.makedirs(our_dir,exist_ok=True)
shutil.copy(speech_path, our_dir)
cost_time.append(time.time() - start_time)
pred_wav = wavs[0]
min_shape = min(init_wave.shape[1], pred_wav.shape[0])
mse = torch.mean((init_wave[0][:min_shape]
- pred_wav[:min_shape])**2).numpy().item()
total_mse.append(mse)
except Exception as e:
fail_cnt += 1
failed_audio_path.append(speech_path)
print(e,i)
print(f'total mse: {sum(total_mse)}, fail_cnt = {fail_cnt}')
print(f'failed audio paths: {failed_audio_path}')