feat: Integrate IndexTTS2 model and update related schemas and frontend components

This commit is contained in:
2026-03-12 13:30:53 +08:00
parent e5b5a16364
commit 8aec4f6f44
151 changed files with 40077 additions and 85 deletions

View File

@@ -0,0 +1,219 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
"""
Implementation of model from:
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
Convolutional Recurrent Neural Networks" (2019)
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
"""
import torch
from torch import nn
class JDCNet(nn.Module):
"""
Joint Detection and Classification Network model for singing voice melody.
"""
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
super().__init__()
self.num_class = num_class
# input = (b, 1, 31, 513), b = batch size
self.conv_block = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
), # out: (b, 64, 31, 513)
nn.BatchNorm2d(num_features=64),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
)
# res blocks
self.res_block1 = ResBlock(
in_channels=64, out_channels=128
) # (b, 128, 31, 128)
self.res_block2 = ResBlock(
in_channels=128, out_channels=192
) # (b, 192, 31, 32)
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
# pool block
self.pool_block = nn.Sequential(
nn.BatchNorm2d(num_features=256),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
nn.Dropout(p=0.2),
)
# maxpool layers (for auxiliary network inputs)
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
self.detector_conv = nn.Sequential(
nn.Conv2d(640, 256, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.Dropout(p=0.2),
)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
self.bilstm_classifier = nn.LSTM(
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
) # (b, 31, 512)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
self.bilstm_detector = nn.LSTM(
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
) # (b, 31, 512)
# input: (b * 31, 512)
self.classifier = nn.Linear(
in_features=512, out_features=self.num_class
) # (b * 31, num_class)
# input: (b * 31, 512)
self.detector = nn.Linear(
in_features=512, out_features=2
) # (b * 31, 2) - binary classifier
# initialize weights
self.apply(self.init_weights)
def get_feature_GAN(self, x):
seq_len = x.shape[-2]
x = x.float().transpose(-1, -2)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
return poolblock_out.transpose(-1, -2)
def get_feature(self, x):
seq_len = x.shape[-2]
x = x.float().transpose(-1, -2)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
return self.pool_block[2](poolblock_out)
def forward(self, x):
"""
Returns:
classification_prediction, detection_prediction
sizes: (b, 31, 722), (b, 31, 2)
"""
###############################
# forward pass for classifier #
###############################
seq_len = x.shape[-1]
x = x.float().transpose(-1, -2)
convblock_out = self.conv_block(x)
resblock1_out = self.res_block1(convblock_out)
resblock2_out = self.res_block2(resblock1_out)
resblock3_out = self.res_block3(resblock2_out)
poolblock_out = self.pool_block[0](resblock3_out)
poolblock_out = self.pool_block[1](poolblock_out)
GAN_feature = poolblock_out.transpose(-1, -2)
poolblock_out = self.pool_block[2](poolblock_out)
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
classifier_out = (
poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
)
classifier_out, _ = self.bilstm_classifier(
classifier_out
) # ignore the hidden states
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
classifier_out = self.classifier(classifier_out)
classifier_out = classifier_out.view(
(-1, seq_len, self.num_class)
) # (b, 31, num_class)
# sizes: (b, 31, 722), (b, 31, 2)
# classifier output consists of predicted pitch classes per frame
# detector output consists of: (isvoice, notvoice) estimates per frame
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
@staticmethod
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
for p in m.parameters():
if p.data is None:
continue
if len(p.shape) >= 2:
nn.init.orthogonal_(p.data)
else:
nn.init.normal_(p.data)
class ResBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
super().__init__()
self.downsample = in_channels != out_channels
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
self.pre_conv = nn.Sequential(
nn.BatchNorm2d(num_features=in_channels),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
)
# conv layers
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(leaky_relu_slope, inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
)
# 1 x 1 convolution layer to match the feature dimensions
self.conv1by1 = None
if self.downsample:
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
def forward(self, x):
x = self.pre_conv(x)
if self.downsample:
x = self.conv(x) + self.conv1by1(x)
else:
x = self.conv(x) + x
return x

View File

@@ -0,0 +1,437 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
import copy
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from . import commons
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x

View File

@@ -0,0 +1,331 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import os.path
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from munch import Munch
import json
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def slice_segments_audio(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
dtype=torch.long
)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def log_norm(x, mean=-4, std=4, dim=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
return x
from huggingface_hub import hf_hub_download
def load_F0_models(path):
# load F0 model
from .JDC.model import JDCNet
F0_model = JDCNet(num_class=1, seq_len=192)
if not os.path.exists(path):
path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
params = torch.load(path, map_location="cpu")["net"]
F0_model.load_state_dict(params)
_ = F0_model.train()
return F0_model
# Generators
from modules.dac.model.dac import Encoder, Decoder
from .quantize import FAquantizer, FApredictors
# Discriminators
from modules.dac.model.discriminator import Discriminator
def build_model(args):
encoder = Encoder(
d_model=args.DAC.encoder_dim,
strides=args.DAC.encoder_rates,
d_latent=1024,
causal=args.causal,
lstm=args.lstm,
)
quantizer = FAquantizer(
in_dim=1024,
n_p_codebooks=1,
n_c_codebooks=args.n_c_codebooks,
n_t_codebooks=2,
n_r_codebooks=3,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0.5,
causal=args.causal,
separate_prosody_encoder=args.separate_prosody_encoder,
timbre_norm=args.timbre_norm,
)
fa_predictors = FApredictors(
in_dim=1024,
use_gr_content_f0=args.use_gr_content_f0,
use_gr_prosody_phone=args.use_gr_prosody_phone,
use_gr_residual_f0=True,
use_gr_residual_phone=True,
use_gr_timbre_content=True,
use_gr_timbre_prosody=args.use_gr_timbre_prosody,
use_gr_x_timbre=True,
norm_f0=args.norm_f0,
timbre_norm=args.timbre_norm,
use_gr_content_global_f0=args.use_gr_content_global_f0,
)
decoder = Decoder(
input_channel=1024,
channels=args.DAC.decoder_dim,
rates=args.DAC.decoder_rates,
causal=args.causal,
lstm=args.lstm,
)
discriminator = Discriminator(
rates=[],
periods=[2, 3, 5, 7, 11],
fft_sizes=[2048, 1024, 512],
sample_rate=args.DAC.sr,
bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
)
nets = Munch(
encoder=encoder,
quantizer=quantizer,
decoder=decoder,
discriminator=discriminator,
fa_predictors=fa_predictors,
)
return nets
def load_checkpoint(
model,
optimizer,
path,
load_only_params=True,
ignore_modules=[],
is_distributed=False,
):
state = torch.load(path, map_location="cpu")
params = state["net"]
for key in model:
if key in params and key not in ignore_modules:
if not is_distributed:
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
for k in list(params[key].keys()):
if k.startswith("module."):
params[key][k[len("module.") :]] = params[key][k]
del params[key][k]
print("%s loaded" % key)
model[key].load_state_dict(params[key], strict=True)
_ = [model[key].eval() for key in model]
if not load_only_params:
epoch = state["epoch"] + 1
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
optimizer.load_scheduler_state_dict(state["scheduler"])
else:
epoch = state["epoch"] + 1
iters = state["iters"]
return model, optimizer, epoch, iters
def recursive_munch(d):
if isinstance(d, dict):
return Munch((k, recursive_munch(v)) for k, v in d.items())
elif isinstance(d, list):
return [recursive_munch(v) for v in d]
else:
return d

View File

@@ -0,0 +1,35 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.autograd import Function
import torch
from torch import nn
class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x
@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -alpha * grad_output
return grad_input, None
revgrad = GradientReversal.apply
class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self.alpha)

View File

@@ -0,0 +1,460 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import nn
from typing import Optional, Any
from torch import Tensor
import torch.nn.functional as F
import torchaudio
import torchaudio.functional as audio_F
import random
random.seed(0)
def _get_activation_fn(activ):
if activ == "relu":
return nn.ReLU()
elif activ == "lrelu":
return nn.LeakyReLU(0.2)
elif activ == "swish":
return lambda x: x * torch.sigmoid(x)
else:
raise RuntimeError(
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
param=None,
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
)
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class CausualConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=1,
dilation=1,
bias=True,
w_init_gain="linear",
param=None,
):
super(CausualConv, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2) * 2
else:
self.padding = padding * 2
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
)
def forward(self, x):
x = self.conv(x)
x = x[:, :, : -self.padding]
return x
class CausualBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
super(CausualBlock, self).__init__()
self.blocks = nn.ModuleList(
[
self._get_conv(
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
)
for i in range(n_conv)
]
)
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
layers = [
CausualConv(
hidden_dim,
hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
_get_activation_fn(activ),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(p=dropout_p),
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p),
]
return nn.Sequential(*layers)
class ConvBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
super().__init__()
self._n_groups = 8
self.blocks = nn.ModuleList(
[
self._get_conv(
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
)
for i in range(n_conv)
]
)
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
layers = [
ConvNorm(
hidden_dim,
hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
_get_activation_fn(activ),
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
nn.Dropout(p=dropout_p),
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p),
]
return nn.Sequential(*layers)
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(
2,
attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding,
bias=False,
stride=1,
dilation=1,
)
self.location_dense = LinearNorm(
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
)
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(Attention, self).__init__()
self.query_layer = LinearNorm(
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = LinearNorm(
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size, attention_dim
)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = energies.squeeze(-1)
return energies
def forward(
self,
attention_hidden_state,
memory,
processed_memory,
attention_weights_cat,
mask,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class ForwardAttentionV2(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(ForwardAttentionV2, self).__init__()
self.query_layer = LinearNorm(
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = LinearNorm(
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size, attention_dim
)
self.score_mask_value = -float(1e20)
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = energies.squeeze(-1)
return energies
def forward(
self,
attention_hidden_state,
memory,
processed_memory,
attention_weights_cat,
mask,
log_alpha,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
log_energy = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
# log_energy =
if mask is not None:
log_energy.data.masked_fill_(mask, self.score_mask_value)
# attention_weights = F.softmax(alignment, dim=1)
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
# log_total_score = log_alpha + content_score
# previous_attention_weights = attention_weights_cat[:,0,:]
log_alpha_shift_padded = []
max_time = log_energy.size(1)
for sft in range(2):
shifted = log_alpha[:, : max_time - sft]
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
log_alpha_new = biased + log_energy
attention_weights = F.softmax(log_alpha_new, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights, log_alpha_new
class PhaseShuffle2d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle2d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = torch.cat([right, left], dim=3)
return shuffled
class PhaseShuffle1d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle1d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :move]
right = x[:, :, move:]
shuffled = torch.cat([right, left], dim=2)
return shuffled
class MFCC(nn.Module):
def __init__(self, n_mfcc=40, n_mels=80):
super(MFCC, self).__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = "ortho"
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer("dct_mat", dct_mat)
def forward(self, mel_specgram):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc

View File

@@ -0,0 +1,741 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from modules.dac.nn.quantize import ResidualVectorQuantize
from torch import nn
from .wavenet import WN
from .style_encoder import StyleEncoder
from .gradient_reversal import GradientReversal
import torch
import torchaudio
import torchaudio.functional as audio_F
import numpy as np
from ..alias_free_torch import *
from torch.nn.utils import weight_norm
from torch import nn, sin, pow
from einops.layers.torch import Rearrange
from modules.dac.model.encodec import SConv1d
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta := x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
return x + self.block(x)
class CNNLSTM(nn.Module):
def __init__(self, indim, outdim, head, global_pred=False):
super().__init__()
self.global_pred = global_pred
self.model = nn.Sequential(
ResidualUnit(indim, dilation=1),
ResidualUnit(indim, dilation=2),
ResidualUnit(indim, dilation=3),
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
Rearrange("b c t -> b t c"),
)
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
def forward(self, x):
# x: [B, C, T]
x = self.model(x)
if self.global_pred:
x = torch.mean(x, dim=1, keepdim=False)
outs = [head(x) for head in self.heads]
return outs
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
class MFCC(nn.Module):
def __init__(self, n_mfcc=40, n_mels=80):
super(MFCC, self).__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = "ortho"
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer("dct_mat", dct_mat)
def forward(self, mel_specgram):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc
class FAquantizer(nn.Module):
def __init__(
self,
in_dim=1024,
n_p_codebooks=1,
n_c_codebooks=2,
n_t_codebooks=2,
n_r_codebooks=3,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0.5,
causal=False,
separate_prosody_encoder=False,
timbre_norm=False,
):
super(FAquantizer, self).__init__()
conv1d_type = SConv1d # if causal else nn.Conv1d
self.prosody_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_p_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.content_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_c_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
if not timbre_norm:
self.timbre_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_t_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
else:
self.timbre_encoder = StyleEncoder(
in_dim=80, hidden_dim=512, out_dim=in_dim
)
self.timbre_linear = nn.Linear(1024, 1024 * 2)
self.timbre_linear.bias.data[:1024] = 1
self.timbre_linear.bias.data[1024:] = 0
self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
self.residual_quantizer = ResidualVectorQuantize(
input_dim=in_dim,
n_codebooks=n_r_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
if separate_prosody_encoder:
self.melspec_linear = conv1d_type(
in_channels=20, out_channels=256, kernel_size=1, causal=causal
)
self.melspec_encoder = WN(
hidden_channels=256,
kernel_size=5,
dilation_rate=1,
n_layers=8,
gin_channels=0,
p_dropout=0.2,
causal=causal,
)
self.melspec_linear2 = conv1d_type(
in_channels=256, out_channels=1024, kernel_size=1, causal=causal
)
else:
pass
self.separate_prosody_encoder = separate_prosody_encoder
self.prob_random_mask_residual = 0.75
SPECT_PARAMS = {
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
}
MEL_PARAMS = {
"n_mels": 80,
}
self.to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
)
self.mel_mean, self.mel_std = -4, 4
self.frame_rate = 24000 / 300
self.hop_length = 300
self.is_timbre_norm = timbre_norm
if timbre_norm:
self.forward = self.forward_v2
def preprocess(self, wave_tensor, n_bins=20):
mel_tensor = self.to_mel(wave_tensor.squeeze(1))
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
@torch.no_grad()
def decode(self, codes):
code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
z_c = self.content_quantizer.from_codes(code_c)[0]
z_p = self.prosody_quantizer.from_codes(code_p)[0]
z_t = self.timbre_quantizer.from_codes(code_t)[0]
z = z_c + z_p + z_t
return z, [z_c, z_p, z_t]
@torch.no_grad()
def encode(self, x, wave_segments, n_c=1):
outs = 0
if self.separate_prosody_encoder:
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(
f0_input,
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
.to(f0_input.device)
.bool(),
)
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(f0_input, 1)
outs += z_p.detach()
else:
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(x, 1)
outs += z_p.detach()
(
z_c,
codes_c,
latents_c,
commitment_loss_c,
codebook_loss_c,
) = self.content_quantizer(x, n_c)
outs += z_c.detach()
timbre_residual_feature = x - z_p.detach() - z_c.detach()
(
z_t,
codes_t,
latents_t,
commitment_loss_t,
codebook_loss_t,
) = self.timbre_quantizer(timbre_residual_feature, 2)
outs += z_t # we should not detach timbre
residual_feature = timbre_residual_feature - z_t
(
z_r,
codes_r,
latents_r,
commitment_loss_r,
codebook_loss_r,
) = self.residual_quantizer(residual_feature, 3)
return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
def forward(
self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
):
# timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
# timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
outs = 0
if self.separate_prosody_encoder:
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(
f0_input,
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
.to(f0_input.device)
.bool(),
)
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(f0_input, 1)
outs += z_p.detach()
else:
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(x, 1)
outs += z_p.detach()
(
z_c,
codes_c,
latents_c,
commitment_loss_c,
codebook_loss_c,
) = self.content_quantizer(x, n_c)
outs += z_c.detach()
timbre_residual_feature = x - z_p.detach() - z_c.detach()
(
z_t,
codes_t,
latents_t,
commitment_loss_t,
codebook_loss_t,
) = self.timbre_quantizer(timbre_residual_feature, n_t)
outs += z_t # we should not detach timbre
residual_feature = timbre_residual_feature - z_t
(
z_r,
codes_r,
latents_r,
commitment_loss_r,
codebook_loss_r,
) = self.residual_quantizer(residual_feature, 3)
bsz = z_r.shape[0]
res_mask = np.random.choice(
[0, 1],
size=bsz,
p=[
self.prob_random_mask_residual,
1 - self.prob_random_mask_residual,
],
)
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
noise_must_on = noise_added_flags * recon_noisy_flags
noise_must_off = noise_added_flags * (~recon_noisy_flags)
res_mask[noise_must_on] = 1
res_mask[noise_must_off] = 0
outs += z_r * res_mask
quantized = [z_p, z_c, z_t, z_r]
commitment_losses = (
commitment_loss_p
+ commitment_loss_c
+ commitment_loss_t
+ commitment_loss_r
)
codebook_losses = (
codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
)
return outs, quantized, commitment_losses, codebook_losses
def forward_v2(
self,
x,
wave_segments,
n_c=1,
n_t=2,
full_waves=None,
wave_lens=None,
return_codes=False,
):
# timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
if full_waves is None:
mel = self.preprocess(wave_segments, n_bins=80)
timbre = self.timbre_encoder(
mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
)
else:
mel = self.preprocess(full_waves, n_bins=80)
timbre = self.timbre_encoder(
mel,
sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
)
outs = 0
if self.separate_prosody_encoder:
prosody_feature = self.preprocess(wave_segments)
f0_input = prosody_feature # (B, T, 20)
f0_input = self.melspec_linear(f0_input)
f0_input = self.melspec_encoder(
f0_input,
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
.to(f0_input.device)
.bool(),
)
f0_input = self.melspec_linear2(f0_input)
common_min_size = min(f0_input.size(2), x.size(2))
f0_input = f0_input[:, :, :common_min_size]
x = x[:, :, :common_min_size]
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(f0_input, 1)
outs += z_p.detach()
else:
(
z_p,
codes_p,
latents_p,
commitment_loss_p,
codebook_loss_p,
) = self.prosody_quantizer(x, 1)
outs += z_p.detach()
(
z_c,
codes_c,
latents_c,
commitment_loss_c,
codebook_loss_c,
) = self.content_quantizer(x, n_c)
outs += z_c.detach()
residual_feature = x - z_p.detach() - z_c.detach()
(
z_r,
codes_r,
latents_r,
commitment_loss_r,
codebook_loss_r,
) = self.residual_quantizer(residual_feature, 3)
bsz = z_r.shape[0]
res_mask = np.random.choice(
[0, 1],
size=bsz,
p=[
self.prob_random_mask_residual,
1 - self.prob_random_mask_residual,
],
)
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
if not self.training:
res_mask = torch.ones_like(res_mask)
outs += z_r * res_mask
quantized = [z_p, z_c, z_r]
codes = [codes_p, codes_c, codes_r]
commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
outs = outs.transpose(1, 2)
outs = self.timbre_norm(outs)
outs = outs.transpose(1, 2)
outs = outs * gamma + beta
if return_codes:
return outs, quantized, commitment_losses, codebook_losses, timbre, codes
else:
return outs, quantized, commitment_losses, codebook_losses, timbre
def voice_conversion(self, z, ref_wave):
ref_mel = self.preprocess(ref_wave, n_bins=80)
ref_timbre = self.timbre_encoder(
ref_mel,
sequence_mask(
torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
ref_mel.size(-1),
).unsqueeze(1),
)
style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
outs = z.transpose(1, 2)
outs = self.timbre_norm(outs)
outs = outs.transpose(1, 2)
outs = outs * gamma + beta
return outs
class FApredictors(nn.Module):
def __init__(
self,
in_dim=1024,
use_gr_content_f0=False,
use_gr_prosody_phone=False,
use_gr_residual_f0=False,
use_gr_residual_phone=False,
use_gr_timbre_content=True,
use_gr_timbre_prosody=True,
use_gr_x_timbre=False,
norm_f0=True,
timbre_norm=False,
use_gr_content_global_f0=False,
):
super(FApredictors, self).__init__()
self.f0_predictor = CNNLSTM(in_dim, 1, 2)
self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
if timbre_norm:
self.timbre_predictor = nn.Linear(in_dim, 20000)
else:
self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
self.use_gr_content_f0 = use_gr_content_f0
self.use_gr_prosody_phone = use_gr_prosody_phone
self.use_gr_residual_f0 = use_gr_residual_f0
self.use_gr_residual_phone = use_gr_residual_phone
self.use_gr_timbre_content = use_gr_timbre_content
self.use_gr_timbre_prosody = use_gr_timbre_prosody
self.use_gr_x_timbre = use_gr_x_timbre
self.rev_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
)
self.rev_content_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
)
self.rev_timbre_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
)
self.norm_f0 = norm_f0
self.timbre_norm = timbre_norm
if timbre_norm:
self.forward = self.forward_v2
self.global_f0_predictor = nn.Linear(in_dim, 1)
self.use_gr_content_global_f0 = use_gr_content_global_f0
if use_gr_content_global_f0:
self.rev_global_f0_predictor = nn.Sequential(
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
)
def forward(self, quantized):
prosody_latent = quantized[0]
content_latent = quantized[1]
timbre_latent = quantized[2]
residual_latent = quantized[3]
content_pred = self.phone_predictor(content_latent)[0]
if self.norm_f0:
spk_pred = self.timbre_predictor(timbre_latent)[0]
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
else:
spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
prosody_rev_latent = torch.zeros_like(quantized[0])
if self.use_gr_content_f0:
prosody_rev_latent += quantized[1]
if self.use_gr_timbre_prosody:
prosody_rev_latent += quantized[2]
if self.use_gr_residual_f0:
prosody_rev_latent += quantized[3]
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
content_rev_latent = torch.zeros_like(quantized[1])
if self.use_gr_prosody_phone:
content_rev_latent += quantized[0]
if self.use_gr_timbre_content:
content_rev_latent += quantized[2]
if self.use_gr_residual_phone:
content_rev_latent += quantized[3]
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
if self.norm_f0:
timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
else:
timbre_rev_latent = quantized[1] + quantized[3]
if self.use_gr_x_timbre:
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
else:
x_spk_pred = None
preds = {
"f0": f0_pred,
"uv": uv_pred,
"content": content_pred,
"timbre": spk_pred,
}
rev_preds = {
"rev_f0": rev_f0_pred,
"rev_uv": rev_uv_pred,
"rev_content": rev_content_pred,
"x_timbre": x_spk_pred,
}
return preds, rev_preds
def forward_v2(self, quantized, timbre):
prosody_latent = quantized[0]
content_latent = quantized[1]
residual_latent = quantized[2]
content_pred = self.phone_predictor(content_latent)[0]
spk_pred = self.timbre_predictor(timbre)
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
prosody_rev_latent = torch.zeros_like(prosody_latent)
if self.use_gr_content_f0:
prosody_rev_latent += content_latent
if self.use_gr_residual_f0:
prosody_rev_latent += residual_latent
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
content_rev_latent = torch.zeros_like(content_latent)
if self.use_gr_prosody_phone:
content_rev_latent += prosody_latent
if self.use_gr_residual_phone:
content_rev_latent += residual_latent
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
timbre_rev_latent = prosody_latent + content_latent + residual_latent
if self.use_gr_x_timbre:
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
else:
x_spk_pred = None
preds = {
"f0": f0_pred,
"uv": uv_pred,
"content": content_pred,
"timbre": spk_pred,
}
rev_preds = {
"rev_f0": rev_f0_pred,
"rev_uv": rev_uv_pred,
"rev_content": rev_content_pred,
"x_timbre": x_spk_pred,
}
return preds, rev_preds

View File

@@ -0,0 +1,110 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
from . import attentions
from torch import nn
import torch
from torch.nn import functional as F
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class Conv1dGLU(nn.Module):
"""
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
"""
def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__()
self.out_channels = out_channels
self.conv1 = nn.Conv1d(
in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.conv1(x)
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
x = x1 * torch.sigmoid(x2)
x = residual + self.dropout(x)
return x
class StyleEncoder(torch.nn.Module):
def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
super().__init__()
self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.kernel_size = 5
self.n_head = 2
self.dropout = 0.1
self.spectral = nn.Sequential(
nn.Conv1d(self.in_dim, self.hidden_dim, 1),
Mish(),
nn.Dropout(self.dropout),
nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
Mish(),
nn.Dropout(self.dropout),
)
self.temporal = nn.Sequential(
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = attentions.MultiHeadAttention(
self.hidden_dim,
self.hidden_dim,
self.n_head,
p_dropout=self.dropout,
proximal_bias=False,
proximal_init=True,
)
self.atten_drop = nn.Dropout(self.dropout)
self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
def forward(self, x, mask=None):
# spectral
x = self.spectral(x) * mask
# temporal
x = self.temporal(x) * mask
# self-attention
attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
y = self.slf_attn(x, x, attn_mask=attn_mask)
x = x + self.atten_drop(y)
# fc
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w
def temporal_avg_pool(self, x, mask=None):
if mask is None:
out = torch.mean(x, dim=2)
else:
len_ = mask.sum(dim=2)
x = x.sum(dim=2)
out = torch.div(x, len_)
return out

View File

@@ -0,0 +1,224 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
import math
import torch
from torch import nn
from torch.nn import functional as F
from modules.dac.model.encodec import SConv1d
from . import commons
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
causal=False,
):
super(WN, self).__init__()
conv1d_type = SConv1d
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
self.cond_layer = conv1d_type(
gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
)
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = conv1d_type(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
norm="weight_norm",
causal=causal,
)
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = conv1d_type(
hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
)
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)