init commit

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-26 15:34:31 +08:00
commit 80513a3258
141 changed files with 24966 additions and 0 deletions

Binary file not shown.

View File

@@ -0,0 +1,523 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import random
import typing as tp
from random import randrange
import numpy as np
from einops import rearrange, repeat
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
@torch.no_grad()
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
dists = -(
samples.pow(2).sum(1, keepdim=True)
- 2 * torch.matmul(samples, means.t())
+ means.t().pow(2).sum(0, keepdim=True)
)
buckets = dists.max(dim=-1).indices
del dists
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
def preprocess(x):
x = rearrange(x, "... d -> (...) d")
return x
def postprocess_emb(embed_ind, shape):
return embed_ind.view(*shape[:-1])
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: float = 2.0,
):
super().__init__()
self.decay = decay
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.inited = None
self.cluster_size = None
self.embed = None
self.embed_avg = None
self.training = True
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
# distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
expired_codes = cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
else:
print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
# sync buffers outside for efficiency
# distrib.broadcast_tensors(self.buffers())
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
shape = x.shape
# pre-process
x = preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
shape, dtype = x.shape, x.dtype
x = preprocess(x)
self.init_embed_(x)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
# Note: after ema update, there is a very small difference between codebooks on GPUs.
# The impact can be very small, ignore it.
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently, supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: float = 2.0,
commitment_weight: float = 1.,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
self.training = True
@property
def codebook(self):
return self._codebook.embed
def encode(self, x, buffers):
# x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x, buffers)
return embed_in
def decode(self, embed_ind, buffers):
quantize = self._codebook.decode(embed_ind, buffers)
quantize = self.project_out(quantize)
# quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x, buffers):
device = x.device
# x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x, buffers)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
# quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class DistributedResidualVectorQuantization(nn.Module):
"""Efficient distributed residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
"""
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
"""
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
kmeans_init = kwargs["kmeans_init"]
if isinstance(kmeans_init, bool):
if not kwargs["kmeans_init"]:
# use uniform init
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
inited = True
else:
# to perform kmeans init on first batch
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
inited = False
elif isinstance(kmeans_init, str):
# use prepared kmeans init
embed = np.load(kmeans_init)
embed = torch.from_numpy(embed)
if embed.dim() == 2:
embed = embed.unsqueeze(0)
inited = True
else:
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
self.q0_ds_ratio = 1
if "q0_ds_ratio" in kwargs:
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
self.layers = nn.ModuleList()
for i in range(num_quantizers):
vq_args = dict(**kwargs)
vq = VectorQuantization(**vq_args)
self.layers.append(vq)
self.quantize_dropout = quantize_dropout
self.rand_num_quant = rand_num_quant
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = torch.zeros_like(x)
residual = x
bb, cc, tt = x.shape
device = x.device
all_losses = []
all_indices = []
all_sub_quants = []
n_q = n_q or len(self.layers)
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
if should_quantize_dropout:
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
null_indices_shape = (x.shape[0], x.shape[2])
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
for quantizer_index, layer in enumerate(self.layers[:n_q]):
# dropout except the first quantizer
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
all_sub_quants.append(null_sub_quant)
continue
quant_in = residual
if self.q0_ds_ratio > 1 and quantizer_index == 0:
quant_in = F.interpolate(quant_in, size=[tt//2])
quantized, indices, loss = layer(quant_in, [
self.inited[quantizer_index],
self.cluster_size[quantizer_index],
self.embed[quantizer_index],
self.embed_avg[quantizer_index]
])
if self.q0_ds_ratio > 1 and quantizer_index == 0:
quantized = F.interpolate(quantized, size=[tt])
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
all_sub_quants.append(quantized)
# sync buffers after one forward step
# distrib.broadcast_tensors(self.buffers())
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
indices = layer.encode(residual, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
quantized = layer.decode(indices, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
quantized_out = quantized_out + quantized
return quantized_out
class DistributedGroupResidualVectorQuantization(nn.Module):
"""Efficient distributed group residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
Group Then rvq
"""
def __init__(self, *,
num_groups,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
self.rvqs = nn.ModuleList(
[
DistributedResidualVectorQuantization(
num_quantizers=num_quantizers,
quantize_dropout=quantize_dropout,
rand_num_quant=rand_num_quant,
**kwargs
)
for _ in range(num_groups)
]
)
self.num_groups = num_groups
def forward(self, x, n_q: tp.Optional[int] = None):
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
all_quantized_out = []
all_indices = []
all_losses = []
for mod, item in zip(self.rvqs, x_lst):
quantized_out, out_indices, out_losses = mod(item, n_q)
all_quantized_out.append(quantized_out)
all_indices.append(out_indices)
all_losses.append(out_losses)
out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)

View File

@@ -0,0 +1,357 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sox
import copy
import torch
import operator
import onnxruntime
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as kaldi
from librosa.filters import mel as librosa_mel_fn
from itertools import accumulate
from typing import List
from torch import Tensor
from .core_vq import DistributedGroupResidualVectorQuantization
from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
class MelSpectrogramFeatures(nn.Module):
"""
Calculate the BigVGAN style mel spectrogram of an input signal.
Args:
filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
Returns:
torch.Tensor: Mel spectrogram.
"""
def __init__(self,
filter_length=1024,
hop_length=160,
win_length=640,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=16000,
sampling_rate_org=None,
padding='center',
use_db = False,
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate
self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
self.mel_basis = {}
self.hann_window = {}
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
with torch.no_grad():
feats = self.extract(audio, **kwargs)
return feats
def extract(self, audio, **kwargs):
if len(audio.shape) == 3:
audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
assert len(audio.shape) == 2
y = audio
if len(list(self.mel_basis.keys())) == 0:
mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
class XVectorExtractor(nn.Module):
def __init__(self, audio_codec_with_xvector):
super().__init__()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
self.tfm = sox.Transformer()
self.tfm.norm(db_level=-6)
self.mel_ext = MelSpectrogramFeatures(
filter_length=1024,
hop_length=160,
win_length=640,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=16000
)
def extract_code(self, audio):
with torch.no_grad():
norm_audio = self.sox_norm(audio)
norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
feat = kaldi.fbank(norm_audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
ref_mel = self.mel_ext.extract(audio=norm_audio)
return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
def sox_norm(self, audio):
wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
return wav_norm
class WhisperEncoderVQ(WhisperEncoder):
def __init__(
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
n_window: int = 1500,
output_dim: int = 512,
grad_checkpointing: bool = False,
enable_mp: bool = False,
audio_sequence_parallel: bool = False,
audio_vq_layers: int = -1,
audio_vq_type: str = "NULL",
audio_vq_codebook_size: int = 4096,
audio_vq_pe: bool = False,
audio_vq_commit_loss: float = 0.0,
audio_vq_out_commit_loss: float = 0.0,
audio_vq_no_quantize: bool = False,
audio_vq_ff_layer: int = 0,
audio_vq_threshold_ema_dead_code: float = 0.1,
audio_vq_codebook_dim: int = None,
audio_vq_ds_rate: int = None,
):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
self.audio_vq_layers = audio_vq_layers
self.audio_vq_type = audio_vq_type
self.audio_vq_codebook_size = audio_vq_codebook_size
self.audio_vq_pe = audio_vq_pe
self.audio_vq_commit_loss = audio_vq_commit_loss
self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
self.audio_vq_no_quantize = audio_vq_no_quantize
self.audio_vq_ff_layer = audio_vq_ff_layer
if audio_vq_layers > 0:
self.vq_feature_dim = self.n_state
self.audio_vq_ds_rate = 1
else:
raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
if self.audio_vq_ds_rate == audio_vq_ds_rate:
self.audio_vq_downsample = nn.Identity()
self.audio_vq_upsample = nn.Identity()
else:
assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
stride = audio_vq_ds_rate // self.audio_vq_ds_rate
self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
self.audio_vq_ds_rate = audio_vq_ds_rate
if audio_vq_type == "GRVQ":
self.audio_quantizer = DistributedGroupResidualVectorQuantization(
codebook_size = audio_vq_codebook_size,
dim = self.vq_feature_dim,
codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
num_groups=1,
num_quantizers=1,
kmeans_init=False,
threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
)
else:
raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
if self.audio_vq_pe:
self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
def _calc_quantize_activities(self, indices):
indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
vq_num_activities = sum(indices_onehot>0)
vq_num_tokens = sum(indices_onehot)
return {
"vq_num_activities": vq_num_activities,
"vq_num_tokens": vq_num_tokens,
}
def _do_quantize(self, x, pe=None, y=None):
"""
x: torch.Tensor, shape = (T, D)
q: torch.Tensor, shape = (T, D)
i: torch.Tensor, shape = (T)
"""
if self.audio_vq_out_commit_loss > 0:
x_teacher = x.clone()
x = x.unsqueeze(0)
x = self.audio_vq_downsample(x.transpose(1, 2))
x = x.transpose(1, 2)
vq_stats = {}
if self.audio_vq_type == "GRVQ":
if self.training:
raise NotImplementedError
else:
indices = self.audio_quantizer.encode(x)
x = self.audio_quantizer.decode(indices)
indices = indices.squeeze(2).squeeze(1)
vq_stats.update(self._calc_quantize_activities(indices))
x, indices = x.squeeze(0), indices.squeeze(0)
if self.audio_vq_pe:
x = x + pe
x = self.project_after_vq_pe(x)
x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
x = x.transpose(1, 2).squeeze(0)
if self.audio_vq_out_commit_loss > 0:
vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
return x, indices, vq_stats
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
"""
x : torch.Tensor, shape = (n_mels, n_ctx)
the mel spectrogram of the audio
"""
aftercnn_x_list = []
pe_for_vq_list = []
for each_x in x_list:
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
for each_x_split in each_x_split_list:
each_x_split = F.gelu(self.conv1(each_x_split))
each_x_split = F.gelu(self.conv2(each_x_split))
each_x_split = each_x_split.permute(1, 0) # L,D
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
x = torch.cat(aftercnn_x_list, dim=0)
src_len = x.size(0)
output_list = []
for item in audio_aftercnnlens:
while item > self.n_window:
output_list.append(self.n_window)
item -= self.n_window
output_list.append(item)
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
layer_id = 0
for block in self.blocks:
layer_id+=1
x = block(x, cu_seqlens=cu_seqlens)
if self.audio_vq_layers == layer_id: # vq inside encoder
x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
if return_indices:
return x, indices
if self.avg_pooler:
x_list = x.split(audio_aftercnnlens, dim=0)
token_x_list = []
for x in x_list:
x = x.permute(1, 0)
x = self.avg_pooler(x)
x = x.permute(1, 0)
token_x_list.append(x)
x = torch.cat(token_x_list, dim=0)
x = self.ln_post(x)
x = self.proj(x)
output = torch.zeros(
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
device=x.device, dtype=x.dtype
)
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
audio_tokens_mask[start_ids] = False
audio_tokens_mask[end_ids] = False
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
output[audio_tokens_mask] = x
if self.audio_vq_type != "NULL":
return output, vq_stats
return output

View File

@@ -0,0 +1,406 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import torch
import operator
import numpy as np
import torch.nn.functional as F
from functools import lru_cache
from typing import Optional, Union, List
from torch import nn, Tensor
from itertools import accumulate
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
except ImportError:
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
except ImportError:
print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
flash_attn_varlen_func = None
N_FFT = 400
HOP_LENGTH = 160
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def get_T_after_cnn(L_in, dilation=1):
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
return L_out
def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
audio_len = len(audio)
if padding:
reduction = 160 * 2 * audio_vq_ds_rate
audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
else:
mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
return mel
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
class ConvTranspose1d(nn.ConvTranspose1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.use_flash_attention = True
def forward(
self,
x: Tensor,
cu_seqlens = None,
):
q = self.query(x)
k = self.key(x)
v = self.value(x)
if self.use_flash_attention:
if flash_attn_varlen_func is None:
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
else:
if q.dtype not in [torch.float16, torch.bfloat16]:
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
self.use_flash_attention = False
else:
x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
else:
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
output = self.out(x)
return output
def qkv_flash_attention(
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
):
n_ctx, n_state = q.shape
# scale = (n_state // self.n_head) ** -0.25
q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
k = k.view(n_ctx, self.n_head, -1)
v = v.view(n_ctx, self.n_head, -1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
x = flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
)
x = x.reshape(n_ctx, n_state)
return x
def qkv_attention_manual(
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
):
n_ctx, n_state = q.shape
head_dim = n_state // self.n_head
scale = head_dim ** -0.5
q = q.view(n_ctx, self.n_head, head_dim)
k = k.view(n_ctx, self.n_head, head_dim)
v = v.view(n_ctx, self.n_head, head_dim)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = len(seqlens)
max_seqlen = max(seqlens)
q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
k_padded = torch.zeros_like(q_padded)
v_padded = torch.zeros_like(q_padded)
for i in range(batch_size):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i+1]
seq_len = seqlens[i]
q_padded[i, :seq_len] = q[start_idx:end_idx]
k_padded[i, :seq_len] = k[start_idx:end_idx]
v_padded[i, :seq_len] = v[start_idx:end_idx]
q_padded = q_padded.transpose(1, 2)
k_padded = k_padded.transpose(1, 2)
v_padded = v_padded.transpose(1, 2)
attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
attn_scores = attn_scores + attn_mask
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v_padded)
context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
assert output_packed.shape == (n_ctx, n_state)
return output_packed
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int,
enable_mp: bool = False, sequence_parallel: bool = False):
super().__init__()
n_mlp = n_state * 4
self.attn_ln = nn.LayerNorm(n_state)
self.mlp_ln = nn.LayerNorm(n_state)
self.attn = MultiHeadAttention(n_state, n_head)
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
def forward(
self,
x: Tensor,
cu_seqlens = None
):
x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
x = x + self.mlp(self.mlp_ln(x))
return x
class WhisperEncoder(nn.Module):
def __init__(
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
n_window: int = 1500,
output_dim: int = 512,
grad_checkpointing: bool = False,
enable_mp: bool = False,
audio_sequence_parallel: bool = False,
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.n_layer = n_layer
self.n_mels = n_mels
self.blocks = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
for _ in range(n_layer)]
)
self.ln_post = nn.LayerNorm(n_state)
self.avg_pooler = nn.AvgPool1d(2, stride=2)
self.proj = torch.nn.Linear(n_state, output_dim)
self.audio_bos_eos_token = nn.Embedding(2, output_dim)
self.output_dim = output_dim
self.grad_checkpointing = grad_checkpointing
self.enable_mp = enable_mp
self.n_head = n_head
self.n_state = n_state
self.n_window = n_window
self.audio_sequence_parallel = audio_sequence_parallel
self.tp_world_size = 1
self.set_audio_sync()
def set_audio_sync(self):
for name, param in self.named_parameters():
if not name.startswith("blocks"):
setattr(param, "audio_sync", True)
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
"""
x : torch.Tensor, shape = (n_mels, n_ctx)
the mel spectrogram of the audio
"""
aftercnn_x_list = []
for each_x in x_list:
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
for each_x_split in each_x_split_list:
each_x_split = F.gelu(self.conv1(each_x_split))
each_x_split = F.gelu(self.conv2(each_x_split))
each_x_split = each_x_split.permute(1, 0) # L,D
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
x = torch.cat(aftercnn_x_list, dim=0)
src_len = x.size(0)
output_list = []
for item in audio_aftercnnlens:
while item > self.n_window:
output_list.append(self.n_window)
item -= self.n_window
output_list.append(item)
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
layer_id = 0
for block in self.blocks:
layer_id+=1
x = block(x, cu_seqlens=cu_seqlens)
if self.avg_pooler:
x_list = x.split(audio_aftercnnlens, dim=0)
token_x_list = []
for x in x_list:
x = x.permute(1, 0)
x = self.avg_pooler(x)
x = x.permute(1, 0)
token_x_list.append(x)
x = torch.cat(token_x_list, dim=0)
x = self.ln_post(x)
x = self.proj(x)
output = torch.zeros(
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
device=x.device, dtype=x.dtype
)
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
audio_tokens_mask[start_ids] = False
audio_tokens_mask[end_ids] = False
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
output[audio_tokens_mask] = x
return output
def lock(self, layers: int):
self.conv1.requires_grad_(False)
self.conv2.requires_grad_(False)
for i in range(min(layers, len(self.blocks))):
self.blocks[i].requires_grad_(False)