refactor: rename canto-backend → backend, canto-frontend → frontend
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
# Copyright (c) 2024 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from concurrent.futures import ALL_COMPLETED
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ResidualVQ
|
||||
from indextts.utils.maskgct.models.codec.kmeans.vocos import VocosBackbone
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def compute_codebook_perplexity(indices, codebook_size):
|
||||
indices = indices.flatten()
|
||||
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
|
||||
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
|
||||
return perp
|
||||
|
||||
|
||||
class RepCodec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
codebook_size=8192,
|
||||
hidden_size=1024,
|
||||
codebook_dim=8,
|
||||
vocos_dim=384,
|
||||
vocos_intermediate_dim=2048,
|
||||
vocos_num_layers=12,
|
||||
num_quantizers=1,
|
||||
downsample_scale=1,
|
||||
cfg=None,
|
||||
):
|
||||
super().__init__()
|
||||
codebook_size = (
|
||||
cfg.codebook_size
|
||||
if cfg is not None and hasattr(cfg, "codebook_size")
|
||||
else codebook_size
|
||||
)
|
||||
codebook_dim = (
|
||||
cfg.codebook_dim
|
||||
if cfg is not None and hasattr(cfg, "codebook_dim")
|
||||
else codebook_dim
|
||||
)
|
||||
hidden_size = (
|
||||
cfg.hidden_size
|
||||
if cfg is not None and hasattr(cfg, "hidden_size")
|
||||
else hidden_size
|
||||
)
|
||||
vocos_dim = (
|
||||
cfg.vocos_dim
|
||||
if cfg is not None and hasattr(cfg, "vocos_dim")
|
||||
else vocos_dim
|
||||
)
|
||||
vocos_intermediate_dim = (
|
||||
cfg.vocos_intermediate_dim
|
||||
if cfg is not None and hasattr(cfg, "vocos_dim")
|
||||
else vocos_intermediate_dim
|
||||
)
|
||||
vocos_num_layers = (
|
||||
cfg.vocos_num_layers
|
||||
if cfg is not None and hasattr(cfg, "vocos_dim")
|
||||
else vocos_num_layers
|
||||
)
|
||||
num_quantizers = (
|
||||
cfg.num_quantizers
|
||||
if cfg is not None and hasattr(cfg, "num_quantizers")
|
||||
else num_quantizers
|
||||
)
|
||||
downsample_scale = (
|
||||
cfg.downsample_scale
|
||||
if cfg is not None and hasattr(cfg, "downsample_scale")
|
||||
else downsample_scale
|
||||
)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.vocos_dim = vocos_dim
|
||||
self.vocos_intermediate_dim = vocos_intermediate_dim
|
||||
self.vocos_num_layers = vocos_num_layers
|
||||
self.num_quantizers = num_quantizers
|
||||
self.downsample_scale = downsample_scale
|
||||
|
||||
if self.downsample_scale != None and self.downsample_scale > 1:
|
||||
self.down = nn.Conv1d(
|
||||
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.up = nn.Conv1d(
|
||||
self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
VocosBackbone(
|
||||
input_channels=self.hidden_size,
|
||||
dim=self.vocos_dim,
|
||||
intermediate_dim=self.vocos_intermediate_dim,
|
||||
num_layers=self.vocos_num_layers,
|
||||
adanorm_num_embeddings=None,
|
||||
),
|
||||
nn.Linear(self.vocos_dim, self.hidden_size),
|
||||
)
|
||||
self.decoder = nn.Sequential(
|
||||
VocosBackbone(
|
||||
input_channels=self.hidden_size,
|
||||
dim=self.vocos_dim,
|
||||
intermediate_dim=self.vocos_intermediate_dim,
|
||||
num_layers=self.vocos_num_layers,
|
||||
adanorm_num_embeddings=None,
|
||||
),
|
||||
nn.Linear(self.vocos_dim, self.hidden_size),
|
||||
)
|
||||
|
||||
self.quantizer = ResidualVQ(
|
||||
input_dim=hidden_size,
|
||||
num_quantizers=num_quantizers,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_type="fvq",
|
||||
quantizer_dropout=0.0,
|
||||
commitment=0.15,
|
||||
codebook_loss_weight=1.0,
|
||||
use_l2_normlize=True,
|
||||
)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# downsample
|
||||
if self.downsample_scale != None and self.downsample_scale > 1:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.down(x)
|
||||
x = F.gelu(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# encoder
|
||||
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
# vq
|
||||
(
|
||||
quantized_out,
|
||||
all_indices,
|
||||
all_commit_losses,
|
||||
all_codebook_losses,
|
||||
_,
|
||||
) = self.quantizer(x)
|
||||
|
||||
# decoder
|
||||
x = self.decoder(quantized_out)
|
||||
|
||||
# up
|
||||
if self.downsample_scale != None and self.downsample_scale > 1:
|
||||
x = x.transpose(1, 2)
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
x_rec = self.up(x).transpose(1, 2)
|
||||
|
||||
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
||||
all_indices = all_indices
|
||||
|
||||
return x_rec, codebook_loss, all_indices
|
||||
|
||||
def quantize(self, x):
|
||||
|
||||
if self.downsample_scale != None and self.downsample_scale > 1:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.down(x)
|
||||
x = F.gelu(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
(
|
||||
quantized_out,
|
||||
all_indices,
|
||||
all_commit_losses,
|
||||
all_codebook_losses,
|
||||
_,
|
||||
) = self.quantizer(x)
|
||||
|
||||
if all_indices.shape[0] == 1:
|
||||
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
||||
return all_indices, quantized_out.transpose(1, 2)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.apply(init_weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
|
||||
print(repcodec)
|
||||
print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
|
||||
x = torch.randn(5, 10, 1024)
|
||||
x_rec, codebook_loss, all_indices = repcodec(x)
|
||||
print(x_rec.shape, codebook_loss, all_indices.shape)
|
||||
vq_id, emb = repcodec.quantize(x)
|
||||
print(vq_id.shape, emb.shape)
|
||||
Reference in New Issue
Block a user