feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
618
qwen3-tts-backend/indextts/infer_indextts2.py
Normal file
618
qwen3-tts-backend/indextts/infer_indextts2.py
Normal file
@@ -0,0 +1,618 @@
|
||||
import os
|
||||
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
|
||||
import re
|
||||
import textstat
|
||||
import time
|
||||
from subprocess import CalledProcessError
|
||||
from typing import Dict, List, Tuple
|
||||
import librosa
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
import safetensors
|
||||
from transformers import SeamlessM4TFeatureExtractor
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
# from indextts.BigVGAN.models import BigVGAN as Generator
|
||||
from indextts.gpt.model_v2 import UnifiedVoice
|
||||
from indextts.utils.checkpoint import load_checkpoint
|
||||
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
||||
from indextts.utils.maskgct_utils import build_semantic_model, build_semantic_codec, load_config
|
||||
|
||||
from indextts.s2mel.modules.commons import load_checkpoint2, MyModel
|
||||
from indextts.s2mel.modules.bigvgan import bigvgan
|
||||
from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus
|
||||
from indextts.s2mel.modules.audio import mel_spectrogram
|
||||
|
||||
from indextts.utils.front import TextNormalizer, TextTokenizer
|
||||
|
||||
|
||||
|
||||
def contains_chinese(text):
|
||||
# 正则表达式,用于匹配中文字符 + 数字 -> 都认为是 zh
|
||||
if re.search(r'[\u4e00-\u9fff0-9]', text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_text_syllable_num(text):
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
|
||||
number_char_pattern = re.compile(r'[0-9]')
|
||||
syllable_num = 0
|
||||
tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z]+|[0-9]+', text)
|
||||
# print(tokens)
|
||||
if contains_chinese(text):
|
||||
for token in tokens:
|
||||
if chinese_char_pattern.search(token) or number_char_pattern.search(token):
|
||||
syllable_num += len(token)
|
||||
else:
|
||||
syllable_num += textstat.syllable_count(token)
|
||||
else:
|
||||
syllable_num = textstat.syllable_count(text)
|
||||
|
||||
return syllable_num
|
||||
|
||||
|
||||
def get_text_tts_dur(text):
|
||||
min_speed = 3 # 2.18 #
|
||||
max_speed = 5.50
|
||||
|
||||
ratio = 0.8517 if contains_chinese(text) else 1.0
|
||||
|
||||
syllable_num = get_text_syllable_num(text)
|
||||
max_dur = syllable_num * ratio / max_speed
|
||||
min_dur = syllable_num * ratio / min_speed
|
||||
|
||||
return max_dur, min_dur
|
||||
|
||||
|
||||
|
||||
class IndexTTS2:
|
||||
def __init__(
|
||||
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, use_cuda_kernel=None,use_deepspeed=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg_path (str): path to the config file.
|
||||
model_dir (str): path to the model directory.
|
||||
is_fp16 (bool): whether to use fp16.
|
||||
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
||||
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
self.is_fp16 = False if device == "cpu" else is_fp16
|
||||
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
||||
elif torch.cuda.is_available():
|
||||
self.device = "cuda:0"
|
||||
self.is_fp16 = is_fp16
|
||||
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
||||
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
self.is_fp16 = False # Use float16 on MPS is overhead than float32
|
||||
self.use_cuda_kernel = False
|
||||
else:
|
||||
self.device = "cpu"
|
||||
self.is_fp16 = False
|
||||
self.use_cuda_kernel = False
|
||||
print(">> Be patient, it may take a while to run in CPU mode.")
|
||||
|
||||
self.cfg = OmegaConf.load(cfg_path)
|
||||
self.model_dir = model_dir
|
||||
self.dtype = torch.float16 if self.is_fp16 else None
|
||||
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
||||
|
||||
|
||||
self.gpt = UnifiedVoice(**self.cfg.gpt)
|
||||
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
||||
load_checkpoint(self.gpt, self.gpt_path)
|
||||
if self.is_fp16:
|
||||
self.gpt.half()
|
||||
self.gpt = self.gpt.to(self.device)
|
||||
self.gpt.eval()
|
||||
print(">> GPT weights restored from:", self.gpt_path)
|
||||
if self.is_fp16:
|
||||
try:
|
||||
import deepspeed
|
||||
|
||||
use_deepspeed = True
|
||||
except (ImportError, OSError, CalledProcessError) as e:
|
||||
use_deepspeed = False
|
||||
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
||||
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
||||
else:
|
||||
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
|
||||
|
||||
if self.use_cuda_kernel:
|
||||
# preload the CUDA kernel for BigVGAN
|
||||
try:
|
||||
from indextts.BigVGAN.alias_free_activation.cuda import load
|
||||
|
||||
anti_alias_activation_cuda = load.load()
|
||||
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
|
||||
except:
|
||||
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
||||
self.use_cuda_kernel = False
|
||||
|
||||
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
||||
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(os.path.join(self.model_dir, self.cfg.w2v_stat))
|
||||
self.semantic_model = self.semantic_model.to(self.device)
|
||||
self.semantic_model.eval()
|
||||
self.semantic_mean = self.semantic_mean.to(self.device)
|
||||
self.semantic_std = self.semantic_std.to(self.device)
|
||||
|
||||
semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
|
||||
semantic_code_ckpt = hf_hub_download("amphion/MaskGCT", filename="semantic_codec/model.safetensors")
|
||||
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
|
||||
self.semantic_codec = semantic_codec.to(self.device)
|
||||
self.semantic_codec.eval()
|
||||
print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
|
||||
|
||||
s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
|
||||
s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
|
||||
s2mel, _, _, _ = load_checkpoint2(
|
||||
s2mel,
|
||||
None,
|
||||
s2mel_path,
|
||||
load_only_params=True,
|
||||
ignore_modules=[],
|
||||
is_distributed=False,
|
||||
)
|
||||
self.s2mel = s2mel.to(self.device)
|
||||
self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
|
||||
self.s2mel.eval()
|
||||
print(">> s2mel weights restored from:", s2mel_path)
|
||||
|
||||
# load campplus_model
|
||||
campplus_ckpt_path = hf_hub_download(
|
||||
"funasr/campplus", filename="campplus_cn_common.bin"
|
||||
)
|
||||
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
|
||||
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
|
||||
self.campplus_model = campplus_model.to(self.device)
|
||||
self.campplus_model.eval()
|
||||
print(">> campplus_model weights restored from:", campplus_ckpt_path)
|
||||
|
||||
bigvgan_name = self.cfg.vocoder.name
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False)
|
||||
self.bigvgan = self.bigvgan.to(self.device)
|
||||
self.bigvgan.remove_weight_norm()
|
||||
self.bigvgan.eval()
|
||||
print(">> bigvgan weights restored from:", bigvgan_name)
|
||||
|
||||
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
|
||||
self.normalizer = TextNormalizer()
|
||||
self.normalizer.load()
|
||||
print(">> TextNormalizer loaded")
|
||||
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
|
||||
print(">> bpe model loaded from:", self.bpe_path)
|
||||
|
||||
emo_matrix = torch.load(os.path.join(self.model_dir, self.cfg.emo_matrix))
|
||||
self.emo_matrix = emo_matrix.to(self.device)
|
||||
self.emo_num = list(self.cfg.get('emo_num', []))
|
||||
|
||||
mel_fn_args = {
|
||||
"n_fft": self.cfg.s2mel['preprocess_params']['spect_params']['n_fft'],
|
||||
"win_size": self.cfg.s2mel['preprocess_params']['spect_params']['win_length'],
|
||||
"hop_size": self.cfg.s2mel['preprocess_params']['spect_params']['hop_length'],
|
||||
"num_mels": self.cfg.s2mel['preprocess_params']['spect_params']['n_mels'],
|
||||
"sampling_rate": self.cfg.s2mel["preprocess_params"]["sr"],
|
||||
"fmin": self.cfg.s2mel['preprocess_params']['spect_params'].get('fmin', 0),
|
||||
"fmax": None if self.cfg.s2mel['preprocess_params']['spect_params'].get('fmax', "None") == "None" else 8000,
|
||||
"center": False
|
||||
}
|
||||
self.mel_fn = lambda x: mel_spectrogram(x, **mel_fn_args)
|
||||
|
||||
# 缓存参考音频:
|
||||
self.cache_spk_cond = None
|
||||
self.cache_s2mel_style = None
|
||||
self.cache_s2mel_prompt = None
|
||||
self.cache_spk_audio_prompt = None
|
||||
self.cache_emo_cond = None
|
||||
self.cache_emo_audio_prompt = None
|
||||
self.cache_mel = None
|
||||
|
||||
# 进度引用显示(可选)
|
||||
self.gr_progress = None
|
||||
self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None
|
||||
|
||||
@torch.no_grad()
|
||||
def get_emb(self, input_features, attention_mask):
|
||||
vq_emb = self.semantic_model(
|
||||
input_features=input_features,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
feat = vq_emb.hidden_states[17] # (B, T, C)
|
||||
feat = (feat - self.semantic_mean) / self.semantic_std
|
||||
return feat
|
||||
|
||||
def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30):
|
||||
"""
|
||||
Shrink special tokens (silent_token and stop_mel_token) in codes
|
||||
codes: [B, T]
|
||||
"""
|
||||
code_lens = []
|
||||
codes_list = []
|
||||
device = codes.device
|
||||
dtype = codes.dtype
|
||||
isfix = False
|
||||
for i in range(0, codes.shape[0]):
|
||||
code = codes[i]
|
||||
if not torch.any(code == self.stop_mel_token).item():
|
||||
len_ = code.size(0)
|
||||
else:
|
||||
stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False)
|
||||
len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0)
|
||||
|
||||
count = torch.sum(code == silent_token).item()
|
||||
if count > max_consecutive:
|
||||
# code = code.cpu().tolist()
|
||||
ncode_idx = []
|
||||
n = 0
|
||||
for k in range(len_):
|
||||
assert code[k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here"
|
||||
if code[k] != silent_token:
|
||||
ncode_idx.append(k)
|
||||
n = 0
|
||||
elif code[k] == silent_token and n < 10:
|
||||
ncode_idx.append(k)
|
||||
n += 1
|
||||
# if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
|
||||
# n += 1
|
||||
# new code
|
||||
len_ = len(ncode_idx)
|
||||
codes_list.append(code[ncode_idx])
|
||||
isfix = True
|
||||
else:
|
||||
# shrink to len_
|
||||
codes_list.append(code[:len_])
|
||||
code_lens.append(len_)
|
||||
if isfix:
|
||||
if len(codes_list) > 1:
|
||||
codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token)
|
||||
else:
|
||||
codes = codes_list[0].unsqueeze(0)
|
||||
else:
|
||||
# unchanged
|
||||
pass
|
||||
# clip codes to max length
|
||||
max_len = max(code_lens)
|
||||
if max_len < codes.shape[1]:
|
||||
codes = codes[:, :max_len]
|
||||
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
|
||||
return codes, code_lens
|
||||
|
||||
|
||||
def _set_gr_progress(self, value, desc):
|
||||
if self.gr_progress is not None:
|
||||
self.gr_progress(value, desc=desc)
|
||||
|
||||
|
||||
# 原始推理模式
|
||||
def infer(self, spk_audio_prompt, text, output_path,
|
||||
emo_audio_prompt=None, emo_alpha=1.0,
|
||||
emo_vector=None,
|
||||
use_emo_text=False, emo_text=None,emo_text_weight=1.0,
|
||||
use_speed=False, target_dur=None,
|
||||
verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs):
|
||||
print(">> start inference...")
|
||||
self._set_gr_progress(0, "start inference...")
|
||||
if verbose:
|
||||
print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt},"
|
||||
f" emo_audio_prompt:{emo_audio_prompt}, emo_alpha:{emo_alpha}, "
|
||||
f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, "
|
||||
f"emo_text:{emo_text}, use_speed:{use_speed}, target_dur:{target_dur}")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
|
||||
if emo_vector is not None:
|
||||
assert emo_audio_prompt is None
|
||||
assert emo_alpha == 1.0
|
||||
emo_vector_sum = sum(emo_vector)
|
||||
if self.emo_num and len(emo_vector) == len(self.emo_num):
|
||||
expanded = []
|
||||
for w, n in zip(emo_vector, self.emo_num):
|
||||
expanded.extend([w] * n)
|
||||
weight_vector = torch.tensor(expanded, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
weight_vector = torch.tensor(emo_vector, dtype=torch.float32).to(self.device)
|
||||
emovec_mat = weight_vector.unsqueeze(1) * self.emo_matrix
|
||||
emovec_mat = torch.sum(emovec_mat, 0)
|
||||
emovec_mat = emovec_mat.unsqueeze(0)
|
||||
print(f">> emovec_mat norm: {emovec_mat.norm().item():.4f}, emo_vector_sum: {emo_vector_sum:.4f}")
|
||||
|
||||
|
||||
if emo_audio_prompt is None:
|
||||
emo_audio_prompt = spk_audio_prompt
|
||||
assert emo_alpha == 1.0
|
||||
|
||||
num_codes = None
|
||||
if use_speed:
|
||||
assert target_dur is not None, "When use_speed is set to True, the target duration (target_dur) in seconds must be specified."
|
||||
'''
|
||||
min_dur, max_dur = get_text_tts_dur(text)
|
||||
if target_dur >= min_dur and target_dur <= max_dur:
|
||||
num_codes = torch.tensor([int(target_dur * 50)], device=self.device)
|
||||
else:
|
||||
print('target_dur should in [{}, {}], now {}'.format(min_dur, max_dur, target_dur))
|
||||
return
|
||||
'''
|
||||
|
||||
num_codes = torch.tensor([int(target_dur * 50)], device=self.device)
|
||||
print("目标合成时长: {}s,目标token数:{}".format(str(target_dur), str(int(target_dur * 50))))
|
||||
|
||||
# 如果参考音频改变了,才需要重新生成, 提升速度
|
||||
if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
|
||||
audio, sr = librosa.load(spk_audio_prompt)
|
||||
audio = torch.tensor(audio).unsqueeze(0)
|
||||
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
|
||||
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
|
||||
|
||||
inputs = self.extract_features(audio_16k, sampling_rate=16000, return_tensors="pt")
|
||||
input_features = inputs["input_features"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
input_features = input_features.to(self.device)
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
spk_cond_emb = self.get_emb(input_features, attention_mask)
|
||||
|
||||
_, S_ref = self.semantic_codec.quantize(spk_cond_emb)
|
||||
ref_mel = self.mel_fn(audio_22k.to(spk_cond_emb.device).float())
|
||||
ref_target_lengths = torch.LongTensor([ref_mel.size(2)]).to(ref_mel.device)
|
||||
feat = torchaudio.compliance.kaldi.fbank(audio_16k.to(ref_mel.device),
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True) # feat2另外一个滤波器能量组特征[922, 80]
|
||||
style = self.campplus_model(feat.unsqueeze(0)) #参考音频的全局style2[1,192]
|
||||
|
||||
prompt_condition = self.s2mel.models['length_regulator'](S_ref,
|
||||
ylens=ref_target_lengths,
|
||||
n_quantizers=3,
|
||||
f0=None)[0]
|
||||
|
||||
self.cache_spk_cond = spk_cond_emb.detach()
|
||||
self.cache_s2mel_style = style.detach()
|
||||
self.cache_s2mel_prompt = prompt_condition.detach()
|
||||
self.cache_spk_audio_prompt = spk_audio_prompt
|
||||
self.cache_mel = ref_mel.detach()
|
||||
else:
|
||||
style = self.cache_s2mel_style
|
||||
prompt_condition = self.cache_s2mel_prompt
|
||||
spk_cond_emb = self.cache_spk_cond
|
||||
ref_mel = self.cache_mel
|
||||
|
||||
if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
|
||||
emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000)
|
||||
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
|
||||
emo_input_features = emo_inputs["input_features"]
|
||||
emo_attention_mask = emo_inputs["attention_mask"]
|
||||
emo_input_features = emo_input_features.to(self.device)
|
||||
emo_attention_mask = emo_attention_mask.to(self.device)
|
||||
emo_cond_emb = self.get_emb(emo_input_features, emo_attention_mask)
|
||||
|
||||
self.cache_emo_cond = emo_cond_emb.detach()
|
||||
self.cache_emo_audio_prompt = emo_audio_prompt
|
||||
else:
|
||||
emo_cond_emb = self.cache_emo_cond
|
||||
|
||||
|
||||
self._set_gr_progress(0.1, "text processing...")
|
||||
text_tokens_list = self.tokenizer.tokenize(text)
|
||||
if use_speed and len(text_tokens_list) > max_text_tokens_per_sentence:
|
||||
use_speed = False
|
||||
if not use_speed:
|
||||
sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence)
|
||||
else:
|
||||
sentences = [text_tokens_list]
|
||||
if verbose:
|
||||
print("text_tokens_list:", text_tokens_list)
|
||||
print("sentences count:", len(sentences))
|
||||
print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
|
||||
print(*sentences, sep="\n")
|
||||
do_sample = generation_kwargs.pop("do_sample", True)
|
||||
top_p = generation_kwargs.pop("top_p", 0.8)
|
||||
top_k = generation_kwargs.pop("top_k", 30)
|
||||
temperature = generation_kwargs.pop("temperature", 0.8)
|
||||
autoregressive_batch_size = 1
|
||||
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
|
||||
num_beams = generation_kwargs.pop("num_beams", 3)
|
||||
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
|
||||
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 1500)
|
||||
sampling_rate = 22050
|
||||
|
||||
wavs = []
|
||||
gpt_gen_time = 0
|
||||
gpt_forward_time = 0
|
||||
s2mel_time = 0
|
||||
bigvgan_time = 0
|
||||
progress = 0
|
||||
has_warned = False
|
||||
for sent in sentences:
|
||||
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
||||
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
||||
if verbose:
|
||||
print(text_tokens)
|
||||
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
||||
# debug tokenizer
|
||||
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
||||
print("text_token_syms is same as sentence tokens", text_token_syms == sent)
|
||||
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
emovec = self.gpt.merge_emovec(
|
||||
spk_cond_emb,
|
||||
emo_cond_emb,
|
||||
torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
alpha=emo_alpha
|
||||
)
|
||||
|
||||
if emo_vector is not None:
|
||||
emovec = emovec_mat + (1 - emo_vector_sum) * emovec
|
||||
# emovec = emovec_mat
|
||||
|
||||
codes = self.gpt.inference_speech(
|
||||
spk_cond_emb,
|
||||
text_tokens,
|
||||
emo_cond_emb,
|
||||
cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
emo_vec=emovec,
|
||||
use_speed=use_speed,
|
||||
num_codes=num_codes,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
gpt_gen_time += time.perf_counter() - m_start_time
|
||||
if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
|
||||
warnings.warn(
|
||||
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
||||
f"Input text tokens: {text_tokens.shape[1]}. "
|
||||
f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
has_warned = True
|
||||
|
||||
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
||||
# if verbose:
|
||||
# print(codes, type(codes))
|
||||
# print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
# print(f"code len: {code_lens}")
|
||||
|
||||
code_lens = []
|
||||
for code in codes:
|
||||
if self.stop_mel_token not in code:
|
||||
code_lens.append(len(code))
|
||||
code_len = len(code)
|
||||
else:
|
||||
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0]+1
|
||||
code_len = len_-1
|
||||
code_lens.append(code_len)
|
||||
codes = codes[:, :code_len]
|
||||
code_lens = torch.LongTensor(code_lens)
|
||||
code_lens = code_lens.to(self.device)
|
||||
if verbose:
|
||||
print(codes, type(codes))
|
||||
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
||||
print(f"code len: {code_lens}")
|
||||
|
||||
m_start_time = time.perf_counter()
|
||||
if use_speed:
|
||||
use_speed = torch.ones(spk_cond_emb.size(0)).to(spk_cond_emb.device).long()
|
||||
else:
|
||||
use_speed = torch.zeros(spk_cond_emb.size(0)).to(spk_cond_emb.device).long()
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = self.gpt(
|
||||
spk_cond_emb,
|
||||
text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
|
||||
codes,
|
||||
torch.tensor([codes.shape[-1]], device=text_tokens.device),
|
||||
emo_cond_emb,
|
||||
cond_mel_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
emo_cond_mel_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
emo_vec=emovec,
|
||||
use_speed=use_speed,
|
||||
)
|
||||
gpt_forward_time += time.perf_counter() - m_start_time
|
||||
|
||||
m_start_time = time.perf_counter()
|
||||
diffusion_steps=25
|
||||
inference_cfg_rate=0.7
|
||||
latent = self.s2mel.models['gpt_layer'](latent)
|
||||
S_infer = self.semantic_codec.quantizer.vq2emb(codes.unsqueeze(1))
|
||||
S_infer = S_infer.transpose(1,2)
|
||||
S_infer = S_infer + latent
|
||||
target_lengths = (code_lens * 1.72).long()
|
||||
|
||||
cond = self.s2mel.models['length_regulator'](S_infer,
|
||||
ylens=target_lengths,
|
||||
n_quantizers=3,
|
||||
f0=None)[0]
|
||||
cat_condition = torch.cat([prompt_condition, cond], dim=1)
|
||||
vc_target = self.s2mel.models['cfm'].inference(cat_condition,
|
||||
torch.LongTensor([cat_condition.size(1)]).to(cond.device),
|
||||
ref_mel, style, None, diffusion_steps,
|
||||
inference_cfg_rate=inference_cfg_rate)
|
||||
vc_target = vc_target[:, :, ref_mel.size(-1):]
|
||||
s2mel_time += time.perf_counter() - m_start_time
|
||||
|
||||
m_start_time = time.perf_counter()
|
||||
wav = self.bigvgan(vc_target.float()).squeeze().unsqueeze(0)
|
||||
print(wav.shape)
|
||||
bigvgan_time += time.perf_counter() - m_start_time
|
||||
wav = wav.squeeze(1)
|
||||
|
||||
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
|
||||
if verbose:
|
||||
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
|
||||
# wavs.append(wav[:, :-512])
|
||||
wavs.append(wav.cpu()) # to cpu before saving
|
||||
end_time = time.perf_counter()
|
||||
self._set_gr_progress(0.9, "save audio...")
|
||||
wav = torch.cat(wavs, dim=1)
|
||||
wav_length = wav.shape[-1] / sampling_rate
|
||||
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
|
||||
print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
|
||||
print(f">> s2mel_time: {s2mel_time:.2f} seconds")
|
||||
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
|
||||
print(f">> Total inference time: {end_time - start_time:.2f} seconds")
|
||||
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
||||
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
|
||||
|
||||
# save audio
|
||||
wav = wav.cpu() # to cpu
|
||||
if output_path:
|
||||
# 直接保存音频到指定路径中
|
||||
if os.path.isfile(output_path):
|
||||
os.remove(output_path)
|
||||
print(">> remove old wav file:", output_path)
|
||||
if os.path.dirname(output_path) != "":
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
import soundfile as sf
|
||||
sf.write(output_path, wav.squeeze().cpu().numpy().astype('int16'), sampling_rate, subtype='PCM_16')
|
||||
print(">> wav file saved to:", output_path)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return output_path
|
||||
else:
|
||||
# 返回以符合Gradio的格式要求
|
||||
wav_data = wav.type(torch.int16)
|
||||
wav_data = wav_data.numpy().T
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return (sampling_rate, wav_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_wav="test_data/input.wav"
|
||||
#text="晕 XUAN4 是 一 种 GAN3 觉"
|
||||
#text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
|
||||
#text="There is a vehicle arriving in dock number 7?"
|
||||
text='欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。'
|
||||
|
||||
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, use_cuda_kernel=False)
|
||||
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||
Reference in New Issue
Block a user