feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
@@ -0,0 +1 @@
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
|
||||
|
||||
"""
|
||||
Implementation of model from:
|
||||
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
|
||||
Convolutional Recurrent Neural Networks" (2019)
|
||||
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class JDCNet(nn.Module):
|
||||
"""
|
||||
Joint Detection and Classification Network model for singing voice melody.
|
||||
"""
|
||||
|
||||
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
|
||||
super().__init__()
|
||||
self.num_class = num_class
|
||||
|
||||
# input = (b, 1, 31, 513), b = batch size
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
|
||||
), # out: (b, 64, 31, 513)
|
||||
nn.BatchNorm2d(num_features=64),
|
||||
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
||||
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
|
||||
)
|
||||
|
||||
# res blocks
|
||||
self.res_block1 = ResBlock(
|
||||
in_channels=64, out_channels=128
|
||||
) # (b, 128, 31, 128)
|
||||
self.res_block2 = ResBlock(
|
||||
in_channels=128, out_channels=192
|
||||
) # (b, 192, 31, 32)
|
||||
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
|
||||
|
||||
# pool block
|
||||
self.pool_block = nn.Sequential(
|
||||
nn.BatchNorm2d(num_features=256),
|
||||
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
|
||||
nn.Dropout(p=0.2),
|
||||
)
|
||||
|
||||
# maxpool layers (for auxiliary network inputs)
|
||||
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
|
||||
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
|
||||
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
|
||||
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
|
||||
|
||||
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
|
||||
self.detector_conv = nn.Sequential(
|
||||
nn.Conv2d(640, 256, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
||||
nn.Dropout(p=0.2),
|
||||
)
|
||||
|
||||
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
||||
self.bilstm_classifier = nn.LSTM(
|
||||
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
|
||||
) # (b, 31, 512)
|
||||
|
||||
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
||||
self.bilstm_detector = nn.LSTM(
|
||||
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
|
||||
) # (b, 31, 512)
|
||||
|
||||
# input: (b * 31, 512)
|
||||
self.classifier = nn.Linear(
|
||||
in_features=512, out_features=self.num_class
|
||||
) # (b * 31, num_class)
|
||||
|
||||
# input: (b * 31, 512)
|
||||
self.detector = nn.Linear(
|
||||
in_features=512, out_features=2
|
||||
) # (b * 31, 2) - binary classifier
|
||||
|
||||
# initialize weights
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def get_feature_GAN(self, x):
|
||||
seq_len = x.shape[-2]
|
||||
x = x.float().transpose(-1, -2)
|
||||
|
||||
convblock_out = self.conv_block(x)
|
||||
|
||||
resblock1_out = self.res_block1(convblock_out)
|
||||
resblock2_out = self.res_block2(resblock1_out)
|
||||
resblock3_out = self.res_block3(resblock2_out)
|
||||
poolblock_out = self.pool_block[0](resblock3_out)
|
||||
poolblock_out = self.pool_block[1](poolblock_out)
|
||||
|
||||
return poolblock_out.transpose(-1, -2)
|
||||
|
||||
def get_feature(self, x):
|
||||
seq_len = x.shape[-2]
|
||||
x = x.float().transpose(-1, -2)
|
||||
|
||||
convblock_out = self.conv_block(x)
|
||||
|
||||
resblock1_out = self.res_block1(convblock_out)
|
||||
resblock2_out = self.res_block2(resblock1_out)
|
||||
resblock3_out = self.res_block3(resblock2_out)
|
||||
poolblock_out = self.pool_block[0](resblock3_out)
|
||||
poolblock_out = self.pool_block[1](poolblock_out)
|
||||
|
||||
return self.pool_block[2](poolblock_out)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Returns:
|
||||
classification_prediction, detection_prediction
|
||||
sizes: (b, 31, 722), (b, 31, 2)
|
||||
"""
|
||||
###############################
|
||||
# forward pass for classifier #
|
||||
###############################
|
||||
seq_len = x.shape[-1]
|
||||
x = x.float().transpose(-1, -2)
|
||||
|
||||
convblock_out = self.conv_block(x)
|
||||
|
||||
resblock1_out = self.res_block1(convblock_out)
|
||||
resblock2_out = self.res_block2(resblock1_out)
|
||||
resblock3_out = self.res_block3(resblock2_out)
|
||||
|
||||
poolblock_out = self.pool_block[0](resblock3_out)
|
||||
poolblock_out = self.pool_block[1](poolblock_out)
|
||||
GAN_feature = poolblock_out.transpose(-1, -2)
|
||||
poolblock_out = self.pool_block[2](poolblock_out)
|
||||
|
||||
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
|
||||
classifier_out = (
|
||||
poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
|
||||
)
|
||||
classifier_out, _ = self.bilstm_classifier(
|
||||
classifier_out
|
||||
) # ignore the hidden states
|
||||
|
||||
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
|
||||
classifier_out = self.classifier(classifier_out)
|
||||
classifier_out = classifier_out.view(
|
||||
(-1, seq_len, self.num_class)
|
||||
) # (b, 31, num_class)
|
||||
|
||||
# sizes: (b, 31, 722), (b, 31, 2)
|
||||
# classifier output consists of predicted pitch classes per frame
|
||||
# detector output consists of: (isvoice, notvoice) estimates per frame
|
||||
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
|
||||
|
||||
@staticmethod
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
|
||||
for p in m.parameters():
|
||||
if p.data is None:
|
||||
continue
|
||||
|
||||
if len(p.shape) >= 2:
|
||||
nn.init.orthogonal_(p.data)
|
||||
else:
|
||||
nn.init.normal_(p.data)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
|
||||
super().__init__()
|
||||
self.downsample = in_channels != out_channels
|
||||
|
||||
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
|
||||
self.pre_conv = nn.Sequential(
|
||||
nn.BatchNorm2d(num_features=in_channels),
|
||||
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
|
||||
)
|
||||
|
||||
# conv layers
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
||||
)
|
||||
|
||||
# 1 x 1 convolution layer to match the feature dimensions
|
||||
self.conv1by1 = None
|
||||
if self.downsample:
|
||||
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pre_conv(x)
|
||||
if self.downsample:
|
||||
x = self.conv(x) + self.conv1by1(x)
|
||||
else:
|
||||
x = self.conv(x) + x
|
||||
return x
|
||||
@@ -0,0 +1,437 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
|
||||
|
||||
import copy
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from . import commons
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size=1,
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
n_heads,
|
||||
p_dropout=p_dropout,
|
||||
window_size=window_size,
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size=1,
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.self_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_0 = nn.ModuleList()
|
||||
self.encdec_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.self_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
n_heads,
|
||||
p_dropout=p_dropout,
|
||||
proximal_bias=proximal_bias,
|
||||
proximal_init=proximal_init,
|
||||
)
|
||||
)
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
causal=True,
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask, h, h_mask):
|
||||
"""
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
||||
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
n_heads,
|
||||
p_dropout=0.0,
|
||||
window_size=None,
|
||||
heads_share=True,
|
||||
block_length=None,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
self.heads_share = heads_share
|
||||
self.block_length = block_length
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
if proximal_init:
|
||||
with torch.no_grad():
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
||||
)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Local attention is only available for self-attention."
|
||||
block_mask = (
|
||||
torch.ones_like(scores)
|
||||
.triu(-self.block_length)
|
||||
.tril(self.block_length)
|
||||
)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s
|
||||
)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings
|
||||
)
|
||||
output = (
|
||||
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, m]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, d]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0))
|
||||
return ret
|
||||
|
||||
def _matmul_with_relative_keys(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, d]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, m]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, 2*l-1]
|
||||
ret: [b, h, l, l]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Concat columns of pad to shift from relative to absolute indexing.
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, l]
|
||||
ret: [b, h, l, 2*l-1]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
"""Bias for self-attention to encourage attention to close positions.
|
||||
Args:
|
||||
length: an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
"""
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=0.0,
|
||||
activation=None,
|
||||
causal=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
if self.activation == "gelu":
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
||||
@@ -0,0 +1,331 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
import os.path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from munch import Munch
|
||||
import json
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += (
|
||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def slice_segments_audio(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
|
||||
dtype=torch.long
|
||||
)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
||||
num_timescales - 1
|
||||
)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
def log_norm(x, mean=-4, std=4, dim=2):
|
||||
"""
|
||||
normalized log mel -> mel -> norm -> log(norm)
|
||||
"""
|
||||
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
||||
return x
|
||||
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def load_F0_models(path):
|
||||
# load F0 model
|
||||
from .JDC.model import JDCNet
|
||||
|
||||
F0_model = JDCNet(num_class=1, seq_len=192)
|
||||
if not os.path.exists(path):
|
||||
path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
|
||||
params = torch.load(path, map_location="cpu")["net"]
|
||||
F0_model.load_state_dict(params)
|
||||
_ = F0_model.train()
|
||||
|
||||
return F0_model
|
||||
|
||||
|
||||
# Generators
|
||||
from modules.dac.model.dac import Encoder, Decoder
|
||||
from .quantize import FAquantizer, FApredictors
|
||||
|
||||
# Discriminators
|
||||
from modules.dac.model.discriminator import Discriminator
|
||||
|
||||
|
||||
def build_model(args):
|
||||
encoder = Encoder(
|
||||
d_model=args.DAC.encoder_dim,
|
||||
strides=args.DAC.encoder_rates,
|
||||
d_latent=1024,
|
||||
causal=args.causal,
|
||||
lstm=args.lstm,
|
||||
)
|
||||
|
||||
quantizer = FAquantizer(
|
||||
in_dim=1024,
|
||||
n_p_codebooks=1,
|
||||
n_c_codebooks=args.n_c_codebooks,
|
||||
n_t_codebooks=2,
|
||||
n_r_codebooks=3,
|
||||
codebook_size=1024,
|
||||
codebook_dim=8,
|
||||
quantizer_dropout=0.5,
|
||||
causal=args.causal,
|
||||
separate_prosody_encoder=args.separate_prosody_encoder,
|
||||
timbre_norm=args.timbre_norm,
|
||||
)
|
||||
|
||||
fa_predictors = FApredictors(
|
||||
in_dim=1024,
|
||||
use_gr_content_f0=args.use_gr_content_f0,
|
||||
use_gr_prosody_phone=args.use_gr_prosody_phone,
|
||||
use_gr_residual_f0=True,
|
||||
use_gr_residual_phone=True,
|
||||
use_gr_timbre_content=True,
|
||||
use_gr_timbre_prosody=args.use_gr_timbre_prosody,
|
||||
use_gr_x_timbre=True,
|
||||
norm_f0=args.norm_f0,
|
||||
timbre_norm=args.timbre_norm,
|
||||
use_gr_content_global_f0=args.use_gr_content_global_f0,
|
||||
)
|
||||
|
||||
decoder = Decoder(
|
||||
input_channel=1024,
|
||||
channels=args.DAC.decoder_dim,
|
||||
rates=args.DAC.decoder_rates,
|
||||
causal=args.causal,
|
||||
lstm=args.lstm,
|
||||
)
|
||||
|
||||
discriminator = Discriminator(
|
||||
rates=[],
|
||||
periods=[2, 3, 5, 7, 11],
|
||||
fft_sizes=[2048, 1024, 512],
|
||||
sample_rate=args.DAC.sr,
|
||||
bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
|
||||
)
|
||||
|
||||
nets = Munch(
|
||||
encoder=encoder,
|
||||
quantizer=quantizer,
|
||||
decoder=decoder,
|
||||
discriminator=discriminator,
|
||||
fa_predictors=fa_predictors,
|
||||
)
|
||||
|
||||
return nets
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
path,
|
||||
load_only_params=True,
|
||||
ignore_modules=[],
|
||||
is_distributed=False,
|
||||
):
|
||||
state = torch.load(path, map_location="cpu")
|
||||
params = state["net"]
|
||||
for key in model:
|
||||
if key in params and key not in ignore_modules:
|
||||
if not is_distributed:
|
||||
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
||||
for k in list(params[key].keys()):
|
||||
if k.startswith("module."):
|
||||
params[key][k[len("module.") :]] = params[key][k]
|
||||
del params[key][k]
|
||||
print("%s loaded" % key)
|
||||
model[key].load_state_dict(params[key], strict=True)
|
||||
_ = [model[key].eval() for key in model]
|
||||
|
||||
if not load_only_params:
|
||||
epoch = state["epoch"] + 1
|
||||
iters = state["iters"]
|
||||
optimizer.load_state_dict(state["optimizer"])
|
||||
optimizer.load_scheduler_state_dict(state["scheduler"])
|
||||
|
||||
else:
|
||||
epoch = state["epoch"] + 1
|
||||
iters = state["iters"]
|
||||
|
||||
return model, optimizer, epoch, iters
|
||||
|
||||
|
||||
def recursive_munch(d):
|
||||
if isinstance(d, dict):
|
||||
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
||||
elif isinstance(d, list):
|
||||
return [recursive_munch(v) for v in d]
|
||||
else:
|
||||
return d
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GradientReversal(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, alpha):
|
||||
ctx.save_for_backward(x, alpha)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input = None
|
||||
_, alpha = ctx.saved_tensors
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = -alpha * grad_output
|
||||
return grad_input, None
|
||||
|
||||
|
||||
revgrad = GradientReversal.apply
|
||||
|
||||
|
||||
class GradientReversal(nn.Module):
|
||||
def __init__(self, alpha):
|
||||
super().__init__()
|
||||
self.alpha = torch.tensor(alpha, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return revgrad(x, self.alpha)
|
||||
@@ -0,0 +1,460 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional, Any
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import torchaudio.functional as audio_F
|
||||
|
||||
import random
|
||||
|
||||
random.seed(0)
|
||||
|
||||
|
||||
def _get_activation_fn(activ):
|
||||
if activ == "relu":
|
||||
return nn.ReLU()
|
||||
elif activ == "lrelu":
|
||||
return nn.LeakyReLU(0.2)
|
||||
elif activ == "swish":
|
||||
return lambda x: x * torch.sigmoid(x)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
|
||||
)
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class ConvNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
w_init_gain="linear",
|
||||
param=None,
|
||||
):
|
||||
super(ConvNorm, self).__init__()
|
||||
if padding is None:
|
||||
assert kernel_size % 2 == 1
|
||||
padding = int(dilation * (kernel_size - 1) / 2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
||||
)
|
||||
|
||||
def forward(self, signal):
|
||||
conv_signal = self.conv(signal)
|
||||
return conv_signal
|
||||
|
||||
|
||||
class CausualConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
w_init_gain="linear",
|
||||
param=None,
|
||||
):
|
||||
super(CausualConv, self).__init__()
|
||||
if padding is None:
|
||||
assert kernel_size % 2 == 1
|
||||
padding = int(dilation * (kernel_size - 1) / 2) * 2
|
||||
else:
|
||||
self.padding = padding * 2
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x[:, :, : -self.padding]
|
||||
return x
|
||||
|
||||
|
||||
class CausualBlock(nn.Module):
|
||||
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
|
||||
super(CausualBlock, self).__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
self._get_conv(
|
||||
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
||||
)
|
||||
for i in range(n_conv)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
res = x
|
||||
x = block(x)
|
||||
x += res
|
||||
return x
|
||||
|
||||
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
|
||||
layers = [
|
||||
CausualConv(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
),
|
||||
_get_activation_fn(activ),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.Dropout(p=dropout_p),
|
||||
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
||||
_get_activation_fn(activ),
|
||||
nn.Dropout(p=dropout_p),
|
||||
]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
|
||||
super().__init__()
|
||||
self._n_groups = 8
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
self._get_conv(
|
||||
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
||||
)
|
||||
for i in range(n_conv)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
res = x
|
||||
x = block(x)
|
||||
x += res
|
||||
return x
|
||||
|
||||
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
|
||||
layers = [
|
||||
ConvNorm(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
),
|
||||
_get_activation_fn(activ),
|
||||
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
||||
nn.Dropout(p=dropout_p),
|
||||
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
||||
_get_activation_fn(activ),
|
||||
nn.Dropout(p=dropout_p),
|
||||
]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
|
||||
super(LocationLayer, self).__init__()
|
||||
padding = int((attention_kernel_size - 1) / 2)
|
||||
self.location_conv = ConvNorm(
|
||||
2,
|
||||
attention_n_filters,
|
||||
kernel_size=attention_kernel_size,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
)
|
||||
self.location_dense = LinearNorm(
|
||||
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
|
||||
)
|
||||
|
||||
def forward(self, attention_weights_cat):
|
||||
processed_attention = self.location_conv(attention_weights_cat)
|
||||
processed_attention = processed_attention.transpose(1, 2)
|
||||
processed_attention = self.location_dense(processed_attention)
|
||||
return processed_attention
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_rnn_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = LinearNorm(
|
||||
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
||||
)
|
||||
self.memory_layer = LinearNorm(
|
||||
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
||||
)
|
||||
self.v = LinearNorm(attention_dim, 1, bias=False)
|
||||
self.location_layer = LocationLayer(
|
||||
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
||||
)
|
||||
self.score_mask_value = -float("inf")
|
||||
|
||||
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
||||
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
||||
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
||||
RETURNS
|
||||
-------
|
||||
alignment (batch, max_time)
|
||||
"""
|
||||
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_weights_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
||||
)
|
||||
|
||||
energies = energies.squeeze(-1)
|
||||
return energies
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_hidden_state,
|
||||
memory,
|
||||
processed_memory,
|
||||
attention_weights_cat,
|
||||
mask,
|
||||
):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
attention_hidden_state: attention rnn last output
|
||||
memory: encoder outputs
|
||||
processed_memory: processed encoder outputs
|
||||
attention_weights_cat: previous and cummulative attention weights
|
||||
mask: binary mask for padded data
|
||||
"""
|
||||
alignment = self.get_alignment_energies(
|
||||
attention_hidden_state, processed_memory, attention_weights_cat
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
attention_weights = F.softmax(alignment, dim=1)
|
||||
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
||||
attention_context = attention_context.squeeze(1)
|
||||
|
||||
return attention_context, attention_weights
|
||||
|
||||
|
||||
class ForwardAttentionV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_rnn_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
):
|
||||
super(ForwardAttentionV2, self).__init__()
|
||||
self.query_layer = LinearNorm(
|
||||
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
||||
)
|
||||
self.memory_layer = LinearNorm(
|
||||
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
||||
)
|
||||
self.v = LinearNorm(attention_dim, 1, bias=False)
|
||||
self.location_layer = LocationLayer(
|
||||
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
||||
)
|
||||
self.score_mask_value = -float(1e20)
|
||||
|
||||
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
||||
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
||||
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
|
||||
RETURNS
|
||||
-------
|
||||
alignment (batch, max_time)
|
||||
"""
|
||||
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_weights_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
||||
)
|
||||
|
||||
energies = energies.squeeze(-1)
|
||||
return energies
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_hidden_state,
|
||||
memory,
|
||||
processed_memory,
|
||||
attention_weights_cat,
|
||||
mask,
|
||||
log_alpha,
|
||||
):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
attention_hidden_state: attention rnn last output
|
||||
memory: encoder outputs
|
||||
processed_memory: processed encoder outputs
|
||||
attention_weights_cat: previous and cummulative attention weights
|
||||
mask: binary mask for padded data
|
||||
"""
|
||||
log_energy = self.get_alignment_energies(
|
||||
attention_hidden_state, processed_memory, attention_weights_cat
|
||||
)
|
||||
|
||||
# log_energy =
|
||||
|
||||
if mask is not None:
|
||||
log_energy.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
# attention_weights = F.softmax(alignment, dim=1)
|
||||
|
||||
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
|
||||
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
|
||||
|
||||
# log_total_score = log_alpha + content_score
|
||||
|
||||
# previous_attention_weights = attention_weights_cat[:,0,:]
|
||||
|
||||
log_alpha_shift_padded = []
|
||||
max_time = log_energy.size(1)
|
||||
for sft in range(2):
|
||||
shifted = log_alpha[:, : max_time - sft]
|
||||
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
|
||||
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
|
||||
|
||||
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
|
||||
|
||||
log_alpha_new = biased + log_energy
|
||||
|
||||
attention_weights = F.softmax(log_alpha_new, dim=1)
|
||||
|
||||
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
||||
attention_context = attention_context.squeeze(1)
|
||||
|
||||
return attention_context, attention_weights, log_alpha_new
|
||||
|
||||
|
||||
class PhaseShuffle2d(nn.Module):
|
||||
def __init__(self, n=2):
|
||||
super(PhaseShuffle2d, self).__init__()
|
||||
self.n = n
|
||||
self.random = random.Random(1)
|
||||
|
||||
def forward(self, x, move=None):
|
||||
# x.size = (B, C, M, L)
|
||||
if move is None:
|
||||
move = self.random.randint(-self.n, self.n)
|
||||
|
||||
if move == 0:
|
||||
return x
|
||||
else:
|
||||
left = x[:, :, :, :move]
|
||||
right = x[:, :, :, move:]
|
||||
shuffled = torch.cat([right, left], dim=3)
|
||||
return shuffled
|
||||
|
||||
|
||||
class PhaseShuffle1d(nn.Module):
|
||||
def __init__(self, n=2):
|
||||
super(PhaseShuffle1d, self).__init__()
|
||||
self.n = n
|
||||
self.random = random.Random(1)
|
||||
|
||||
def forward(self, x, move=None):
|
||||
# x.size = (B, C, M, L)
|
||||
if move is None:
|
||||
move = self.random.randint(-self.n, self.n)
|
||||
|
||||
if move == 0:
|
||||
return x
|
||||
else:
|
||||
left = x[:, :, :move]
|
||||
right = x[:, :, move:]
|
||||
shuffled = torch.cat([right, left], dim=2)
|
||||
|
||||
return shuffled
|
||||
|
||||
|
||||
class MFCC(nn.Module):
|
||||
def __init__(self, n_mfcc=40, n_mels=80):
|
||||
super(MFCC, self).__init__()
|
||||
self.n_mfcc = n_mfcc
|
||||
self.n_mels = n_mels
|
||||
self.norm = "ortho"
|
||||
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
||||
self.register_buffer("dct_mat", dct_mat)
|
||||
|
||||
def forward(self, mel_specgram):
|
||||
if len(mel_specgram.shape) == 2:
|
||||
mel_specgram = mel_specgram.unsqueeze(0)
|
||||
unsqueezed = True
|
||||
else:
|
||||
unsqueezed = False
|
||||
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
||||
# -> (channel, time, n_mfcc).tranpose(...)
|
||||
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
||||
|
||||
# unpack batch
|
||||
if unsqueezed:
|
||||
mfcc = mfcc.squeeze(0)
|
||||
return mfcc
|
||||
@@ -0,0 +1,741 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from modules.dac.nn.quantize import ResidualVectorQuantize
|
||||
from torch import nn
|
||||
from .wavenet import WN
|
||||
from .style_encoder import StyleEncoder
|
||||
from .gradient_reversal import GradientReversal
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.functional as audio_F
|
||||
import numpy as np
|
||||
from ..alias_free_torch import *
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch import nn, sin, pow
|
||||
from einops.layers.torch import Rearrange
|
||||
from modules.dac.model.encodec import SConv1d
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta := x + 1/b * sin^2 (xa)
|
||||
"""
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
||||
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
||||
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
||||
WNConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
class CNNLSTM(nn.Module):
|
||||
def __init__(self, indim, outdim, head, global_pred=False):
|
||||
super().__init__()
|
||||
self.global_pred = global_pred
|
||||
self.model = nn.Sequential(
|
||||
ResidualUnit(indim, dilation=1),
|
||||
ResidualUnit(indim, dilation=2),
|
||||
ResidualUnit(indim, dilation=3),
|
||||
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
|
||||
Rearrange("b c t -> b t c"),
|
||||
)
|
||||
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C, T]
|
||||
x = self.model(x)
|
||||
if self.global_pred:
|
||||
x = torch.mean(x, dim=1, keepdim=False)
|
||||
outs = [head(x) for head in self.heads]
|
||||
return outs
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
class MFCC(nn.Module):
|
||||
def __init__(self, n_mfcc=40, n_mels=80):
|
||||
super(MFCC, self).__init__()
|
||||
self.n_mfcc = n_mfcc
|
||||
self.n_mels = n_mels
|
||||
self.norm = "ortho"
|
||||
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
||||
self.register_buffer("dct_mat", dct_mat)
|
||||
|
||||
def forward(self, mel_specgram):
|
||||
if len(mel_specgram.shape) == 2:
|
||||
mel_specgram = mel_specgram.unsqueeze(0)
|
||||
unsqueezed = True
|
||||
else:
|
||||
unsqueezed = False
|
||||
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
||||
# -> (channel, time, n_mfcc).tranpose(...)
|
||||
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
||||
|
||||
# unpack batch
|
||||
if unsqueezed:
|
||||
mfcc = mfcc.squeeze(0)
|
||||
return mfcc
|
||||
|
||||
|
||||
class FAquantizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim=1024,
|
||||
n_p_codebooks=1,
|
||||
n_c_codebooks=2,
|
||||
n_t_codebooks=2,
|
||||
n_r_codebooks=3,
|
||||
codebook_size=1024,
|
||||
codebook_dim=8,
|
||||
quantizer_dropout=0.5,
|
||||
causal=False,
|
||||
separate_prosody_encoder=False,
|
||||
timbre_norm=False,
|
||||
):
|
||||
super(FAquantizer, self).__init__()
|
||||
conv1d_type = SConv1d # if causal else nn.Conv1d
|
||||
self.prosody_quantizer = ResidualVectorQuantize(
|
||||
input_dim=in_dim,
|
||||
n_codebooks=n_p_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
|
||||
self.content_quantizer = ResidualVectorQuantize(
|
||||
input_dim=in_dim,
|
||||
n_codebooks=n_c_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
|
||||
if not timbre_norm:
|
||||
self.timbre_quantizer = ResidualVectorQuantize(
|
||||
input_dim=in_dim,
|
||||
n_codebooks=n_t_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
else:
|
||||
self.timbre_encoder = StyleEncoder(
|
||||
in_dim=80, hidden_dim=512, out_dim=in_dim
|
||||
)
|
||||
self.timbre_linear = nn.Linear(1024, 1024 * 2)
|
||||
self.timbre_linear.bias.data[:1024] = 1
|
||||
self.timbre_linear.bias.data[1024:] = 0
|
||||
self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
|
||||
|
||||
self.residual_quantizer = ResidualVectorQuantize(
|
||||
input_dim=in_dim,
|
||||
n_codebooks=n_r_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout,
|
||||
)
|
||||
|
||||
if separate_prosody_encoder:
|
||||
self.melspec_linear = conv1d_type(
|
||||
in_channels=20, out_channels=256, kernel_size=1, causal=causal
|
||||
)
|
||||
self.melspec_encoder = WN(
|
||||
hidden_channels=256,
|
||||
kernel_size=5,
|
||||
dilation_rate=1,
|
||||
n_layers=8,
|
||||
gin_channels=0,
|
||||
p_dropout=0.2,
|
||||
causal=causal,
|
||||
)
|
||||
self.melspec_linear2 = conv1d_type(
|
||||
in_channels=256, out_channels=1024, kernel_size=1, causal=causal
|
||||
)
|
||||
else:
|
||||
pass
|
||||
self.separate_prosody_encoder = separate_prosody_encoder
|
||||
|
||||
self.prob_random_mask_residual = 0.75
|
||||
|
||||
SPECT_PARAMS = {
|
||||
"n_fft": 2048,
|
||||
"win_length": 1200,
|
||||
"hop_length": 300,
|
||||
}
|
||||
MEL_PARAMS = {
|
||||
"n_mels": 80,
|
||||
}
|
||||
|
||||
self.to_mel = torchaudio.transforms.MelSpectrogram(
|
||||
n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
|
||||
)
|
||||
self.mel_mean, self.mel_std = -4, 4
|
||||
self.frame_rate = 24000 / 300
|
||||
self.hop_length = 300
|
||||
|
||||
self.is_timbre_norm = timbre_norm
|
||||
if timbre_norm:
|
||||
self.forward = self.forward_v2
|
||||
|
||||
def preprocess(self, wave_tensor, n_bins=20):
|
||||
mel_tensor = self.to_mel(wave_tensor.squeeze(1))
|
||||
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
|
||||
return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes):
|
||||
code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
|
||||
|
||||
z_c = self.content_quantizer.from_codes(code_c)[0]
|
||||
z_p = self.prosody_quantizer.from_codes(code_p)[0]
|
||||
z_t = self.timbre_quantizer.from_codes(code_t)[0]
|
||||
|
||||
z = z_c + z_p + z_t
|
||||
|
||||
return z, [z_c, z_p, z_t]
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x, wave_segments, n_c=1):
|
||||
outs = 0
|
||||
if self.separate_prosody_encoder:
|
||||
prosody_feature = self.preprocess(wave_segments)
|
||||
|
||||
f0_input = prosody_feature # (B, T, 20)
|
||||
f0_input = self.melspec_linear(f0_input)
|
||||
f0_input = self.melspec_encoder(
|
||||
f0_input,
|
||||
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
|
||||
.to(f0_input.device)
|
||||
.bool(),
|
||||
)
|
||||
f0_input = self.melspec_linear2(f0_input)
|
||||
|
||||
common_min_size = min(f0_input.size(2), x.size(2))
|
||||
f0_input = f0_input[:, :, :common_min_size]
|
||||
|
||||
x = x[:, :, :common_min_size]
|
||||
|
||||
(
|
||||
z_p,
|
||||
codes_p,
|
||||
latents_p,
|
||||
commitment_loss_p,
|
||||
codebook_loss_p,
|
||||
) = self.prosody_quantizer(f0_input, 1)
|
||||
outs += z_p.detach()
|
||||
else:
|
||||
(
|
||||
z_p,
|
||||
codes_p,
|
||||
latents_p,
|
||||
commitment_loss_p,
|
||||
codebook_loss_p,
|
||||
) = self.prosody_quantizer(x, 1)
|
||||
outs += z_p.detach()
|
||||
|
||||
(
|
||||
z_c,
|
||||
codes_c,
|
||||
latents_c,
|
||||
commitment_loss_c,
|
||||
codebook_loss_c,
|
||||
) = self.content_quantizer(x, n_c)
|
||||
outs += z_c.detach()
|
||||
|
||||
timbre_residual_feature = x - z_p.detach() - z_c.detach()
|
||||
|
||||
(
|
||||
z_t,
|
||||
codes_t,
|
||||
latents_t,
|
||||
commitment_loss_t,
|
||||
codebook_loss_t,
|
||||
) = self.timbre_quantizer(timbre_residual_feature, 2)
|
||||
outs += z_t # we should not detach timbre
|
||||
|
||||
residual_feature = timbre_residual_feature - z_t
|
||||
|
||||
(
|
||||
z_r,
|
||||
codes_r,
|
||||
latents_r,
|
||||
commitment_loss_r,
|
||||
codebook_loss_r,
|
||||
) = self.residual_quantizer(residual_feature, 3)
|
||||
|
||||
return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
|
||||
|
||||
def forward(
|
||||
self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
|
||||
):
|
||||
# timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
|
||||
# timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
|
||||
outs = 0
|
||||
if self.separate_prosody_encoder:
|
||||
prosody_feature = self.preprocess(wave_segments)
|
||||
|
||||
f0_input = prosody_feature # (B, T, 20)
|
||||
f0_input = self.melspec_linear(f0_input)
|
||||
f0_input = self.melspec_encoder(
|
||||
f0_input,
|
||||
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
|
||||
.to(f0_input.device)
|
||||
.bool(),
|
||||
)
|
||||
f0_input = self.melspec_linear2(f0_input)
|
||||
|
||||
common_min_size = min(f0_input.size(2), x.size(2))
|
||||
f0_input = f0_input[:, :, :common_min_size]
|
||||
|
||||
x = x[:, :, :common_min_size]
|
||||
|
||||
(
|
||||
z_p,
|
||||
codes_p,
|
||||
latents_p,
|
||||
commitment_loss_p,
|
||||
codebook_loss_p,
|
||||
) = self.prosody_quantizer(f0_input, 1)
|
||||
outs += z_p.detach()
|
||||
else:
|
||||
(
|
||||
z_p,
|
||||
codes_p,
|
||||
latents_p,
|
||||
commitment_loss_p,
|
||||
codebook_loss_p,
|
||||
) = self.prosody_quantizer(x, 1)
|
||||
outs += z_p.detach()
|
||||
|
||||
(
|
||||
z_c,
|
||||
codes_c,
|
||||
latents_c,
|
||||
commitment_loss_c,
|
||||
codebook_loss_c,
|
||||
) = self.content_quantizer(x, n_c)
|
||||
outs += z_c.detach()
|
||||
|
||||
timbre_residual_feature = x - z_p.detach() - z_c.detach()
|
||||
|
||||
(
|
||||
z_t,
|
||||
codes_t,
|
||||
latents_t,
|
||||
commitment_loss_t,
|
||||
codebook_loss_t,
|
||||
) = self.timbre_quantizer(timbre_residual_feature, n_t)
|
||||
outs += z_t # we should not detach timbre
|
||||
|
||||
residual_feature = timbre_residual_feature - z_t
|
||||
|
||||
(
|
||||
z_r,
|
||||
codes_r,
|
||||
latents_r,
|
||||
commitment_loss_r,
|
||||
codebook_loss_r,
|
||||
) = self.residual_quantizer(residual_feature, 3)
|
||||
|
||||
bsz = z_r.shape[0]
|
||||
res_mask = np.random.choice(
|
||||
[0, 1],
|
||||
size=bsz,
|
||||
p=[
|
||||
self.prob_random_mask_residual,
|
||||
1 - self.prob_random_mask_residual,
|
||||
],
|
||||
)
|
||||
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
|
||||
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
|
||||
noise_must_on = noise_added_flags * recon_noisy_flags
|
||||
noise_must_off = noise_added_flags * (~recon_noisy_flags)
|
||||
res_mask[noise_must_on] = 1
|
||||
res_mask[noise_must_off] = 0
|
||||
|
||||
outs += z_r * res_mask
|
||||
|
||||
quantized = [z_p, z_c, z_t, z_r]
|
||||
commitment_losses = (
|
||||
commitment_loss_p
|
||||
+ commitment_loss_c
|
||||
+ commitment_loss_t
|
||||
+ commitment_loss_r
|
||||
)
|
||||
codebook_losses = (
|
||||
codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
|
||||
)
|
||||
|
||||
return outs, quantized, commitment_losses, codebook_losses
|
||||
|
||||
def forward_v2(
|
||||
self,
|
||||
x,
|
||||
wave_segments,
|
||||
n_c=1,
|
||||
n_t=2,
|
||||
full_waves=None,
|
||||
wave_lens=None,
|
||||
return_codes=False,
|
||||
):
|
||||
# timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
|
||||
if full_waves is None:
|
||||
mel = self.preprocess(wave_segments, n_bins=80)
|
||||
timbre = self.timbre_encoder(
|
||||
mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
|
||||
)
|
||||
else:
|
||||
mel = self.preprocess(full_waves, n_bins=80)
|
||||
timbre = self.timbre_encoder(
|
||||
mel,
|
||||
sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
|
||||
)
|
||||
outs = 0
|
||||
if self.separate_prosody_encoder:
|
||||
prosody_feature = self.preprocess(wave_segments)
|
||||
|
||||
f0_input = prosody_feature # (B, T, 20)
|
||||
f0_input = self.melspec_linear(f0_input)
|
||||
f0_input = self.melspec_encoder(
|
||||
f0_input,
|
||||
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
|
||||
.to(f0_input.device)
|
||||
.bool(),
|
||||
)
|
||||
f0_input = self.melspec_linear2(f0_input)
|
||||
|
||||
common_min_size = min(f0_input.size(2), x.size(2))
|
||||
f0_input = f0_input[:, :, :common_min_size]
|
||||
|
||||
x = x[:, :, :common_min_size]
|
||||
|
||||
(
|
||||
z_p,
|
||||
codes_p,
|
||||
latents_p,
|
||||
commitment_loss_p,
|
||||
codebook_loss_p,
|
||||
) = self.prosody_quantizer(f0_input, 1)
|
||||
outs += z_p.detach()
|
||||
else:
|
||||
(
|
||||
z_p,
|
||||
codes_p,
|
||||
latents_p,
|
||||
commitment_loss_p,
|
||||
codebook_loss_p,
|
||||
) = self.prosody_quantizer(x, 1)
|
||||
outs += z_p.detach()
|
||||
|
||||
(
|
||||
z_c,
|
||||
codes_c,
|
||||
latents_c,
|
||||
commitment_loss_c,
|
||||
codebook_loss_c,
|
||||
) = self.content_quantizer(x, n_c)
|
||||
outs += z_c.detach()
|
||||
|
||||
residual_feature = x - z_p.detach() - z_c.detach()
|
||||
|
||||
(
|
||||
z_r,
|
||||
codes_r,
|
||||
latents_r,
|
||||
commitment_loss_r,
|
||||
codebook_loss_r,
|
||||
) = self.residual_quantizer(residual_feature, 3)
|
||||
|
||||
bsz = z_r.shape[0]
|
||||
res_mask = np.random.choice(
|
||||
[0, 1],
|
||||
size=bsz,
|
||||
p=[
|
||||
self.prob_random_mask_residual,
|
||||
1 - self.prob_random_mask_residual,
|
||||
],
|
||||
)
|
||||
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
|
||||
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
|
||||
|
||||
if not self.training:
|
||||
res_mask = torch.ones_like(res_mask)
|
||||
outs += z_r * res_mask
|
||||
|
||||
quantized = [z_p, z_c, z_r]
|
||||
codes = [codes_p, codes_c, codes_r]
|
||||
commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
|
||||
codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
|
||||
|
||||
style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
|
||||
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
||||
outs = outs.transpose(1, 2)
|
||||
outs = self.timbre_norm(outs)
|
||||
outs = outs.transpose(1, 2)
|
||||
outs = outs * gamma + beta
|
||||
|
||||
if return_codes:
|
||||
return outs, quantized, commitment_losses, codebook_losses, timbre, codes
|
||||
else:
|
||||
return outs, quantized, commitment_losses, codebook_losses, timbre
|
||||
|
||||
def voice_conversion(self, z, ref_wave):
|
||||
ref_mel = self.preprocess(ref_wave, n_bins=80)
|
||||
ref_timbre = self.timbre_encoder(
|
||||
ref_mel,
|
||||
sequence_mask(
|
||||
torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
|
||||
ref_mel.size(-1),
|
||||
).unsqueeze(1),
|
||||
)
|
||||
style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1)
|
||||
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
||||
outs = z.transpose(1, 2)
|
||||
outs = self.timbre_norm(outs)
|
||||
outs = outs.transpose(1, 2)
|
||||
outs = outs * gamma + beta
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class FApredictors(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim=1024,
|
||||
use_gr_content_f0=False,
|
||||
use_gr_prosody_phone=False,
|
||||
use_gr_residual_f0=False,
|
||||
use_gr_residual_phone=False,
|
||||
use_gr_timbre_content=True,
|
||||
use_gr_timbre_prosody=True,
|
||||
use_gr_x_timbre=False,
|
||||
norm_f0=True,
|
||||
timbre_norm=False,
|
||||
use_gr_content_global_f0=False,
|
||||
):
|
||||
super(FApredictors, self).__init__()
|
||||
self.f0_predictor = CNNLSTM(in_dim, 1, 2)
|
||||
self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
|
||||
if timbre_norm:
|
||||
self.timbre_predictor = nn.Linear(in_dim, 20000)
|
||||
else:
|
||||
self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
|
||||
|
||||
self.use_gr_content_f0 = use_gr_content_f0
|
||||
self.use_gr_prosody_phone = use_gr_prosody_phone
|
||||
self.use_gr_residual_f0 = use_gr_residual_f0
|
||||
self.use_gr_residual_phone = use_gr_residual_phone
|
||||
self.use_gr_timbre_content = use_gr_timbre_content
|
||||
self.use_gr_timbre_prosody = use_gr_timbre_prosody
|
||||
self.use_gr_x_timbre = use_gr_x_timbre
|
||||
|
||||
self.rev_f0_predictor = nn.Sequential(
|
||||
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
|
||||
)
|
||||
self.rev_content_predictor = nn.Sequential(
|
||||
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
|
||||
)
|
||||
self.rev_timbre_predictor = nn.Sequential(
|
||||
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
|
||||
)
|
||||
|
||||
self.norm_f0 = norm_f0
|
||||
self.timbre_norm = timbre_norm
|
||||
if timbre_norm:
|
||||
self.forward = self.forward_v2
|
||||
self.global_f0_predictor = nn.Linear(in_dim, 1)
|
||||
|
||||
self.use_gr_content_global_f0 = use_gr_content_global_f0
|
||||
if use_gr_content_global_f0:
|
||||
self.rev_global_f0_predictor = nn.Sequential(
|
||||
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
|
||||
)
|
||||
|
||||
def forward(self, quantized):
|
||||
prosody_latent = quantized[0]
|
||||
content_latent = quantized[1]
|
||||
timbre_latent = quantized[2]
|
||||
residual_latent = quantized[3]
|
||||
content_pred = self.phone_predictor(content_latent)[0]
|
||||
|
||||
if self.norm_f0:
|
||||
spk_pred = self.timbre_predictor(timbre_latent)[0]
|
||||
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
|
||||
else:
|
||||
spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
|
||||
f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
|
||||
|
||||
prosody_rev_latent = torch.zeros_like(quantized[0])
|
||||
if self.use_gr_content_f0:
|
||||
prosody_rev_latent += quantized[1]
|
||||
if self.use_gr_timbre_prosody:
|
||||
prosody_rev_latent += quantized[2]
|
||||
if self.use_gr_residual_f0:
|
||||
prosody_rev_latent += quantized[3]
|
||||
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
|
||||
|
||||
content_rev_latent = torch.zeros_like(quantized[1])
|
||||
if self.use_gr_prosody_phone:
|
||||
content_rev_latent += quantized[0]
|
||||
if self.use_gr_timbre_content:
|
||||
content_rev_latent += quantized[2]
|
||||
if self.use_gr_residual_phone:
|
||||
content_rev_latent += quantized[3]
|
||||
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
|
||||
|
||||
if self.norm_f0:
|
||||
timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
|
||||
else:
|
||||
timbre_rev_latent = quantized[1] + quantized[3]
|
||||
if self.use_gr_x_timbre:
|
||||
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
|
||||
else:
|
||||
x_spk_pred = None
|
||||
|
||||
preds = {
|
||||
"f0": f0_pred,
|
||||
"uv": uv_pred,
|
||||
"content": content_pred,
|
||||
"timbre": spk_pred,
|
||||
}
|
||||
|
||||
rev_preds = {
|
||||
"rev_f0": rev_f0_pred,
|
||||
"rev_uv": rev_uv_pred,
|
||||
"rev_content": rev_content_pred,
|
||||
"x_timbre": x_spk_pred,
|
||||
}
|
||||
return preds, rev_preds
|
||||
|
||||
def forward_v2(self, quantized, timbre):
|
||||
prosody_latent = quantized[0]
|
||||
content_latent = quantized[1]
|
||||
residual_latent = quantized[2]
|
||||
content_pred = self.phone_predictor(content_latent)[0]
|
||||
|
||||
spk_pred = self.timbre_predictor(timbre)
|
||||
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
|
||||
|
||||
prosody_rev_latent = torch.zeros_like(prosody_latent)
|
||||
if self.use_gr_content_f0:
|
||||
prosody_rev_latent += content_latent
|
||||
if self.use_gr_residual_f0:
|
||||
prosody_rev_latent += residual_latent
|
||||
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
|
||||
|
||||
content_rev_latent = torch.zeros_like(content_latent)
|
||||
if self.use_gr_prosody_phone:
|
||||
content_rev_latent += prosody_latent
|
||||
if self.use_gr_residual_phone:
|
||||
content_rev_latent += residual_latent
|
||||
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
|
||||
|
||||
timbre_rev_latent = prosody_latent + content_latent + residual_latent
|
||||
if self.use_gr_x_timbre:
|
||||
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
|
||||
else:
|
||||
x_spk_pred = None
|
||||
|
||||
preds = {
|
||||
"f0": f0_pred,
|
||||
"uv": uv_pred,
|
||||
"content": content_pred,
|
||||
"timbre": spk_pred,
|
||||
}
|
||||
|
||||
rev_preds = {
|
||||
"rev_f0": rev_f0_pred,
|
||||
"rev_uv": rev_uv_pred,
|
||||
"rev_content": rev_content_pred,
|
||||
"x_timbre": x_spk_pred,
|
||||
}
|
||||
return preds, rev_preds
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
|
||||
|
||||
from . import attentions
|
||||
from torch import nn
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self):
|
||||
super(Mish, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.tanh(F.softplus(x))
|
||||
|
||||
|
||||
class Conv1dGLU(nn.Module):
|
||||
"""
|
||||
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
||||
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
||||
super(Conv1dGLU, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.conv1 = nn.Conv1d(
|
||||
in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
|
||||
x = x1 * torch.sigmoid(x2)
|
||||
x = residual + self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class StyleEncoder(torch.nn.Module):
|
||||
def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
self.kernel_size = 5
|
||||
self.n_head = 2
|
||||
self.dropout = 0.1
|
||||
|
||||
self.spectral = nn.Sequential(
|
||||
nn.Conv1d(self.in_dim, self.hidden_dim, 1),
|
||||
Mish(),
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
|
||||
Mish(),
|
||||
nn.Dropout(self.dropout),
|
||||
)
|
||||
|
||||
self.temporal = nn.Sequential(
|
||||
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
||||
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
||||
)
|
||||
|
||||
self.slf_attn = attentions.MultiHeadAttention(
|
||||
self.hidden_dim,
|
||||
self.hidden_dim,
|
||||
self.n_head,
|
||||
p_dropout=self.dropout,
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
)
|
||||
self.atten_drop = nn.Dropout(self.dropout)
|
||||
self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
|
||||
# spectral
|
||||
x = self.spectral(x) * mask
|
||||
# temporal
|
||||
x = self.temporal(x) * mask
|
||||
|
||||
# self-attention
|
||||
attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
|
||||
y = self.slf_attn(x, x, attn_mask=attn_mask)
|
||||
x = x + self.atten_drop(y)
|
||||
|
||||
# fc
|
||||
x = self.fc(x)
|
||||
|
||||
# temoral average pooling
|
||||
w = self.temporal_avg_pool(x, mask=mask)
|
||||
|
||||
return w
|
||||
|
||||
def temporal_avg_pool(self, x, mask=None):
|
||||
if mask is None:
|
||||
out = torch.mean(x, dim=2)
|
||||
else:
|
||||
len_ = mask.sum(dim=2)
|
||||
x = x.sum(dim=2)
|
||||
|
||||
out = torch.div(x, len_)
|
||||
return out
|
||||
@@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from modules.dac.model.encodec import SConv1d
|
||||
|
||||
from . import commons
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_layers,
|
||||
p_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
)
|
||||
)
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0,
|
||||
p_dropout=0,
|
||||
causal=False,
|
||||
):
|
||||
super(WN, self).__init__()
|
||||
conv1d_type = SConv1d
|
||||
assert kernel_size % 2 == 1
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = (kernel_size,)
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond_layer = conv1d_type(
|
||||
gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
|
||||
)
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = conv1d_type(
|
||||
hidden_channels,
|
||||
2 * hidden_channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
norm="weight_norm",
|
||||
causal=causal,
|
||||
)
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = conv1d_type(
|
||||
hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
|
||||
)
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
Reference in New Issue
Block a user