[2noise/ChatTTS]当输入文本较短时,zero-shot 大多生成 0 秒噪声

2024-10-17 855 views
9

当输入文本较短时,zero-shot 生成 0 秒噪声。 比如当文本为:"一二线城市用地紧张.",生成 0 秒噪声 当将文本扩展为“"一二线城市用地紧张,房价居高不下,百姓负担重,不敢消费,生活压力大,苦于维持生计"”,能够生成较完整的语音,但是会多个"就"的发音。 如果使用这个语音数据集,zero-shot 生成的大多数声音,都是 0 秒噪声。 aishell_dataset = MsDataset.load('modelscope/speech_asr_aishell1_testsets',subset_name='default',split='test')

请帮忙看看是什么原因,能否解决,多谢!

软件版本: ChatTTS branch: main commit: 51ec0c7 python 3.10 torch 2.3.0 torchaudio 2.3.0 os: ubuntu

以下是复现代码:

from tools.audio import load_audio
from tools.audio import pcm_arr_to_mp3_view
from tools.logger import get_logger
import torchaudio
import pybase16384 as b14
import torch
import numpy as np
from modelscope.msdatasets import MsDataset
from tqdm import tqdm
import time
import os
import shutil

logger = get_logger("Command")
import ChatTTS

chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")

def seed_everything(seed, device='cuda'):
    """
    Set a random seed to ensure reproducibility.

    Parameters:
        seed (int): The random seed
    """
    import random
    import numpy as np
    import torch

    # Set the seed for the Python random number generator
    random.seed(seed)

    # Set the seed for the NumPy random number generator
    np.random.seed(seed)

    # Set the seed for the PyTorch random number generator
    torch.manual_seed(seed)

    # If using CUDA, set the seed for the CUDA random number generator
    if device == 'cuda':
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)  # Set the seed for all GPUs

        # Set reproducibility (optional)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

device = 'cuda'
seed_everything(42, device)

if chat.load(source='custom',
            custom_path='./pretrained_models/chatTTS/',
            device=device,
            compile=False, # Set to True for better performance
            use_flash_attn=False):
    logger.info("Models loaded successfully.")
else:
    logger.error("Models load failed.")
    sys.exit(1)

try:
    chat.normalizer.register("zh", normalizer_zh_tn())
except ValueError as e:
    logger.error(e)
except:
    logger.warning("Package WeTextProcessing not found!")
    logger.warning(
        "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
    )

def save_mp3_file(wav, out_dir, filename):
    data = pcm_arr_to_mp3_view(wav)
    mp3_filename = os.path.join(out_dir, filename)
    with open(mp3_filename, "wb") as f:
        f.write(data)
    logger.info(f"Audio saved to {mp3_filename}")

t0 = time.time()
aishell_dataset = MsDataset.load('modelscope/speech_asr_aishell1_testsets',subset_name='default',split='test')
t1 = time.time()
print(f'time-cost of downloading dataset: {t1 - t0}')
inference_times = 50
total_mse = []
cost_time = []
fail_cnt = 0
failed_audio_path = []

params_refine_text = ChatTTS.Chat.RefineTextParams(
    prompt='[oral_2][laugh_0][break_6]',
)
for i, item in tqdm(enumerate(aishell_dataset), desc="speech_asr_aishell1_testsets Inference", total=inference_times, unit="sample"):
        if i == inference_times:
            break
        try:
            speech_path ,label = item['Audio:FILE'], item["Text:LABEL"]
            print(f'speech_path: {speech_path}')
            init_wave, _ = torchaudio.load(speech_path)
            start_time = time.time()
            spk_smp = chat.sample_audio_speaker(init_wave.cpu().numpy()[0])
            # spk_smp = chat.sample_audio_speaker(load_audio(speech_path, 24000))

            params_infer_code = ChatTTS.Chat.InferCodeParams(
                spk_smp=spk_smp,
                # temperature=0.9,
                txt_smp=label,
            )

            wavs = chat.infer(label,
                              params_infer_code=params_infer_code,
                              params_refine_text=params_refine_text,
                              )

            for index, wav in enumerate(wavs):
                arr = speech_path.split('/')
                folder = arr[-2]
                file_name = arr[-1]
                output_dir='./result/seedeverything/0822-1111'
                our_dir = os.path.join(output_dir, device, folder, 'generated')
                os.makedirs(our_dir,exist_ok=True)
                save_mp3_file(wav, our_dir, file_name)

                our_dir = os.path.join(output_dir, device, folder, 'original')
                os.makedirs(our_dir,exist_ok=True)
                shutil.copy(speech_path, our_dir)

            cost_time.append(time.time() - start_time)
            pred_wav = wavs[0]
            min_shape = min(init_wave.shape[1], pred_wav.shape[0])
            mse = torch.mean((init_wave[0][:min_shape]
                            - pred_wav[:min_shape])**2).numpy().item()
            total_mse.append(mse)
        except Exception as e:
            fail_cnt += 1
            failed_audio_path.append(speech_path)
            print(e,i)

print(f'total mse: {sum(total_mse)}, fail_cnt = {fail_cnt}')
print(f'failed audio paths: {failed_audio_path}')

回答

6
    我没有搞过ChatTTS, 但是踩过TorToise-TTS, 都是基于 AR-LLM的,大同小异。

您可以看一下,当生成的是静音片段时,结束位置的Token,是不是sequence_end token, 而且其概率很大? 或者甚至是第一个预测token就是Sequence End Token? 这样的话,可能您的Prompt自身语音开始处部分静音没有清理干净,使得模型在In-Context推理时,错误的认为推理到了序列结尾位置。 如果处理Prompt还是不能解决问题,您还可以尝试修改采样算法。采样算法参数调整可以调过这个问题。比如,您可以将温度Temperature,调的的非常高,这样最后的token概率分布就被拉平了,被采样的几率都接近了,就不一定一开始就采样到sequence_end了。 最终级的解决办法,还是模型自身的效果调优,不能让它过早的预测到sequence_end token. 以上仅供参考。

5

label 要符合 ChatTTS 的格式,即经过 refine_text 后的结果。

4

请问如何获取refine过后的结果呢