feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Torch modules."""
|
||||
|
||||
# flake8: noqa
|
||||
from .conv import (
|
||||
pad1d,
|
||||
unpad1d,
|
||||
NormConv1d,
|
||||
NormConvTranspose1d,
|
||||
NormConv2d,
|
||||
NormConvTranspose2d,
|
||||
SConv1d,
|
||||
SConvTranspose1d,
|
||||
)
|
||||
from .lstm import SLSTM
|
||||
from .seanet import SEANetEncoder, SEANetDecoder
|
||||
@@ -0,0 +1,346 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Convolutional layers wrappers and utilities."""
|
||||
|
||||
import math
|
||||
import typing as tp
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from .norm import ConvLayerNorm
|
||||
|
||||
|
||||
CONV_NORMALIZATIONS = frozenset(
|
||||
[
|
||||
"none",
|
||||
"weight_norm",
|
||||
"spectral_norm",
|
||||
"time_layer_norm",
|
||||
"layer_norm",
|
||||
"time_group_norm",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
||||
assert norm in CONV_NORMALIZATIONS
|
||||
if norm == "weight_norm":
|
||||
return weight_norm(module)
|
||||
elif norm == "spectral_norm":
|
||||
return spectral_norm(module)
|
||||
else:
|
||||
# We already check was in CONV_NORMALIZATION, so any other choice
|
||||
# doesn't need reparametrization.
|
||||
return module
|
||||
|
||||
|
||||
def get_norm_module(
|
||||
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
||||
) -> nn.Module:
|
||||
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
||||
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
||||
"""
|
||||
assert norm in CONV_NORMALIZATIONS
|
||||
if norm == "layer_norm":
|
||||
assert isinstance(module, nn.modules.conv._ConvNd)
|
||||
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
||||
elif norm == "time_group_norm":
|
||||
if causal:
|
||||
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
||||
assert isinstance(module, nn.modules.conv._ConvNd)
|
||||
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
||||
else:
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
def get_extra_padding_for_conv1d(
|
||||
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
||||
) -> int:
|
||||
"""See `pad_for_conv1d`."""
|
||||
length = x.shape[-1]
|
||||
n_frames = (length - kernel_size + padding_total) / stride + 1
|
||||
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
||||
return ideal_length - length
|
||||
|
||||
|
||||
def pad_for_conv1d(
|
||||
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
||||
):
|
||||
"""Pad for a convolution to make sure that the last window is full.
|
||||
Extra padding is added at the end. This is required to ensure that we can rebuild
|
||||
an output of the same length, as otherwise, even with padding, some time steps
|
||||
might get removed.
|
||||
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
||||
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
||||
1 2 3 # (output frames of a convolution, last 0 is never used)
|
||||
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
||||
1 2 3 4 # once you removed padding, we are missing one time step !
|
||||
"""
|
||||
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
||||
return F.pad(x, (0, extra_padding))
|
||||
|
||||
|
||||
def pad1d(
|
||||
x: torch.Tensor,
|
||||
paddings: tp.Tuple[int, int],
|
||||
mode: str = "zero",
|
||||
value: float = 0.0,
|
||||
):
|
||||
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
||||
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
if mode == "reflect":
|
||||
max_pad = max(padding_left, padding_right)
|
||||
extra_pad = 0
|
||||
if length <= max_pad:
|
||||
extra_pad = max_pad - length + 1
|
||||
x = F.pad(x, (0, extra_pad))
|
||||
padded = F.pad(x, paddings, mode, value)
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
else:
|
||||
return F.pad(x, paddings, mode, value)
|
||||
|
||||
|
||||
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
||||
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
assert (padding_left + padding_right) <= x.shape[-1]
|
||||
end = x.shape[-1] - padding_right
|
||||
return x[..., padding_left:end]
|
||||
|
||||
|
||||
class NormConv1d(nn.Module):
|
||||
"""Wrapper around Conv1d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConv2d(nn.Module):
|
||||
"""Wrapper around Conv2d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
norm: str = "none",
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConvTranspose1d(nn.Module):
|
||||
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.convtr = apply_parametrization_norm(
|
||||
nn.ConvTranspose1d(*args, **kwargs), norm
|
||||
)
|
||||
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convtr(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConvTranspose2d(nn.Module):
|
||||
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
norm: str = "none",
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.convtr = apply_parametrization_norm(
|
||||
nn.ConvTranspose2d(*args, **kwargs), norm
|
||||
)
|
||||
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convtr(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class SConv1d(nn.Module):
|
||||
"""Conv1d with some builtin handling of asymmetric or causal padding
|
||||
and normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
pad_mode: str = "reflect",
|
||||
):
|
||||
super().__init__()
|
||||
# warn user on unusual setup between dilation and stride
|
||||
if stride > 1 and dilation > 1:
|
||||
warnings.warn(
|
||||
"SConv1d has been initialized with stride > 1 and dilation > 1"
|
||||
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
||||
)
|
||||
self.conv = NormConv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_kwargs,
|
||||
)
|
||||
self.causal = causal
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
kernel_size = self.conv.conv.kernel_size[0]
|
||||
stride = self.conv.conv.stride[0]
|
||||
dilation = self.conv.conv.dilation[0]
|
||||
padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
||||
extra_padding = get_extra_padding_for_conv1d(
|
||||
x, kernel_size, stride, padding_total
|
||||
)
|
||||
if self.causal:
|
||||
# Left padding for causal
|
||||
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
x = pad1d(
|
||||
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
||||
)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class SConvTranspose1d(nn.Module):
|
||||
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
||||
and normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
trim_right_ratio: float = 1.0,
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
):
|
||||
super().__init__()
|
||||
self.convtr = NormConvTranspose1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
causal=causal,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_kwargs,
|
||||
)
|
||||
self.causal = causal
|
||||
self.trim_right_ratio = trim_right_ratio
|
||||
assert (
|
||||
self.causal or self.trim_right_ratio == 1.0
|
||||
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
||||
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
||||
|
||||
def forward(self, x):
|
||||
kernel_size = self.convtr.convtr.kernel_size[0]
|
||||
stride = self.convtr.convtr.stride[0]
|
||||
padding_total = kernel_size - stride
|
||||
|
||||
y = self.convtr(x)
|
||||
|
||||
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
||||
# removed at the very end, when keeping only the right length for the output,
|
||||
# as removing it here would require also passing the length at the matching layer
|
||||
# in the encoder.
|
||||
if self.causal:
|
||||
# Trim the padding on the right according to the specified ratio
|
||||
# if trim_right_ratio = 1.0, trim everything from right
|
||||
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
||||
padding_left = padding_total - padding_right
|
||||
y = unpad1d(y, (padding_left, padding_right))
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
y = unpad1d(y, (padding_left, padding_right))
|
||||
return y
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""LSTM layers module."""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SLSTM(nn.Module):
|
||||
"""
|
||||
LSTM without worrying about the hidden state, nor the layout of the data.
|
||||
Expects input as convolutional layout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int,
|
||||
num_layers: int = 2,
|
||||
skip: bool = True,
|
||||
bidirectional: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.bidirectional = bidirectional
|
||||
self.skip = skip
|
||||
self.lstm = nn.LSTM(
|
||||
dimension, dimension, num_layers, bidirectional=bidirectional
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(2, 0, 1)
|
||||
y, _ = self.lstm(x)
|
||||
if self.bidirectional:
|
||||
x = x.repeat(1, 1, 2)
|
||||
if self.skip:
|
||||
y = y + x
|
||||
y = y.permute(1, 2, 0)
|
||||
return y
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Normalization modules."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvLayerNorm(nn.LayerNorm):
|
||||
"""
|
||||
Convolution-friendly LayerNorm that moves channels to last dimensions
|
||||
before running the normalization and moves them back to original position right after.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
|
||||
):
|
||||
super().__init__(normalized_shape, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = einops.rearrange(x, "b ... t -> b t ...")
|
||||
x = super().forward(x)
|
||||
x = einops.rearrange(x, "b t ... -> b ... t")
|
||||
return
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from .vq import QuantizedResult, ResidualVectorQuantizer
|
||||
@@ -0,0 +1,317 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Arithmetic coder."""
|
||||
|
||||
import io
|
||||
import math
|
||||
import random
|
||||
import typing as tp
|
||||
import torch
|
||||
|
||||
from ..binary import BitPacker, BitUnpacker
|
||||
|
||||
|
||||
def build_stable_quantized_cdf(
|
||||
pdf: torch.Tensor,
|
||||
total_range_bits: int,
|
||||
roundoff: float = 1e-8,
|
||||
min_range: int = 2,
|
||||
check: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Turn the given PDF into a quantized CDF that splits
|
||||
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
||||
to the PDF.
|
||||
|
||||
Args:
|
||||
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
||||
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
||||
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
||||
roundoff (float): will round the pdf up to that level to remove difference coming
|
||||
from e.g. evaluating the Language Model on different architectures.
|
||||
min_range (int): minimum range width. Should always be at least 2 for numerical
|
||||
stability. Use this to avoid pathological behavior is a value
|
||||
that is expected to be rare actually happens in real life.
|
||||
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
||||
"""
|
||||
pdf = pdf.detach()
|
||||
if roundoff:
|
||||
pdf = (pdf / roundoff).floor() * roundoff
|
||||
# interpolate with uniform distribution to achieve desired minimum probability.
|
||||
total_range = 2**total_range_bits
|
||||
cardinality = len(pdf)
|
||||
alpha = min_range * cardinality / total_range
|
||||
assert alpha <= 1, "you must reduce min_range"
|
||||
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
||||
ranges += min_range
|
||||
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
||||
if min_range < 2:
|
||||
raise ValueError("min_range must be at least 2.")
|
||||
if check:
|
||||
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
|
||||
if (
|
||||
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
|
||||
).any() or quantized_cdf[0] < min_range:
|
||||
raise ValueError("You must increase your total_range_bits.")
|
||||
return quantized_cdf
|
||||
|
||||
|
||||
class ArithmeticCoder:
|
||||
"""ArithmeticCoder,
|
||||
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
||||
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
||||
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
||||
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
||||
sequence `(s_t)` by doing the following:
|
||||
|
||||
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
||||
2) For each time step t, split the current range into contiguous chunks,
|
||||
one for each possible outcome, with size roughly proportional to `p`.
|
||||
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
||||
would be `{[0, 2], [3, 3]}`.
|
||||
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
||||
4) When done encoding all the values, just select any value remaining in the range.
|
||||
|
||||
You will notice that this procedure can fail: for instance if at any point in time
|
||||
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
||||
possible outcome. Intuitively, the more likely a value is, the less the range width
|
||||
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
||||
coding scheme, likely outcomes would take less bits, and more of them can be coded
|
||||
with a fixed budget.
|
||||
|
||||
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
||||
when the current range decreases below a given limit (given by `total_range_bits`), without
|
||||
having to redo all the computations. If we encode mostly likely values, we will seldom
|
||||
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
||||
|
||||
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
||||
code works for any sequence `(p_t)` possibly different for each timestep.
|
||||
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
||||
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
||||
|
||||
Args:
|
||||
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
||||
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
||||
Any time the current range width fall under this limit, new bits will
|
||||
be injected to rescale the initial range.
|
||||
"""
|
||||
|
||||
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
||||
assert total_range_bits <= 30
|
||||
self.total_range_bits = total_range_bits
|
||||
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
||||
self.low: int = 0
|
||||
self.high: int = 0
|
||||
self.max_bit: int = -1
|
||||
self._dbg: tp.List[tp.Any] = []
|
||||
self._dbg2: tp.List[tp.Any] = []
|
||||
|
||||
@property
|
||||
def delta(self) -> int:
|
||||
"""Return the current range width."""
|
||||
return self.high - self.low + 1
|
||||
|
||||
def _flush_common_prefix(self):
|
||||
# If self.low and self.high start with the sames bits,
|
||||
# those won't change anymore as we always just increase the range
|
||||
# by powers of 2, and we can flush them out to the bit stream.
|
||||
assert self.high >= self.low, (self.low, self.high)
|
||||
assert self.high < 2 ** (self.max_bit + 1)
|
||||
while self.max_bit >= 0:
|
||||
b1 = self.low >> self.max_bit
|
||||
b2 = self.high >> self.max_bit
|
||||
if b1 == b2:
|
||||
self.low -= b1 << self.max_bit
|
||||
self.high -= b1 << self.max_bit
|
||||
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
||||
assert self.low >= 0
|
||||
self.max_bit -= 1
|
||||
self.packer.push(b1)
|
||||
else:
|
||||
break
|
||||
|
||||
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
||||
"""Push the given symbol on the stream, flushing out bits
|
||||
if possible.
|
||||
|
||||
Args:
|
||||
symbol (int): symbol to encode with the AC.
|
||||
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
||||
to build this from your pdf estimate.
|
||||
"""
|
||||
while self.delta < 2**self.total_range_bits:
|
||||
self.low *= 2
|
||||
self.high = self.high * 2 + 1
|
||||
self.max_bit += 1
|
||||
|
||||
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
||||
range_high = quantized_cdf[symbol].item() - 1
|
||||
effective_low = int(
|
||||
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
effective_high = int(
|
||||
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
assert self.low <= self.high
|
||||
self.high = self.low + effective_high
|
||||
self.low = self.low + effective_low
|
||||
assert self.low <= self.high, (
|
||||
effective_low,
|
||||
effective_high,
|
||||
range_low,
|
||||
range_high,
|
||||
)
|
||||
self._dbg.append((self.low, self.high))
|
||||
self._dbg2.append((self.low, self.high))
|
||||
outs = self._flush_common_prefix()
|
||||
assert self.low <= self.high
|
||||
assert self.max_bit >= -1
|
||||
assert self.max_bit <= 61, self.max_bit
|
||||
return outs
|
||||
|
||||
def flush(self):
|
||||
"""Flush the remaining information to the stream."""
|
||||
while self.max_bit >= 0:
|
||||
b1 = (self.low >> self.max_bit) & 1
|
||||
self.packer.push(b1)
|
||||
self.max_bit -= 1
|
||||
self.packer.flush()
|
||||
|
||||
|
||||
class ArithmeticDecoder:
|
||||
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
||||
|
||||
Note that this must be called with **exactly** the same parameters and sequence
|
||||
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
||||
|
||||
If the AC encoder current range is [L, H], with `L` and `H` having the some common
|
||||
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
||||
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
||||
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
||||
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
||||
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
||||
and we will need to read new bits from the stream and repeat the process.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
||||
self.total_range_bits = total_range_bits
|
||||
self.low: int = 0
|
||||
self.high: int = 0
|
||||
self.current: int = 0
|
||||
self.max_bit: int = -1
|
||||
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
||||
# Following is for debugging
|
||||
self._dbg: tp.List[tp.Any] = []
|
||||
self._dbg2: tp.List[tp.Any] = []
|
||||
self._last: tp.Any = None
|
||||
|
||||
@property
|
||||
def delta(self) -> int:
|
||||
return self.high - self.low + 1
|
||||
|
||||
def _flush_common_prefix(self):
|
||||
# Given the current range [L, H], if both have a common prefix,
|
||||
# we know we can remove it from our representation to avoid handling large numbers.
|
||||
while self.max_bit >= 0:
|
||||
b1 = self.low >> self.max_bit
|
||||
b2 = self.high >> self.max_bit
|
||||
if b1 == b2:
|
||||
self.low -= b1 << self.max_bit
|
||||
self.high -= b1 << self.max_bit
|
||||
self.current -= b1 << self.max_bit
|
||||
assert self.high >= self.low
|
||||
assert self.low >= 0
|
||||
self.max_bit -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
||||
"""Pull a symbol, reading as many bits from the stream as required.
|
||||
This returns `None` when the stream has been exhausted.
|
||||
|
||||
Args:
|
||||
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
||||
to build this from your pdf estimate. This must be **exatly**
|
||||
the same cdf as the one used at encoding time.
|
||||
"""
|
||||
while self.delta < 2**self.total_range_bits:
|
||||
bit = self.unpacker.pull()
|
||||
if bit is None:
|
||||
return None
|
||||
self.low *= 2
|
||||
self.high = self.high * 2 + 1
|
||||
self.current = self.current * 2 + bit
|
||||
self.max_bit += 1
|
||||
|
||||
def bin_search(low_idx: int, high_idx: int):
|
||||
# Binary search is not just for coding interviews :)
|
||||
if high_idx < low_idx:
|
||||
raise RuntimeError("Binary search failed")
|
||||
mid = (low_idx + high_idx) // 2
|
||||
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
||||
range_high = quantized_cdf[mid].item() - 1
|
||||
effective_low = int(
|
||||
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
effective_high = int(
|
||||
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
low = effective_low + self.low
|
||||
high = effective_high + self.low
|
||||
if self.current >= low:
|
||||
if self.current <= high:
|
||||
return (mid, low, high, self.current)
|
||||
else:
|
||||
return bin_search(mid + 1, high_idx)
|
||||
else:
|
||||
return bin_search(low_idx, mid - 1)
|
||||
|
||||
self._last = (self.low, self.high, self.current, self.max_bit)
|
||||
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
||||
self._dbg.append((self.low, self.high, self.current))
|
||||
self._flush_common_prefix()
|
||||
self._dbg2.append((self.low, self.high, self.current))
|
||||
|
||||
return sym
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(1234)
|
||||
random.seed(1234)
|
||||
for _ in range(4):
|
||||
pdfs = []
|
||||
cardinality = random.randrange(4000)
|
||||
steps = random.randrange(100, 500)
|
||||
fo = io.BytesIO()
|
||||
encoder = ArithmeticCoder(fo)
|
||||
symbols = []
|
||||
for step in range(steps):
|
||||
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
||||
pdfs.append(pdf)
|
||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
||||
symbol = torch.multinomial(pdf, 1).item()
|
||||
symbols.append(symbol)
|
||||
encoder.push(symbol, q_cdf)
|
||||
encoder.flush()
|
||||
|
||||
fo.seek(0)
|
||||
decoder = ArithmeticDecoder(fo)
|
||||
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
||||
decoded_symbol = decoder.pull(q_cdf)
|
||||
assert decoded_symbol == symbol, idx
|
||||
assert decoder.pull(torch.zeros(1)) is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
@@ -0,0 +1,388 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This implementation is inspired from
|
||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
||||
# which is released under MIT License. Hereafter, the original license:
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .distrib import broadcast_tensors, rank
|
||||
|
||||
|
||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay: float):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
||||
|
||||
|
||||
def uniform_init(*shape: int):
|
||||
t = torch.empty(shape)
|
||||
nn.init.kaiming_uniform_(t)
|
||||
return t
|
||||
|
||||
|
||||
def sample_vectors(samples, num: int):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
||||
dists = -(diffs**2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: int = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
||||
uniform_init if not kmeans_init else torch.zeros
|
||||
)
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.epsilon = epsilon
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
||||
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def preprocess(self, x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
return x
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def postprocess_emb(self, embed_ind, shape):
|
||||
return embed_ind.view(*shape[:-1])
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
return quantize
|
||||
|
||||
def encode(self, x):
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = self.preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = self.postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = self.preprocess(x)
|
||||
|
||||
self.init_embed_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = self.postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class VectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation.
|
||||
Currently supports only euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
commitment_weight: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self._codebook = EuclideanCodebook(
|
||||
dim=_codebook_dim,
|
||||
codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init,
|
||||
kmeans_iters=kmeans_iters,
|
||||
decay=decay,
|
||||
epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
||||
)
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
def encode(self, x):
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self._codebook.decode(embed_ind)
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
|
||||
quantize, embed_ind = self._codebook(x)
|
||||
|
||||
if self.training:
|
||||
quantize = x + (quantize - x).detach()
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class ResidualVectorQuantization(nn.Module):
|
||||
"""Residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
||||
):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
out_quantized = []
|
||||
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
for i, layer in enumerate(self.layers[:n_q]):
|
||||
quantized, indices, loss = layer(residual)
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
if layers and i in layers:
|
||||
out_quantized.append(quantized)
|
||||
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses, out_quantized
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
st = st or 0
|
||||
for layer in self.layers[st:n_q]:
|
||||
indices = layer.encode(residual)
|
||||
quantized = layer.decode(indices)
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[st + i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Torch distributed utilities."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def rank():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def world_size():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def is_distributed():
|
||||
return world_size() > 1
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||||
if is_distributed():
|
||||
return torch.distributed.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def _is_complex_or_float(tensor):
|
||||
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||
|
||||
|
||||
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||||
# utility function to check that the number of params in all workers is the same,
|
||||
# and thus avoid a deadlock with distributed all reduce.
|
||||
if not is_distributed() or not params:
|
||||
return
|
||||
# print('params[0].device ', params[0].device)
|
||||
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||
all_reduce(tensor)
|
||||
if tensor.item() != len(params) * world_size():
|
||||
# If not all the workers have the same number, for at least one of them,
|
||||
# this inequality will be verified.
|
||||
raise RuntimeError(
|
||||
f"Mismatch in number of params: ours is {len(params)}, "
|
||||
"at least one worker has a different one."
|
||||
)
|
||||
|
||||
|
||||
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||||
"""Broadcast the tensors from the given parameters to all workers.
|
||||
This can be used to ensure that all workers have the same model to start with.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||
_check_number_of_params(tensors)
|
||||
handles = []
|
||||
for tensor in tensors:
|
||||
# src = int(rank()) # added code
|
||||
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||||
handles.append(handle)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
|
||||
def sync_buffer(buffers, average=True):
|
||||
"""
|
||||
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for buffer in buffers:
|
||||
if torch.is_floating_point(buffer.data):
|
||||
if average:
|
||||
handle = torch.distributed.all_reduce(
|
||||
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
|
||||
)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
||||
handles.append((buffer, handle))
|
||||
for buffer, handle in handles:
|
||||
handle.wait()
|
||||
if average:
|
||||
buffer.data /= world_size
|
||||
|
||||
|
||||
def sync_grad(params):
|
||||
"""
|
||||
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||
on any black magic. For simple models it can also be as fast.
|
||||
Just call this on your model parameters after the call to backward!
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
handle = torch.distributed.all_reduce(
|
||||
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
|
||||
)
|
||||
handles.append((p, handle))
|
||||
for p, handle in handles:
|
||||
handle.wait()
|
||||
p.grad.data /= world_size()
|
||||
|
||||
|
||||
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
||||
"""Average a dictionary of metrics across all workers, using the optional
|
||||
`count` as unormalized weight.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return metrics
|
||||
keys, values = zip(*metrics.items())
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||
tensor *= count
|
||||
all_reduce(tensor)
|
||||
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||
return dict(zip(keys, averaged))
|
||||
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Residual vector quantizer implementation."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .core_vq import ResidualVectorQuantization
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizedResult:
|
||||
quantized: torch.Tensor
|
||||
codes: torch.Tensor
|
||||
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
||||
penalty: tp.Optional[torch.Tensor] = None
|
||||
metrics: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class ResidualVectorQuantizer(nn.Module):
|
||||
"""Residual Vector Quantizer.
|
||||
Args:
|
||||
dimension (int): Dimension of the codebooks.
|
||||
n_q (int): Number of residual vector quantizers used.
|
||||
bins (int): Codebook size.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int = 256,
|
||||
n_q: int = 8,
|
||||
bins: int = 1024,
|
||||
decay: float = 0.99,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_q = n_q
|
||||
self.dimension = dimension
|
||||
self.bins = bins
|
||||
self.decay = decay
|
||||
self.kmeans_init = kmeans_init
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.vq = ResidualVectorQuantization(
|
||||
dim=self.dimension,
|
||||
codebook_size=self.bins,
|
||||
num_quantizers=self.n_q,
|
||||
decay=self.decay,
|
||||
kmeans_init=self.kmeans_init,
|
||||
kmeans_iters=self.kmeans_iters,
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n_q: tp.Optional[int] = None,
|
||||
layers: tp.Optional[list] = None,
|
||||
) -> QuantizedResult:
|
||||
"""Residual vector quantization on the given input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
||||
layers (list): Layer that need to return quantized. Defalt: None.
|
||||
Returns:
|
||||
QuantizedResult:
|
||||
The quantized (or approximately quantized) representation with
|
||||
the associated numbert quantizers and layer quantized required to return.
|
||||
"""
|
||||
n_q = n_q if n_q else self.n_q
|
||||
if layers and max(layers) >= n_q:
|
||||
raise ValueError(
|
||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
||||
x, n_q=n_q, layers=layers
|
||||
)
|
||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
||||
st (int): Start to encode input from which layers. Default: 0.
|
||||
"""
|
||||
n_q = n_q if n_q else self.n_q
|
||||
st = st or 0
|
||||
codes = self.vq.encode(x, n_q=n_q, st=st)
|
||||
return codes
|
||||
|
||||
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
|
||||
"""Decode the given codes to the quantized representation.
|
||||
Args:
|
||||
codes (torch.Tensor): Input indices for each quantizer.
|
||||
st (int): Start to decode input codes from which layers. Default: 0.
|
||||
"""
|
||||
quantized = self.vq.decode(codes, st=st)
|
||||
return quantized
|
||||
@@ -0,0 +1,414 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# This source file is copied from https://github.com/facebookresearch/encodec
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Encodec SEANet-based encoder and decoder implementation."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from . import SConv1d, SConvTranspose1d, SLSTM
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
class SEANetResnetBlock(nn.Module):
|
||||
"""Residual block from SEANet model.
|
||||
Args:
|
||||
dim (int): Dimension of the input/output
|
||||
kernel_sizes (list): List of kernel sizes for the convolutions.
|
||||
dilations (list): List of dilations for the convolutions.
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
||||
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_sizes: tp.List[int] = [3, 1],
|
||||
dilations: tp.List[int] = [1, 1],
|
||||
activation: str = "ELU",
|
||||
activation_params: dict = {"alpha": 1.0},
|
||||
norm: str = "weight_norm",
|
||||
norm_params: tp.Dict[str, tp.Any] = {},
|
||||
causal: bool = False,
|
||||
pad_mode: str = "reflect",
|
||||
compress: int = 2,
|
||||
true_skip: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
assert len(kernel_sizes) == len(
|
||||
dilations
|
||||
), "Number of kernel sizes should match number of dilations"
|
||||
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
||||
hidden = dim // compress
|
||||
block = []
|
||||
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
||||
in_chs = dim if i == 0 else hidden
|
||||
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
||||
block += [
|
||||
act(**activation_params) if activation != "Snake" else act(in_chs),
|
||||
SConv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
self.block = nn.Sequential(*block)
|
||||
self.shortcut: nn.Module
|
||||
if true_skip:
|
||||
self.shortcut = nn.Identity()
|
||||
else:
|
||||
self.shortcut = SConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=1,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.shortcut(x) + self.block(x)
|
||||
|
||||
|
||||
class SEANetEncoder(nn.Module):
|
||||
"""SEANet encoder.
|
||||
Args:
|
||||
channels (int): Audio channels.
|
||||
dimension (int): Intermediate representation dimension.
|
||||
n_filters (int): Base width for the model.
|
||||
n_residual_layers (int): nb of residual layers.
|
||||
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
||||
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
||||
that must match the decoder order
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
kernel_size (int): Kernel size for the initial convolution.
|
||||
last_kernel_size (int): Kernel size for the initial convolution.
|
||||
residual_kernel_size (int): Kernel size for the residual layers.
|
||||
dilation_base (int): How much to increase the dilation with each layer.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
true_skip (bool): Whether to use true skip connection or a simple
|
||||
(streamable) convolution as the skip connection in the residual network blocks.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
lstm (int): Number of LSTM layers at the end of the encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
dimension: int = 128,
|
||||
n_filters: int = 32,
|
||||
n_residual_layers: int = 1,
|
||||
ratios: tp.List[int] = [8, 5, 4, 2],
|
||||
activation: str = "ELU",
|
||||
activation_params: dict = {"alpha": 1.0},
|
||||
norm: str = "weight_norm",
|
||||
norm_params: tp.Dict[str, tp.Any] = {},
|
||||
kernel_size: int = 7,
|
||||
last_kernel_size: int = 7,
|
||||
residual_kernel_size: int = 3,
|
||||
dilation_base: int = 2,
|
||||
causal: bool = False,
|
||||
pad_mode: str = "reflect",
|
||||
true_skip: bool = False,
|
||||
compress: int = 2,
|
||||
lstm: int = 2,
|
||||
bidirectional: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dimension = dimension
|
||||
self.n_filters = n_filters
|
||||
self.ratios = list(reversed(ratios))
|
||||
del ratios
|
||||
self.n_residual_layers = n_residual_layers
|
||||
self.hop_length = np.prod(self.ratios) # 计算乘积
|
||||
|
||||
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
||||
mult = 1
|
||||
model: tp.List[nn.Module] = [
|
||||
SConv1d(
|
||||
channels,
|
||||
mult * n_filters,
|
||||
kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
]
|
||||
# Downsample to raw audio scale
|
||||
for i, ratio in enumerate(self.ratios):
|
||||
# Add residual layers
|
||||
for j in range(n_residual_layers):
|
||||
model += [
|
||||
SEANetResnetBlock(
|
||||
mult * n_filters,
|
||||
kernel_sizes=[residual_kernel_size, 1],
|
||||
dilations=[dilation_base**j, 1],
|
||||
norm=norm,
|
||||
norm_params=norm_params,
|
||||
activation=activation,
|
||||
activation_params=activation_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
compress=compress,
|
||||
true_skip=true_skip,
|
||||
)
|
||||
]
|
||||
|
||||
# Add downsampling layers
|
||||
model += [
|
||||
(
|
||||
act(**activation_params)
|
||||
if activation != "Snake"
|
||||
else act(mult * n_filters)
|
||||
),
|
||||
SConv1d(
|
||||
mult * n_filters,
|
||||
mult * n_filters * 2,
|
||||
kernel_size=ratio * 2,
|
||||
stride=ratio,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
mult *= 2
|
||||
|
||||
if lstm:
|
||||
model += [
|
||||
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
|
||||
]
|
||||
|
||||
mult = mult * 2 if bidirectional else mult
|
||||
model += [
|
||||
(
|
||||
act(**activation_params)
|
||||
if activation != "Snake"
|
||||
else act(mult * n_filters)
|
||||
),
|
||||
SConv1d(
|
||||
mult * n_filters,
|
||||
dimension,
|
||||
last_kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class SEANetDecoder(nn.Module):
|
||||
"""SEANet decoder.
|
||||
Args:
|
||||
channels (int): Audio channels.
|
||||
dimension (int): Intermediate representation dimension.
|
||||
n_filters (int): Base width for the model.
|
||||
n_residual_layers (int): nb of residual layers.
|
||||
ratios (Sequence[int]): kernel size and stride ratios
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function
|
||||
final_activation (str): Final activation function after all convolutions.
|
||||
final_activation_params (dict): Parameters to provide to the activation function
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
kernel_size (int): Kernel size for the initial convolution.
|
||||
last_kernel_size (int): Kernel size for the initial convolution.
|
||||
residual_kernel_size (int): Kernel size for the residual layers.
|
||||
dilation_base (int): How much to increase the dilation with each layer.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
true_skip (bool): Whether to use true skip connection or a simple
|
||||
(streamable) convolution as the skip connection in the residual network blocks.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
lstm (int): Number of LSTM layers at the end of the encoder.
|
||||
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
||||
If equal to 1.0, it means that all the trimming is done at the right.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
dimension: int = 128,
|
||||
n_filters: int = 32,
|
||||
n_residual_layers: int = 1,
|
||||
ratios: tp.List[int] = [8, 5, 4, 2],
|
||||
activation: str = "ELU",
|
||||
activation_params: dict = {"alpha": 1.0},
|
||||
final_activation: tp.Optional[str] = None,
|
||||
final_activation_params: tp.Optional[dict] = None,
|
||||
norm: str = "weight_norm",
|
||||
norm_params: tp.Dict[str, tp.Any] = {},
|
||||
kernel_size: int = 7,
|
||||
last_kernel_size: int = 7,
|
||||
residual_kernel_size: int = 3,
|
||||
dilation_base: int = 2,
|
||||
causal: bool = False,
|
||||
pad_mode: str = "reflect",
|
||||
true_skip: bool = False,
|
||||
compress: int = 2,
|
||||
lstm: int = 2,
|
||||
trim_right_ratio: float = 1.0,
|
||||
bidirectional: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
self.channels = channels
|
||||
self.n_filters = n_filters
|
||||
self.ratios = ratios
|
||||
del ratios
|
||||
self.n_residual_layers = n_residual_layers
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
|
||||
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
||||
mult = int(2 ** len(self.ratios))
|
||||
model: tp.List[nn.Module] = [
|
||||
SConv1d(
|
||||
dimension,
|
||||
mult * n_filters,
|
||||
kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
]
|
||||
|
||||
if lstm:
|
||||
model += [
|
||||
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
|
||||
]
|
||||
|
||||
# Upsample to raw audio scale
|
||||
for i, ratio in enumerate(self.ratios):
|
||||
# Add upsampling layers
|
||||
model += [
|
||||
(
|
||||
act(**activation_params)
|
||||
if activation != "Snake"
|
||||
else act(mult * n_filters)
|
||||
),
|
||||
SConvTranspose1d(
|
||||
mult * n_filters,
|
||||
mult * n_filters // 2,
|
||||
kernel_size=ratio * 2,
|
||||
stride=ratio,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
trim_right_ratio=trim_right_ratio,
|
||||
),
|
||||
]
|
||||
# Add residual layers
|
||||
for j in range(n_residual_layers):
|
||||
model += [
|
||||
SEANetResnetBlock(
|
||||
mult * n_filters // 2,
|
||||
kernel_sizes=[residual_kernel_size, 1],
|
||||
dilations=[dilation_base**j, 1],
|
||||
activation=activation,
|
||||
activation_params=activation_params,
|
||||
norm=norm,
|
||||
norm_params=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
compress=compress,
|
||||
true_skip=true_skip,
|
||||
)
|
||||
]
|
||||
|
||||
mult //= 2
|
||||
|
||||
# Add final layers
|
||||
model += [
|
||||
act(**activation_params) if activation != "Snake" else act(n_filters),
|
||||
SConv1d(
|
||||
n_filters,
|
||||
channels,
|
||||
last_kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
# Add optional final activation to decoder (eg. tanh)
|
||||
if final_activation is not None:
|
||||
final_act = getattr(nn, final_activation)
|
||||
final_activation_params = final_activation_params or {}
|
||||
model += [final_act(**final_activation_params)]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, z):
|
||||
y = self.model(z)
|
||||
return y
|
||||
|
||||
|
||||
def test():
|
||||
import torch
|
||||
|
||||
encoder = SEANetEncoder()
|
||||
decoder = SEANetDecoder()
|
||||
x = torch.randn(1, 1, 24000)
|
||||
z = encoder(x)
|
||||
print("z ", z.shape)
|
||||
assert 1 == 2
|
||||
assert list(z.shape) == [1, 128, 75], z.shape
|
||||
y = decoder(z)
|
||||
assert y.shape == x.shape, (x.shape, y.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
Reference in New Issue
Block a user