feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
@@ -0,0 +1,427 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize import (
|
||||
ResidualVQ,
|
||||
VectorQuantize,
|
||||
FactorizedVectorQuantize,
|
||||
LookupFreeQuantize,
|
||||
)
|
||||
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.vocos import Vocos
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, dim: int = 16, stride: int = 1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
ResidualUnit(dim // 2, dilation=1),
|
||||
ResidualUnit(dim // 2, dilation=3),
|
||||
ResidualUnit(dim // 2, dilation=9),
|
||||
Snake1d(dim // 2),
|
||||
WNConv1d(
|
||||
dim // 2,
|
||||
dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CodecEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
up_ratios: list = [4, 5, 5, 6],
|
||||
out_channels: int = 256,
|
||||
use_tanh: bool = False,
|
||||
cfg=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
d_model = cfg.d_model if cfg is not None else d_model
|
||||
up_ratios = cfg.up_ratios if cfg is not None else up_ratios
|
||||
out_channels = cfg.out_channels if cfg is not None else out_channels
|
||||
use_tanh = cfg.use_tanh if cfg is not None else use_tanh
|
||||
|
||||
# Create first convolution
|
||||
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in up_ratios:
|
||||
d_model *= 2
|
||||
self.block += [EncoderBlock(d_model, stride=stride)]
|
||||
|
||||
# Create last convolution
|
||||
self.block += [
|
||||
Snake1d(d_model),
|
||||
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
||||
]
|
||||
|
||||
if use_tanh:
|
||||
self.block += [nn.Tanh()]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.apply(init_weights)
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(input_dim),
|
||||
WNConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=stride // 2 + stride % 2,
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
ResidualUnit(output_dim, dilation=1),
|
||||
ResidualUnit(output_dim, dilation=3),
|
||||
ResidualUnit(output_dim, dilation=9),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CodecDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 256,
|
||||
upsample_initial_channel: int = 1536,
|
||||
up_ratios: list = [5, 5, 4, 2],
|
||||
num_quantizers: int = 8,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: int = 256,
|
||||
quantizer_type: str = "vq",
|
||||
quantizer_dropout: float = 0.5,
|
||||
commitment: float = 0.25,
|
||||
codebook_loss_weight: float = 1.0,
|
||||
use_l2_normlize: bool = False,
|
||||
codebook_type: str = "euclidean",
|
||||
kmeans_init: bool = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.8,
|
||||
eps: float = 1e-5,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
weight_init: bool = False,
|
||||
use_vocos: bool = False,
|
||||
vocos_dim: int = 384,
|
||||
vocos_intermediate_dim: int = 1152,
|
||||
vocos_num_layers: int = 8,
|
||||
n_fft: int = 800,
|
||||
hop_size: int = 200,
|
||||
padding: str = "same",
|
||||
cfg=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
in_channels = (
|
||||
cfg.in_channels
|
||||
if cfg is not None and hasattr(cfg, "in_channels")
|
||||
else in_channels
|
||||
)
|
||||
upsample_initial_channel = (
|
||||
cfg.upsample_initial_channel
|
||||
if cfg is not None and hasattr(cfg, "upsample_initial_channel")
|
||||
else upsample_initial_channel
|
||||
)
|
||||
up_ratios = (
|
||||
cfg.up_ratios
|
||||
if cfg is not None and hasattr(cfg, "up_ratios")
|
||||
else up_ratios
|
||||
)
|
||||
num_quantizers = (
|
||||
cfg.num_quantizers
|
||||
if cfg is not None and hasattr(cfg, "num_quantizers")
|
||||
else num_quantizers
|
||||
)
|
||||
codebook_size = (
|
||||
cfg.codebook_size
|
||||
if cfg is not None and hasattr(cfg, "codebook_size")
|
||||
else codebook_size
|
||||
)
|
||||
codebook_dim = (
|
||||
cfg.codebook_dim
|
||||
if cfg is not None and hasattr(cfg, "codebook_dim")
|
||||
else codebook_dim
|
||||
)
|
||||
quantizer_type = (
|
||||
cfg.quantizer_type
|
||||
if cfg is not None and hasattr(cfg, "quantizer_type")
|
||||
else quantizer_type
|
||||
)
|
||||
quantizer_dropout = (
|
||||
cfg.quantizer_dropout
|
||||
if cfg is not None and hasattr(cfg, "quantizer_dropout")
|
||||
else quantizer_dropout
|
||||
)
|
||||
commitment = (
|
||||
cfg.commitment
|
||||
if cfg is not None and hasattr(cfg, "commitment")
|
||||
else commitment
|
||||
)
|
||||
codebook_loss_weight = (
|
||||
cfg.codebook_loss_weight
|
||||
if cfg is not None and hasattr(cfg, "codebook_loss_weight")
|
||||
else codebook_loss_weight
|
||||
)
|
||||
use_l2_normlize = (
|
||||
cfg.use_l2_normlize
|
||||
if cfg is not None and hasattr(cfg, "use_l2_normlize")
|
||||
else use_l2_normlize
|
||||
)
|
||||
codebook_type = (
|
||||
cfg.codebook_type
|
||||
if cfg is not None and hasattr(cfg, "codebook_type")
|
||||
else codebook_type
|
||||
)
|
||||
kmeans_init = (
|
||||
cfg.kmeans_init
|
||||
if cfg is not None and hasattr(cfg, "kmeans_init")
|
||||
else kmeans_init
|
||||
)
|
||||
kmeans_iters = (
|
||||
cfg.kmeans_iters
|
||||
if cfg is not None and hasattr(cfg, "kmeans_iters")
|
||||
else kmeans_iters
|
||||
)
|
||||
decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
|
||||
eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
|
||||
threshold_ema_dead_code = (
|
||||
cfg.threshold_ema_dead_code
|
||||
if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
|
||||
else threshold_ema_dead_code
|
||||
)
|
||||
weight_init = (
|
||||
cfg.weight_init
|
||||
if cfg is not None and hasattr(cfg, "weight_init")
|
||||
else weight_init
|
||||
)
|
||||
use_vocos = (
|
||||
cfg.use_vocos
|
||||
if cfg is not None and hasattr(cfg, "use_vocos")
|
||||
else use_vocos
|
||||
)
|
||||
vocos_dim = (
|
||||
cfg.vocos_dim
|
||||
if cfg is not None and hasattr(cfg, "vocos_dim")
|
||||
else vocos_dim
|
||||
)
|
||||
vocos_intermediate_dim = (
|
||||
cfg.vocos_intermediate_dim
|
||||
if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
|
||||
else vocos_intermediate_dim
|
||||
)
|
||||
vocos_num_layers = (
|
||||
cfg.vocos_num_layers
|
||||
if cfg is not None and hasattr(cfg, "vocos_num_layers")
|
||||
else vocos_num_layers
|
||||
)
|
||||
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
||||
hop_size = (
|
||||
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
||||
)
|
||||
padding = (
|
||||
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
||||
)
|
||||
|
||||
if quantizer_type == "vq":
|
||||
self.quantizer = ResidualVQ(
|
||||
input_dim=in_channels,
|
||||
num_quantizers=num_quantizers,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_type=quantizer_type,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
commitment=commitment,
|
||||
codebook_loss_weight=codebook_loss_weight,
|
||||
use_l2_normlize=use_l2_normlize,
|
||||
codebook_type=codebook_type,
|
||||
kmeans_init=kmeans_init,
|
||||
kmeans_iters=kmeans_iters,
|
||||
decay=decay,
|
||||
eps=eps,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
||||
weight_init=weight_init,
|
||||
)
|
||||
elif quantizer_type == "fvq":
|
||||
self.quantizer = ResidualVQ(
|
||||
input_dim=in_channels,
|
||||
num_quantizers=num_quantizers,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_type=quantizer_type,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
commitment=commitment,
|
||||
codebook_loss_weight=codebook_loss_weight,
|
||||
use_l2_normlize=use_l2_normlize,
|
||||
)
|
||||
elif quantizer_type == "lfq":
|
||||
self.quantizer = ResidualVQ(
|
||||
input_dim=in_channels,
|
||||
num_quantizers=num_quantizers,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_type=quantizer_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
||||
|
||||
if not use_vocos:
|
||||
# Add first conv layer
|
||||
channels = upsample_initial_channel
|
||||
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(up_ratios):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
if use_vocos:
|
||||
self.model = Vocos(
|
||||
input_channels=in_channels,
|
||||
dim=vocos_dim,
|
||||
intermediate_dim=vocos_intermediate_dim,
|
||||
num_layers=vocos_num_layers,
|
||||
adanorm_num_embeddings=None,
|
||||
n_fft=n_fft,
|
||||
hop_size=hop_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
|
||||
"""
|
||||
if vq is True, x = encoder output, then return quantized output;
|
||||
else, x = quantized output, then return decoder output
|
||||
"""
|
||||
if vq is True:
|
||||
if eval_vq:
|
||||
self.quantizer.eval()
|
||||
(
|
||||
quantized_out,
|
||||
all_indices,
|
||||
all_commit_losses,
|
||||
all_codebook_losses,
|
||||
all_quantized,
|
||||
) = self.quantizer(x, n_quantizers=n_quantizers)
|
||||
return (
|
||||
quantized_out,
|
||||
all_indices,
|
||||
all_commit_losses,
|
||||
all_codebook_losses,
|
||||
all_quantized,
|
||||
)
|
||||
|
||||
return self.model(x)
|
||||
|
||||
def quantize(self, x, n_quantizers=None):
|
||||
self.quantizer.eval()
|
||||
quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
|
||||
return quantized_out, vq
|
||||
|
||||
# TODO: check consistency of vq2emb and quantize
|
||||
def vq2emb(self, vq, n_quantizers=None):
|
||||
return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
|
||||
|
||||
def decode(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def latent2dist(self, x, n_quantizers=None):
|
||||
return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.apply(init_weights)
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
||||
FactorizedVectorQuantize,
|
||||
)
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
|
||||
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class FactorizedVectorQuantize(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
codebook_size,
|
||||
codebook_dim,
|
||||
commitment=0.005,
|
||||
codebook_loss_weight=1.0,
|
||||
use_l2_normlize=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.commitment = commitment
|
||||
self.codebook_loss_weight = codebook_loss_weight
|
||||
self.use_l2_normlize = use_l2_normlize
|
||||
|
||||
if self.input_dim != self.codebook_dim:
|
||||
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
||||
self.out_project = WNConv1d(
|
||||
self.codebook_dim, self.input_dim, kernel_size=1
|
||||
)
|
||||
|
||||
else:
|
||||
self.in_project = nn.Identity()
|
||||
self.out_project = nn.Identity()
|
||||
|
||||
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
z: torch.Tensor[B x D x T]
|
||||
|
||||
Returns
|
||||
-------
|
||||
z_q: torch.Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
commit_loss: Tensor[B]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook entries
|
||||
codebook_loss: Tensor[B]
|
||||
Codebook loss to update the codebook
|
||||
indices: torch.Tensor[B x T]
|
||||
Codebook indices (quantized discrete representation of input)
|
||||
z_e: torch.Tensor[B x D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"""
|
||||
|
||||
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
||||
z_e = self.in_project(z)
|
||||
z_q, indices = self.decode_latents(z_e)
|
||||
|
||||
# Compute commitment loss and codebook loss
|
||||
if self.training:
|
||||
commit_loss = (
|
||||
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
||||
* self.commitment
|
||||
)
|
||||
codebook_loss = (
|
||||
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
||||
* self.codebook_loss_weight
|
||||
)
|
||||
else:
|
||||
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
|
||||
z_q = z_e + (z_q - z_e).detach()
|
||||
|
||||
z_q = self.out_project(z_q)
|
||||
|
||||
return z_q, commit_loss, codebook_loss, indices, z_e
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.codebook.weight)
|
||||
|
||||
def decode_code(self, embed_id):
|
||||
return self.embed_code(embed_id).transpose(1, 2)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
encodings = rearrange(latents, "b d t -> (b t) d")
|
||||
codebook = self.codebook.weight
|
||||
|
||||
# L2 normalize encodings and codebook
|
||||
if self.use_l2_normlize:
|
||||
encodings = F.normalize(encodings)
|
||||
codebook = F.normalize(codebook)
|
||||
|
||||
# Compute euclidean distance between encodings and codebook,
|
||||
# if use_l2_normlize is True, the distance is equal to cosine distance
|
||||
dist = (
|
||||
encodings.pow(2).sum(1, keepdim=True)
|
||||
- 2 * encodings @ codebook.t()
|
||||
+ codebook.pow(2).sum(1, keepdim=True).t()
|
||||
)
|
||||
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||
z_q = self.decode_code(indices)
|
||||
|
||||
return z_q, indices
|
||||
|
||||
def vq2emb(self, vq, out_proj=True):
|
||||
emb = self.decode_code(vq)
|
||||
if out_proj:
|
||||
emb = self.out_project(emb)
|
||||
return emb
|
||||
|
||||
def latent2dist(self, latents):
|
||||
encodings = rearrange(latents, "b d t -> (b t) d")
|
||||
codebook = self.codebook.weight
|
||||
|
||||
# L2 normalize encodings and codebook
|
||||
if self.use_l2_normlize:
|
||||
encodings = F.normalize(encodings)
|
||||
codebook = F.normalize(codebook)
|
||||
|
||||
# Compute euclidean distance between encodings and codebook,
|
||||
# if use_l2_normlize is True, the distance is equal to cosine distance
|
||||
dist = (
|
||||
encodings.pow(2).sum(1, keepdim=True)
|
||||
- 2 * encodings @ codebook.t()
|
||||
+ codebook.pow(2).sum(1, keepdim=True).t()
|
||||
) # (b*t, k)
|
||||
|
||||
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
|
||||
z_q = self.decode_code(indices)
|
||||
|
||||
return -dist, indices, z_q
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class LookupFreeQuantize(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
codebook_size,
|
||||
codebook_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
|
||||
assert 2**codebook_dim == codebook_size
|
||||
|
||||
if self.input_dim != self.codebook_dim:
|
||||
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
||||
self.out_project = WNConv1d(
|
||||
self.codebook_dim, self.input_dim, kernel_size=1
|
||||
)
|
||||
|
||||
else:
|
||||
self.in_project = nn.Identity()
|
||||
self.out_project = nn.Identity()
|
||||
|
||||
def forward(self, z):
|
||||
z_e = self.in_project(z)
|
||||
z_e = F.sigmoid(z_e)
|
||||
|
||||
z_q = z_e + (torch.round(z_e) - z_e).detach()
|
||||
|
||||
z_q = self.out_project(z_q)
|
||||
|
||||
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
|
||||
bits = (
|
||||
2
|
||||
** torch.arange(self.codebook_dim, device=z.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
.long()
|
||||
) # (1, d, 1)
|
||||
indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
|
||||
|
||||
return z_q, commit_loss, codebook_loss, indices, z_e
|
||||
|
||||
def vq2emb(self, vq, out_proj=True):
|
||||
emb = torch.zeros(
|
||||
vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
|
||||
) # (B, d, T)
|
||||
for i in range(self.codebook_dim):
|
||||
emb[:, i, :] = (vq % 2).float()
|
||||
vq = vq // 2
|
||||
if out_proj:
|
||||
emb = self.out_project(emb)
|
||||
return emb
|
||||
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
||||
FactorizedVectorQuantize,
|
||||
)
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
||||
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
"""
|
||||
Introduced in SoundStream: An end2end neural audio codec
|
||||
https://arxiv.org/abs/2107.03312
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 256,
|
||||
num_quantizers: int = 8,
|
||||
codebook_size: int = 1024,
|
||||
codebook_dim: int = 256,
|
||||
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
|
||||
quantizer_dropout: float = 0.5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer_type = quantizer_type
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
|
||||
if quantizer_type == "vq":
|
||||
VQ = VectorQuantize
|
||||
elif quantizer_type == "fvq":
|
||||
VQ = FactorizedVectorQuantize
|
||||
elif quantizer_type == "lfq":
|
||||
VQ = LookupFreeQuantize
|
||||
else:
|
||||
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
||||
|
||||
self.quantizers = nn.ModuleList(
|
||||
[
|
||||
VQ(
|
||||
input_dim=input_dim,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(num_quantizers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, z, n_quantizers: int = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
n_quantizers : int, optional
|
||||
No. of quantizers to use
|
||||
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
||||
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
||||
when in training mode, and a random number of quantizers is used.
|
||||
Returns
|
||||
-------
|
||||
"quantized_out" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"all_indices" : Tensor[N x B x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"all_commit_losses" : Tensor[N]
|
||||
"all_codebook_losses" : Tensor[N]
|
||||
"all_quantized" : Tensor[N x B x D x T]
|
||||
"""
|
||||
|
||||
quantized_out = 0.0
|
||||
residual = z
|
||||
|
||||
all_commit_losses = []
|
||||
all_codebook_losses = []
|
||||
all_indices = []
|
||||
all_quantized = []
|
||||
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.num_quantizers
|
||||
|
||||
if self.training:
|
||||
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
|
||||
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
|
||||
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
||||
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||
n_quantizers = n_quantizers.to(z.device)
|
||||
|
||||
for i, quantizer in enumerate(self.quantizers):
|
||||
if self.training is False and i >= n_quantizers:
|
||||
break
|
||||
|
||||
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
||||
residual
|
||||
)
|
||||
|
||||
# Create mask to apply quantizer dropout
|
||||
mask = (
|
||||
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
||||
)
|
||||
quantized_out = quantized_out + z_q_i * mask[:, None, None]
|
||||
residual = residual - z_q_i
|
||||
|
||||
commit_loss_i = (commit_loss_i * mask).mean()
|
||||
codebook_loss_i = (codebook_loss_i * mask).mean()
|
||||
|
||||
all_commit_losses.append(commit_loss_i)
|
||||
all_codebook_losses.append(codebook_loss_i)
|
||||
all_indices.append(indices_i)
|
||||
all_quantized.append(z_q_i)
|
||||
|
||||
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
|
||||
torch.stack,
|
||||
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
|
||||
)
|
||||
|
||||
return (
|
||||
quantized_out,
|
||||
all_indices,
|
||||
all_commit_losses,
|
||||
all_codebook_losses,
|
||||
all_quantized,
|
||||
)
|
||||
|
||||
def vq2emb(self, vq, n_quantizers=None):
|
||||
quantized_out = 0.0
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.num_quantizers
|
||||
for idx, quantizer in enumerate(self.quantizers):
|
||||
if idx >= n_quantizers:
|
||||
break
|
||||
quantized_out += quantizer.vq2emb(vq[idx])
|
||||
return quantized_out
|
||||
|
||||
def latent2dist(self, z, n_quantizers=None):
|
||||
quantized_out = 0.0
|
||||
residual = z
|
||||
|
||||
all_dists = []
|
||||
all_indices = []
|
||||
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.num_quantizers
|
||||
|
||||
for i, quantizer in enumerate(self.quantizers):
|
||||
if self.training is False and i >= n_quantizers:
|
||||
break
|
||||
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
|
||||
all_dists.append(dist_i)
|
||||
all_indices.append(indices_i)
|
||||
|
||||
quantized_out = quantized_out + z_q_i
|
||||
residual = residual - z_q_i
|
||||
|
||||
all_dists = torch.stack(all_dists)
|
||||
all_indices = torch.stack(all_indices)
|
||||
|
||||
return all_dists, all_indices
|
||||
@@ -0,0 +1,401 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories, eps=1e-5):
|
||||
return (x + eps) / (x.sum() + n_categories * eps)
|
||||
|
||||
|
||||
def sample_vectors(samples, num):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
||||
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if use_cosine_sim:
|
||||
dists = samples @ means.t()
|
||||
else:
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
||||
means, "c d -> () c d"
|
||||
)
|
||||
dists = -(diffs**2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
if use_cosine_sim:
|
||||
new_means = l2norm(new_means)
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
codebook_size,
|
||||
kmeans_init=False,
|
||||
kmeans_iters=10,
|
||||
decay=0.8,
|
||||
eps=1e-5,
|
||||
threshold_ema_dead_code=2,
|
||||
weight_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.decay = decay
|
||||
init_fn = torch.randn if not weight_init else torch.zeros
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
if weight_init:
|
||||
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.register_buffer(
|
||||
"initted", torch.Tensor([not kmeans_init])
|
||||
) # if kmeans_init is True, then initted is False; otherwise, initted is True
|
||||
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
def init_embed_(self, data):
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed)
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
def replace(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace(batch_samples, mask=expired_codes)
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
flatten = rearrange(x, "... d -> (...) d")
|
||||
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
||||
|
||||
if not self.initted:
|
||||
self.init_embed_(flatten)
|
||||
|
||||
dist = -(
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = embed_ind.view(*shape[:-1])
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
|
||||
if self.training:
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = (
|
||||
flatten.t() @ embed_onehot
|
||||
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
self.expire_codes_(x)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
def vq2emb(self, vq):
|
||||
quantize = F.embedding(vq, self.embed)
|
||||
return quantize
|
||||
|
||||
def latent2dist(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
flatten = rearrange(x, "... d -> (...) d")
|
||||
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
||||
|
||||
if not self.initted:
|
||||
self.init_embed_(flatten)
|
||||
|
||||
dist = -(
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
embed_ind = embed_ind.view(*shape[:-1])
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
|
||||
dist = dist.view(*shape[:-1], -1)
|
||||
|
||||
return dist, embed_ind, quantize
|
||||
|
||||
|
||||
class SimpleCodebook(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
codebook_size,
|
||||
use_l2_normlize=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.codebook_size = codebook_size
|
||||
self.use_l2_normlize = use_l2_normlize
|
||||
|
||||
self.embed = nn.Embedding(self.codebook_size, self.dim)
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
flatten = rearrange(x, "... d -> (...) d")
|
||||
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
||||
|
||||
if self.use_l2_normlize:
|
||||
flatten = F.normalize(flatten)
|
||||
embed = F.normalize(embed)
|
||||
|
||||
dist = -(
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
embed_ind = embed_ind.view(*shape[:-1])
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
def vq2emb(self, vq):
|
||||
quantize = F.embedding(vq, self.embed.weight)
|
||||
return quantize
|
||||
|
||||
def latent2dist(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
flatten = rearrange(x, "... d -> (...) d")
|
||||
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
||||
|
||||
if self.use_l2_normlize:
|
||||
flatten = F.normalize(flatten)
|
||||
embed = F.normalize(embed)
|
||||
|
||||
dist = -(
|
||||
flatten.pow(2).sum(1, keepdim=True)
|
||||
- 2 * flatten @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
embed_ind = embed_ind.view(*shape[:-1])
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
|
||||
dist = dist.view(*shape[:-1], -1)
|
||||
|
||||
return dist, embed_ind, quantize
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
"""Vector quantization and factorized vecotor quantization implementation
|
||||
Args:
|
||||
input_dim (int): Dimension of input.
|
||||
codebook_size (int): Codebook size.
|
||||
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
|
||||
if use codebook_type == "euclidean", otherwise, if you want to use
|
||||
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
|
||||
commitment (float): Weight for commitment loss.
|
||||
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
|
||||
we suggest use it as True if you want to use factorized vector quantization
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
codebook_size,
|
||||
codebook_dim,
|
||||
commitment=0.005,
|
||||
codebook_loss_weight=1.0,
|
||||
use_l2_normlize=False,
|
||||
codebook_type="euclidean", # "euclidean" or "simple"
|
||||
kmeans_init=False,
|
||||
kmeans_iters=10,
|
||||
decay=0.8,
|
||||
eps=1e-5,
|
||||
threshold_ema_dead_code=2,
|
||||
weight_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.commitment = commitment
|
||||
self.codebook_loss_weight = codebook_loss_weight
|
||||
self.use_l2_normlize = use_l2_normlize
|
||||
self.codebook_type = codebook_type
|
||||
self.kmeans_init = kmeans_init
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.weight_init = weight_init
|
||||
|
||||
if self.input_dim != self.codebook_dim:
|
||||
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
||||
self.out_project = WNConv1d(
|
||||
self.codebook_dim, self.input_dim, kernel_size=1
|
||||
)
|
||||
|
||||
else:
|
||||
self.in_project = nn.Identity()
|
||||
self.out_project = nn.Identity()
|
||||
|
||||
if self.codebook_type == "euclidean":
|
||||
self.codebook = EuclideanCodebook(
|
||||
self.codebook_dim,
|
||||
codebook_size=self.codebook_size,
|
||||
kmeans_init=self.kmeans_init,
|
||||
kmeans_iters=self.kmeans_iters,
|
||||
decay=self.decay,
|
||||
eps=self.eps,
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
weight_init=self.weight_init,
|
||||
)
|
||||
elif self.codebook_type == "simple":
|
||||
self.codebook = SimpleCodebook(
|
||||
self.codebook_dim,
|
||||
codebook_size=self.codebook_size,
|
||||
use_l2_normlize=self.use_l2_normlize,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"codebook_type {self.codebook_type} is not implemented!"
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
z: torch.Tensor[B x D x T]
|
||||
|
||||
Returns
|
||||
-------
|
||||
z_q: torch.Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
commit_loss: Tensor[B]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook entries
|
||||
codebook_loss: Tensor[B]
|
||||
Codebook loss to update the codebook
|
||||
indices: torch.Tensor[B x T]
|
||||
Codebook indices (quantized discrete representation of input)
|
||||
z_e: torch.Tensor[B x D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"""
|
||||
|
||||
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
||||
z_e = self.in_project(z)
|
||||
z_q, indices = self.decode_latents(z_e)
|
||||
|
||||
# Compute commitment loss and codebook loss
|
||||
if self.training:
|
||||
commit_loss = (
|
||||
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
||||
* self.commitment
|
||||
)
|
||||
codebook_loss = (
|
||||
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
||||
* self.codebook_loss_weight
|
||||
)
|
||||
else:
|
||||
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
|
||||
z_q = z_e + (z_q - z_e).detach()
|
||||
|
||||
z_q = self.out_project(z_q)
|
||||
|
||||
return z_q, commit_loss, codebook_loss, indices, z_e
|
||||
|
||||
def decode_latents(self, latents):
|
||||
encodings = rearrange(latents, "b d t -> b t d")
|
||||
z_q, indices = self.codebook(encodings)
|
||||
z_q = z_q.transpose(1, 2)
|
||||
return z_q, indices
|
||||
|
||||
def vq2emb(self, vq, out_proj=True):
|
||||
emb = self.codebook.vq2emb(vq)
|
||||
emb = emb.transpose(1, 2)
|
||||
if out_proj:
|
||||
emb = self.out_project(emb)
|
||||
return emb
|
||||
|
||||
def latent2dist(self, latents):
|
||||
latents = rearrange(latents, "b d t -> b t d")
|
||||
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
|
||||
return dist, embed_ind, quantize.transpose(1, 2)
|
||||
@@ -0,0 +1,881 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from torch import nn, view_as_real, view_as_complex
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
||||
import librosa
|
||||
|
||||
|
||||
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
||||
"""
|
||||
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
||||
"""
|
||||
return torch.log(torch.clip(x, min=clip_val))
|
||||
|
||||
|
||||
def symlog(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sign(x) * torch.log1p(x.abs())
|
||||
|
||||
|
||||
def symexp(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
||||
|
||||
|
||||
class STFT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
center=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.center = center
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
window = torch.hann_window(win_length)
|
||||
self.register_buffer("window", window)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, T * hop_length)
|
||||
|
||||
if not self.center:
|
||||
pad = self.win_length - self.hop_length
|
||||
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
||||
|
||||
stft_spec = torch.stft(
|
||||
x,
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
return_complex=False,
|
||||
) # (B, n_fft // 2 + 1, T, 2)
|
||||
|
||||
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
||||
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
||||
|
||||
log_mag = torch.log(
|
||||
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
||||
) # (B, n_fft // 2 + 1, T)
|
||||
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
||||
|
||||
return log_mag, phase
|
||||
|
||||
|
||||
class ISTFT(nn.Module):
|
||||
"""
|
||||
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
||||
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
||||
See issue: https://github.com/pytorch/pytorch/issues/62323
|
||||
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
||||
The NOLA constraint is met as we trim padded samples anyway.
|
||||
|
||||
Args:
|
||||
n_fft (int): Size of Fourier transform.
|
||||
hop_length (int): The distance between neighboring sliding window frames.
|
||||
win_length (int): The size of window frame and STFT filter.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
||||
):
|
||||
super().__init__()
|
||||
if padding not in ["center", "same"]:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
self.padding = padding
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
window = torch.hann_window(win_length)
|
||||
self.register_buffer("window", window)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
||||
|
||||
Args:
|
||||
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
||||
N is the number of frequency bins, and T is the number of time frames.
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
||||
"""
|
||||
if self.padding == "center":
|
||||
# Fallback to pytorch native implementation
|
||||
return torch.istft(
|
||||
spec,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
)
|
||||
elif self.padding == "same":
|
||||
pad = (self.win_length - self.hop_length) // 2
|
||||
else:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
|
||||
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
||||
B, N, T = spec.shape
|
||||
|
||||
# Inverse FFT
|
||||
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
||||
ifft = ifft * self.window[None, :, None]
|
||||
|
||||
# Overlap and Add
|
||||
output_size = (T - 1) * self.hop_length + self.win_length
|
||||
y = torch.nn.functional.fold(
|
||||
ifft,
|
||||
output_size=(1, output_size),
|
||||
kernel_size=(1, self.win_length),
|
||||
stride=(1, self.hop_length),
|
||||
)[:, 0, 0, pad:-pad]
|
||||
|
||||
# Window envelope
|
||||
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
||||
window_envelope = torch.nn.functional.fold(
|
||||
window_sq,
|
||||
output_size=(1, output_size),
|
||||
kernel_size=(1, self.win_length),
|
||||
stride=(1, self.hop_length),
|
||||
).squeeze()[pad:-pad]
|
||||
|
||||
# Normalize
|
||||
assert (window_envelope > 1e-11).all()
|
||||
y = y / window_envelope
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class MDCT(nn.Module):
|
||||
"""
|
||||
Modified Discrete Cosine Transform (MDCT) module.
|
||||
|
||||
Args:
|
||||
frame_len (int): Length of the MDCT frame.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
"""
|
||||
|
||||
def __init__(self, frame_len: int, padding: str = "same"):
|
||||
super().__init__()
|
||||
if padding not in ["center", "same"]:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
self.padding = padding
|
||||
self.frame_len = frame_len
|
||||
N = frame_len // 2
|
||||
n0 = (N + 1) / 2
|
||||
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
||||
self.register_buffer("window", window)
|
||||
|
||||
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
||||
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
||||
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
||||
# https://github.com/pytorch/pytorch/issues/71613
|
||||
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
||||
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
||||
|
||||
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
||||
|
||||
Args:
|
||||
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
||||
and T is the length of the audio.
|
||||
|
||||
Returns:
|
||||
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
||||
and N is the number of frequency bins.
|
||||
"""
|
||||
if self.padding == "center":
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (self.frame_len // 2, self.frame_len // 2)
|
||||
)
|
||||
elif self.padding == "same":
|
||||
# hop_length is 1/2 frame_len
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (self.frame_len // 4, self.frame_len // 4)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
|
||||
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
||||
N = self.frame_len // 2
|
||||
x = x * self.window.expand(x.shape)
|
||||
X = torch.fft.fft(
|
||||
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
||||
)[..., :N]
|
||||
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
||||
return torch.real(res) * np.sqrt(2)
|
||||
|
||||
|
||||
class IMDCT(nn.Module):
|
||||
"""
|
||||
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
||||
|
||||
Args:
|
||||
frame_len (int): Length of the MDCT frame.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
"""
|
||||
|
||||
def __init__(self, frame_len: int, padding: str = "same"):
|
||||
super().__init__()
|
||||
if padding not in ["center", "same"]:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
self.padding = padding
|
||||
self.frame_len = frame_len
|
||||
N = frame_len // 2
|
||||
n0 = (N + 1) / 2
|
||||
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
||||
self.register_buffer("window", window)
|
||||
|
||||
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
||||
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
||||
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
||||
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
||||
|
||||
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
||||
|
||||
Args:
|
||||
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
||||
L is the number of frames, and N is the number of frequency bins.
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
||||
"""
|
||||
B, L, N = X.shape
|
||||
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
||||
Y[..., :N] = X
|
||||
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
||||
y = torch.fft.ifft(
|
||||
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
||||
)
|
||||
y = (
|
||||
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
||||
* np.sqrt(N)
|
||||
* np.sqrt(2)
|
||||
)
|
||||
result = y * self.window.expand(y.shape)
|
||||
output_size = (1, (L + 1) * N)
|
||||
audio = torch.nn.functional.fold(
|
||||
result.transpose(1, 2),
|
||||
output_size=output_size,
|
||||
kernel_size=(1, self.frame_len),
|
||||
stride=(1, self.frame_len // 2),
|
||||
)[:, 0, 0, :]
|
||||
|
||||
if self.padding == "center":
|
||||
pad = self.frame_len // 2
|
||||
elif self.padding == "same":
|
||||
pad = self.frame_len // 4
|
||||
else:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
|
||||
audio = audio[:, pad:-pad]
|
||||
return audio
|
||||
|
||||
|
||||
class FourierHead(nn.Module):
|
||||
"""Base class for inverse fourier modules."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
||||
L is the sequence length, and H denotes the model dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the forward method.")
|
||||
|
||||
|
||||
class ISTFTHead(FourierHead):
|
||||
"""
|
||||
ISTFT Head module for predicting STFT complex coefficients.
|
||||
|
||||
Args:
|
||||
dim (int): Hidden dimension of the model.
|
||||
n_fft (int): Size of Fourier transform.
|
||||
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
||||
the resolution of the input features.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
||||
super().__init__()
|
||||
out_dim = n_fft + 2
|
||||
self.out = torch.nn.Linear(dim, out_dim)
|
||||
self.istft = ISTFT(
|
||||
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the ISTFTHead module.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
||||
L is the sequence length, and H denotes the model dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
||||
"""
|
||||
x = self.out(x).transpose(1, 2)
|
||||
mag, p = x.chunk(2, dim=1)
|
||||
mag = torch.exp(mag)
|
||||
mag = torch.clip(
|
||||
mag, max=1e2
|
||||
) # safeguard to prevent excessively large magnitudes
|
||||
# wrapping happens here. These two lines produce real and imaginary value
|
||||
x = torch.cos(p)
|
||||
y = torch.sin(p)
|
||||
# recalculating phase here does not produce anything new
|
||||
# only costs time
|
||||
# phase = torch.atan2(y, x)
|
||||
# S = mag * torch.exp(phase * 1j)
|
||||
# better directly produce the complex value
|
||||
S = mag * (x + 1j * y)
|
||||
audio = self.istft(S)
|
||||
return audio
|
||||
|
||||
|
||||
class IMDCTSymExpHead(FourierHead):
|
||||
"""
|
||||
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
||||
|
||||
Args:
|
||||
dim (int): Hidden dimension of the model.
|
||||
mdct_frame_len (int): Length of the MDCT frame.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
||||
based on perceptual scaling. Defaults to None.
|
||||
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mdct_frame_len: int,
|
||||
padding: str = "same",
|
||||
sample_rate: Optional[int] = None,
|
||||
clip_audio: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
out_dim = mdct_frame_len // 2
|
||||
self.out = nn.Linear(dim, out_dim)
|
||||
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
||||
self.clip_audio = clip_audio
|
||||
|
||||
if sample_rate is not None:
|
||||
# optionally init the last layer following mel-scale
|
||||
m_max = _hz_to_mel(sample_rate // 2)
|
||||
m_pts = torch.linspace(0, m_max, out_dim)
|
||||
f_pts = _mel_to_hz(m_pts)
|
||||
scale = 1 - (f_pts / f_pts.max())
|
||||
|
||||
with torch.no_grad():
|
||||
self.out.weight.mul_(scale.view(-1, 1))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the IMDCTSymExpHead module.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
||||
L is the sequence length, and H denotes the model dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
||||
"""
|
||||
x = self.out(x)
|
||||
x = symexp(x)
|
||||
x = torch.clip(
|
||||
x, min=-1e2, max=1e2
|
||||
) # safeguard to prevent excessively large magnitudes
|
||||
audio = self.imdct(x)
|
||||
if self.clip_audio:
|
||||
audio = torch.clip(x, min=-1.0, max=1.0)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
class IMDCTCosHead(FourierHead):
|
||||
"""
|
||||
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
||||
|
||||
Args:
|
||||
dim (int): Hidden dimension of the model.
|
||||
mdct_frame_len (int): Length of the MDCT frame.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mdct_frame_len: int,
|
||||
padding: str = "same",
|
||||
clip_audio: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.clip_audio = clip_audio
|
||||
self.out = nn.Linear(dim, mdct_frame_len)
|
||||
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the IMDCTCosHead module.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
||||
L is the sequence length, and H denotes the model dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
||||
"""
|
||||
x = self.out(x)
|
||||
m, p = x.chunk(2, dim=2)
|
||||
m = torch.exp(m).clip(
|
||||
max=1e2
|
||||
) # safeguard to prevent excessively large magnitudes
|
||||
audio = self.imdct(m * torch.cos(p))
|
||||
if self.clip_audio:
|
||||
audio = torch.clip(x, min=-1.0, max=1.0)
|
||||
return audio
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
intermediate_dim (int): Dimensionality of the intermediate layer.
|
||||
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||
Defaults to None.
|
||||
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
||||
None means non-conditional LayerNorm. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
layer_scale_init_value: float,
|
||||
adanorm_num_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||
) # depthwise conv
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
if adanorm_num_embeddings:
|
||||
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(
|
||||
dim, intermediate_dim
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||
if self.adanorm:
|
||||
assert cond_embedding_id is not None
|
||||
x = self.norm(x, cond_embedding_id)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
||||
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
||||
|
||||
Args:
|
||||
num_embeddings (int): Number of embeddings.
|
||||
embedding_dim (int): Dimension of the embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = embedding_dim
|
||||
self.scale = nn.Embedding(
|
||||
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
||||
)
|
||||
self.shift = nn.Embedding(
|
||||
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
||||
)
|
||||
torch.nn.init.ones_(self.scale.weight)
|
||||
torch.nn.init.zeros_(self.shift.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
||||
scale = self.scale(cond_embedding_id)
|
||||
shift = self.shift(cond_embedding_id)
|
||||
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
||||
x = x * scale + shift
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock1(nn.Module):
|
||||
"""
|
||||
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
||||
but without upsampling layers.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
||||
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
||||
Defaults to (1, 3, 5).
|
||||
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
||||
Defaults to 0.1.
|
||||
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int, int] = (1, 3, 5),
|
||||
lrelu_slope: float = 0.1,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=self.get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=self.get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=self.get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=self.get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=self.get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=self.get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.gamma = nn.ParameterList(
|
||||
[
|
||||
(
|
||||
nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
||||
)
|
||||
if layer_scale_init_value is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
||||
)
|
||||
if layer_scale_init_value is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
||||
)
|
||||
if layer_scale_init_value is not None
|
||||
else None
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
||||
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
||||
xt = c1(xt)
|
||||
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
||||
xt = c2(xt)
|
||||
if gamma is not None:
|
||||
xt = gamma * xt
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
@staticmethod
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class Backbone(nn.Module):
|
||||
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
||||
C denotes output features, and L is the sequence length.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
||||
and H denotes the model dimension.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the forward method.")
|
||||
|
||||
|
||||
class VocosBackbone(Backbone):
|
||||
"""
|
||||
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
||||
|
||||
Args:
|
||||
input_channels (int): Number of input features channels.
|
||||
dim (int): Hidden dimension of the model.
|
||||
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
||||
num_layers (int): Number of ConvNeXtBlock layers.
|
||||
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
||||
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
||||
None means non-conditional model. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
num_layers: int,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
adanorm_num_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
if adanorm_num_embeddings:
|
||||
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
||||
self.convnext = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(
|
||||
dim=dim,
|
||||
intermediate_dim=intermediate_dim,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
adanorm_num_embeddings=adanorm_num_embeddings,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
bandwidth_id = kwargs.get("bandwidth_id", None)
|
||||
x = self.embed(x)
|
||||
if self.adanorm:
|
||||
assert bandwidth_id is not None
|
||||
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
||||
else:
|
||||
x = self.norm(x.transpose(1, 2))
|
||||
x = x.transpose(1, 2)
|
||||
for conv_block in self.convnext:
|
||||
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
||||
x = self.final_layer_norm(x.transpose(1, 2))
|
||||
return x
|
||||
|
||||
|
||||
class VocosResNetBackbone(Backbone):
|
||||
"""
|
||||
Vocos backbone module built with ResBlocks.
|
||||
|
||||
Args:
|
||||
input_channels (int): Number of input features channels.
|
||||
dim (int): Hidden dimension of the model.
|
||||
num_blocks (int): Number of ResBlock1 blocks.
|
||||
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels,
|
||||
dim,
|
||||
num_blocks,
|
||||
layer_scale_init_value=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.embed = weight_norm(
|
||||
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
||||
)
|
||||
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
||||
self.resnet = nn.Sequential(
|
||||
*[
|
||||
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.embed(x)
|
||||
x = self.resnet(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Vocos(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 256,
|
||||
dim: int = 384,
|
||||
intermediate_dim: int = 1152,
|
||||
num_layers: int = 8,
|
||||
n_fft: int = 800,
|
||||
hop_size: int = 200,
|
||||
padding: str = "same",
|
||||
adanorm_num_embeddings=None,
|
||||
cfg=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
input_channels = (
|
||||
cfg.input_channels
|
||||
if cfg is not None and hasattr(cfg, "input_channels")
|
||||
else input_channels
|
||||
)
|
||||
dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
|
||||
intermediate_dim = (
|
||||
cfg.intermediate_dim
|
||||
if cfg is not None and hasattr(cfg, "intermediate_dim")
|
||||
else intermediate_dim
|
||||
)
|
||||
num_layers = (
|
||||
cfg.num_layers
|
||||
if cfg is not None and hasattr(cfg, "num_layers")
|
||||
else num_layers
|
||||
)
|
||||
adanorm_num_embeddings = (
|
||||
cfg.adanorm_num_embeddings
|
||||
if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
|
||||
else adanorm_num_embeddings
|
||||
)
|
||||
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
||||
hop_size = (
|
||||
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
||||
)
|
||||
padding = (
|
||||
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
||||
)
|
||||
|
||||
self.backbone = VocosBackbone(
|
||||
input_channels=input_channels,
|
||||
dim=dim,
|
||||
intermediate_dim=intermediate_dim,
|
||||
num_layers=num_layers,
|
||||
adanorm_num_embeddings=adanorm_num_embeddings,
|
||||
)
|
||||
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.head(x)
|
||||
|
||||
return x[:, None, :]
|
||||
Reference in New Issue
Block a user