refactor: rename canto-backend → backend, canto-frontend → frontend
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .fvq import *
|
||||
from .rvq import *
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class FactorizedVectorQuantize(nn.Module):
|
||||
def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.commitment = commitment
|
||||
|
||||
if dim != self.codebook_dim:
|
||||
self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
|
||||
self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
|
||||
else:
|
||||
self.in_proj = nn.Identity()
|
||||
self.out_proj = nn.Identity()
|
||||
self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook
|
||||
|
||||
def forward(self, z):
|
||||
"""Quantized the input tensor using a fixed codebook and returns
|
||||
the corresponding codebook vectors
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
Tensor[B x T]
|
||||
Codebook indices (quantized discrete representation of input)
|
||||
Tensor[B x D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"""
|
||||
# transpose since we use linear
|
||||
|
||||
z = rearrange(z, "b d t -> b t d")
|
||||
|
||||
# Factorized codes project input into low-dimensional space
|
||||
z_e = self.in_proj(z) # z_e : (B x T x D)
|
||||
z_e = rearrange(z_e, "b t d -> b d t")
|
||||
z_q, indices = self.decode_latents(z_e)
|
||||
|
||||
if self.training:
|
||||
commitment_loss = (
|
||||
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
||||
* self.commitment
|
||||
)
|
||||
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
||||
commit_loss = commitment_loss + codebook_loss
|
||||
else:
|
||||
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
||||
|
||||
z_q = (
|
||||
z_e + (z_q - z_e).detach()
|
||||
) # noop in forward pass, straight-through gradient estimator in backward pass
|
||||
|
||||
z_q = rearrange(z_q, "b d t -> b t d")
|
||||
z_q = self.out_proj(z_q)
|
||||
z_q = rearrange(z_q, "b t d -> b d t")
|
||||
|
||||
return z_q, indices, commit_loss
|
||||
|
||||
def vq2emb(self, vq, proj=True):
|
||||
emb = self.embed_code(vq)
|
||||
if proj:
|
||||
emb = self.out_proj(emb)
|
||||
return emb.transpose(1, 2)
|
||||
|
||||
def get_emb(self):
|
||||
return self.codebook.weight
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.codebook.weight)
|
||||
|
||||
def decode_code(self, embed_id):
|
||||
return self.embed_code(embed_id).transpose(1, 2)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
encodings = rearrange(latents, "b d t -> (b t) d")
|
||||
codebook = self.codebook.weight # codebook: (N x D)
|
||||
# L2 normalize encodings and codebook
|
||||
encodings = F.normalize(encodings)
|
||||
codebook = F.normalize(codebook)
|
||||
|
||||
# Compute euclidean distance with codebook
|
||||
dist = (
|
||||
encodings.pow(2).sum(1, keepdim=True)
|
||||
- 2 * encodings @ codebook.t()
|
||||
+ codebook.pow(2).sum(1, keepdim=True).t()
|
||||
)
|
||||
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||
z_q = self.decode_code(indices)
|
||||
return z_q, indices
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .fvq import FactorizedVectorQuantize
|
||||
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
||||
|
||||
def __init__(self, *, num_quantizers, codebook_size, **kwargs):
|
||||
super().__init__()
|
||||
VQ = FactorizedVectorQuantize
|
||||
if type(codebook_size) == int:
|
||||
codebook_size = [codebook_size] * num_quantizers
|
||||
self.layers = nn.ModuleList(
|
||||
[VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
|
||||
)
|
||||
self.num_quantizers = num_quantizers
|
||||
self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
|
||||
self.dropout_type = kwargs.get("dropout_type", None)
|
||||
|
||||
def forward(self, x, n_quantizers=None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
all_quantized = []
|
||||
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.num_quantizers
|
||||
if self.training:
|
||||
n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
|
||||
if self.dropout_type == "linear":
|
||||
dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
|
||||
elif self.dropout_type == "exp":
|
||||
dropout = torch.randint(
|
||||
1, int(math.log2(self.num_quantizers)), (x.shape[0],)
|
||||
)
|
||||
dropout = torch.pow(2, dropout)
|
||||
n_dropout = int(x.shape[0] * self.quantizer_dropout)
|
||||
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||
n_quantizers = n_quantizers.to(x.device)
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
if not self.training and idx >= n_quantizers:
|
||||
break
|
||||
quantized, indices, loss = layer(residual)
|
||||
|
||||
mask = (
|
||||
torch.full((x.shape[0],), fill_value=idx, device=x.device)
|
||||
< n_quantizers
|
||||
)
|
||||
|
||||
residual = residual - quantized
|
||||
|
||||
quantized_out = quantized_out + quantized * mask[:, None, None]
|
||||
|
||||
# loss
|
||||
loss = (loss * mask).mean()
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
all_quantized.append(quantized)
|
||||
all_losses, all_indices, all_quantized = map(
|
||||
torch.stack, (all_losses, all_indices, all_quantized)
|
||||
)
|
||||
return quantized_out, all_indices, all_losses, all_quantized
|
||||
|
||||
def vq2emb(self, vq):
|
||||
# vq: [n_quantizers, B, T]
|
||||
quantized_out = 0.0
|
||||
for idx, layer in enumerate(self.layers):
|
||||
quantized = layer.vq2emb(vq[idx])
|
||||
quantized_out += quantized
|
||||
return quantized_out
|
||||
|
||||
def get_emb(self):
|
||||
embs = []
|
||||
for idx, layer in enumerate(self.layers):
|
||||
embs.append(layer.get_emb())
|
||||
return embs
|
||||
Reference in New Issue
Block a user