feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py
|
||||
# Licensed under Apache License 2.0
|
||||
|
||||
from .modules.seanet import SEANetEncoder, SEANetDecoder
|
||||
from .modules.quantization import ResidualVectorQuantizer
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SpeechTokenizer(nn.Module):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : json
|
||||
Model Config.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = SEANetEncoder(
|
||||
n_filters=config.get("n_filters"),
|
||||
dimension=config.get("dimension"),
|
||||
ratios=config.get("strides"),
|
||||
lstm=config.get("lstm_layers"),
|
||||
bidirectional=config.get("bidirectional"),
|
||||
dilation_base=config.get("dilation_base"),
|
||||
residual_kernel_size=config.get("residual_kernel_size"),
|
||||
n_residual_layers=config.get("n_residual_layers"),
|
||||
activation=config.get("activation"),
|
||||
)
|
||||
self.sample_rate = config.get("sample_rate")
|
||||
self.n_q = config.get("n_q")
|
||||
self.downsample_rate = np.prod(config.get("strides"))
|
||||
if config.get("dimension") != config.get("semantic_dimension"):
|
||||
self.transform = nn.Linear(
|
||||
config.get("dimension"), config.get("semantic_dimension")
|
||||
)
|
||||
else:
|
||||
self.transform = nn.Identity()
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=config.get("dimension"),
|
||||
n_q=config.get("n_q"),
|
||||
bins=config.get("codebook_size"),
|
||||
)
|
||||
self.decoder = SEANetDecoder(
|
||||
n_filters=config.get("n_filters"),
|
||||
dimension=config.get("dimension"),
|
||||
ratios=config.get("strides"),
|
||||
lstm=config.get("lstm_layers"),
|
||||
bidirectional=False,
|
||||
dilation_base=config.get("dilation_base"),
|
||||
residual_kernel_size=config.get("residual_kernel_size"),
|
||||
n_residual_layers=config.get("n_residual_layers"),
|
||||
activation=config.get("activation"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str
|
||||
Path of model configuration file.
|
||||
ckpt_path : str
|
||||
Path of model checkpoint.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : SpeechTokenizer
|
||||
SpeechTokenizer model.
|
||||
|
||||
"""
|
||||
import json
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
model = cls(cfg)
|
||||
params = torch.load(ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(params)
|
||||
return model
|
||||
|
||||
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.tensor
|
||||
Input wavs. Shape: (batch, channels, timesteps).
|
||||
n_q : int, optional
|
||||
Number of quantizers in RVQ used to encode. The default is all layers.
|
||||
layers : list[int], optional
|
||||
Layers of RVQ should return quantized result. The default is the first layer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
o : torch.tensor
|
||||
Output wavs. Shape: (batch, channels, timesteps).
|
||||
commit_loss : torch.tensor
|
||||
Commitment loss from residual vector quantizers.
|
||||
feature : torch.tensor
|
||||
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
|
||||
|
||||
"""
|
||||
n_q = n_q if n_q else self.n_q
|
||||
e = self.encoder(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
e, n_q=n_q, layers=layers
|
||||
)
|
||||
feature = rearrange(quantized_list[0], "b d t -> b t d")
|
||||
feature = self.transform(feature)
|
||||
o = self.decoder(quantized)
|
||||
return o, commit_loss, feature
|
||||
|
||||
def forward_feature(self, x: torch.tensor, layers: list = None):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.tensor
|
||||
Input wavs. Shape should be (batch, channels, timesteps).
|
||||
layers : list[int], optional
|
||||
Layers of RVQ should return quantized result. The default is all layers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quantized_list : list[torch.tensor]
|
||||
Quantized of required layers.
|
||||
|
||||
"""
|
||||
e = self.encoder(x)
|
||||
layers = layers if layers else list(range(self.n_q))
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
|
||||
return quantized_list
|
||||
|
||||
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.tensor
|
||||
Input wavs. Shape: (batch, channels, timesteps).
|
||||
n_q : int, optional
|
||||
Number of quantizers in RVQ used to encode. The default is all layers.
|
||||
st : int, optional
|
||||
Start quantizer index in RVQ. The default is 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
codes : torch.tensor
|
||||
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
|
||||
|
||||
"""
|
||||
e = self.encoder(x)
|
||||
if st is None:
|
||||
st = 0
|
||||
n_q = n_q if n_q else self.n_q
|
||||
codes = self.quantizer.encode(e, n_q=n_q, st=st)
|
||||
return codes
|
||||
|
||||
def decode(self, codes: torch.tensor, st: int = 0):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
codes : torch.tensor
|
||||
Indices for each quantizer. Shape: (n_q, batch, timesteps).
|
||||
st : int, optional
|
||||
Start quantizer index in RVQ. The default is 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
o : torch.tensor
|
||||
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
|
||||
|
||||
"""
|
||||
quantized = self.quantizer.decode(codes, st=st)
|
||||
o = self.decoder(quantized)
|
||||
return o
|
||||
Reference in New Issue
Block a user