feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
400
qwen3-tts-backend/indextts/s2mel/dac/model/dac.py
Normal file
400
qwen3-tts-backend/indextts/s2mel/dac/model/dac.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from audiotools import AudioSignal
|
||||
from audiotools.ml import BaseModel
|
||||
from torch import nn
|
||||
|
||||
from .base import CodecMixin
|
||||
from indextts.s2mel.dac.nn.layers import Snake1d
|
||||
from indextts.s2mel.dac.nn.layers import WNConv1d
|
||||
from indextts.s2mel.dac.nn.layers import WNConvTranspose1d
|
||||
from indextts.s2mel.dac.nn.quantize import ResidualVectorQuantize
|
||||
from .encodec import SConv1d, SConvTranspose1d, SLSTM
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
|
||||
super().__init__()
|
||||
conv1d_type = SConv1d# if causal else WNConv1d
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
|
||||
Snake1d(dim),
|
||||
conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
|
||||
super().__init__()
|
||||
conv1d_type = SConv1d# if causal else WNConv1d
|
||||
self.block = nn.Sequential(
|
||||
ResidualUnit(dim // 2, dilation=1, causal=causal),
|
||||
ResidualUnit(dim // 2, dilation=3, causal=causal),
|
||||
ResidualUnit(dim // 2, dilation=9, causal=causal),
|
||||
Snake1d(dim // 2),
|
||||
conv1d_type(
|
||||
dim // 2,
|
||||
dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
causal=causal,
|
||||
norm='weight_norm',
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
d_latent: int = 64,
|
||||
causal: bool = False,
|
||||
lstm: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
conv1d_type = SConv1d# if causal else WNConv1d
|
||||
# Create first convolution
|
||||
self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
|
||||
|
||||
# Add LSTM if needed
|
||||
self.use_lstm = lstm
|
||||
if lstm:
|
||||
self.block += [SLSTM(d_model, lstm)]
|
||||
|
||||
# Create last convolution
|
||||
self.block += [
|
||||
Snake1d(d_model),
|
||||
conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
|
||||
]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
def reset_cache(self):
|
||||
# recursively find all submodules named SConv1d in self.block and use their reset_cache method
|
||||
def reset_cache(m):
|
||||
if isinstance(m, SConv1d) or isinstance(m, SLSTM):
|
||||
m.reset_cache()
|
||||
return
|
||||
for child in m.children():
|
||||
reset_cache(child)
|
||||
|
||||
reset_cache(self.block)
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
|
||||
super().__init__()
|
||||
conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(input_dim),
|
||||
conv1d_type(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
causal=causal,
|
||||
norm='weight_norm'
|
||||
),
|
||||
ResidualUnit(output_dim, dilation=1, causal=causal),
|
||||
ResidualUnit(output_dim, dilation=3, causal=causal),
|
||||
ResidualUnit(output_dim, dilation=9, causal=causal),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
d_out: int = 1,
|
||||
causal: bool = False,
|
||||
lstm: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
conv1d_type = SConv1d# if causal else WNConv1d
|
||||
# Add first conv layer
|
||||
layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
|
||||
|
||||
if lstm:
|
||||
layers += [SLSTM(channels, num_layers=lstm)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class DAC(BaseModel, CodecMixin):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 64,
|
||||
encoder_rates: List[int] = [2, 4, 8, 8],
|
||||
latent_dim: int = None,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 4, 2],
|
||||
n_codebooks: int = 9,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: Union[int, list] = 8,
|
||||
quantizer_dropout: bool = False,
|
||||
sample_rate: int = 44100,
|
||||
lstm: int = 2,
|
||||
causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
|
||||
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer = ResidualVectorQuantize(
|
||||
input_dim=latent_dim,
|
||||
n_codebooks=n_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
lstm=lstm,
|
||||
causal=causal,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.apply(init_weights)
|
||||
|
||||
self.delay = self.get_delay()
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def encode(
|
||||
self,
|
||||
audio_data: torch.Tensor,
|
||||
n_quantizers: int = None,
|
||||
):
|
||||
"""Encode given audio data and return quantized latent codes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"""
|
||||
z = self.encoder(audio_data)
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
|
||||
z, n_quantizers
|
||||
)
|
||||
return z, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio_data: torch.Tensor,
|
||||
sample_rate: int = None,
|
||||
n_quantizers: int = None,
|
||||
):
|
||||
"""Model forward pass
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
sample_rate : int, optional
|
||||
Sample rate of audio data in Hz, by default None
|
||||
If None, defaults to `self.sample_rate`
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None.
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
length = audio_data.shape[-1]
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.encode(
|
||||
audio_data, n_quantizers
|
||||
)
|
||||
|
||||
x = self.decode(z)
|
||||
return {
|
||||
"audio": x[..., :length],
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
model = DAC().to("cpu")
|
||||
|
||||
for n, m in model.named_modules():
|
||||
o = m.extra_repr()
|
||||
p = sum([np.prod(p.size()) for p in m.parameters()])
|
||||
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
|
||||
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
||||
print(model)
|
||||
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
||||
|
||||
length = 88200 * 2
|
||||
x = torch.randn(1, 1, length).to(model.device)
|
||||
x.requires_grad_(True)
|
||||
x.retain_grad()
|
||||
|
||||
# Make a forward pass
|
||||
out = model(x)["audio"]
|
||||
print("Input shape:", x.shape)
|
||||
print("Output shape:", out.shape)
|
||||
|
||||
# Create gradient variable
|
||||
grad = torch.zeros_like(out)
|
||||
grad[:, :, grad.shape[-1] // 2] = 1
|
||||
|
||||
# Make a backward pass
|
||||
out.backward(grad)
|
||||
|
||||
# Check non-zero values
|
||||
gradmap = x.grad.squeeze(0)
|
||||
gradmap = (gradmap != 0).sum(0) # sum across features
|
||||
rf = (gradmap != 0).sum()
|
||||
|
||||
print(f"Receptive field: {rf.item()}")
|
||||
|
||||
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
|
||||
model.decompress(model.compress(x, verbose=True), verbose=True)
|
||||
Reference in New Issue
Block a user