init commit

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-26 15:34:31 +08:00
commit 80513a3258
141 changed files with 24966 additions and 0 deletions

19
qwen_tts/core/__init__.py Normal file
View File

@@ -0,0 +1,19 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config
from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Model

View File

@@ -0,0 +1,18 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_qwen3_tts import Qwen3TTSConfig
from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
from .processing_qwen3_tts import Qwen3TTSProcessor

View File

@@ -0,0 +1,502 @@
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`].
It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model
architecture. The architecture is based on the ECAPA-TDNN model.
Args:
mel_dim (`int`, *optional*, defaults to 128):
The dimension of the input mel-spectrogram.
enc_dim (`int`, *optional*, defaults to 192):
The dimension of the final speaker embedding.
enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`):
A list of output channels for each TDNN/SERes2Net layer in the encoder. The first channel size is for the initial TDNN layer,
the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, and the last one for the multi-layer feature aggregation.
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`.
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
A list of dilations for each layer in the encoder, corresponding to `enc_channels`.
enc_attention_channels (`int`, *optional*, defaults to 128):
The number of attention channels in the `AttentiveStatisticsPooling` layer.
enc_res2net_scale (`int`, *optional*,defaults to 8):
The scale of the `Res2NetBlock` in the encoder.
enc_se_channels (`int`, *optional*, defaults to 128):
The number of channels in the squeeze part of the `SqueezeExcitationBlock`.
"""
def __init__(
self,
mel_dim=128,
enc_dim=1024,
enc_channels=[512, 512, 512, 512, 1536],
enc_kernel_sizes=[5, 3, 3, 3, 1],
enc_dilations=[1, 2, 3, 4, 1],
enc_attention_channels=128,
enc_res2net_scale=8,
enc_se_channels=128,
sample_rate=24000,
):
self.mel_dim = mel_dim
self.enc_dim = enc_dim
self.enc_channels = enc_channels
self.enc_kernel_sizes = enc_kernel_sizes
self.enc_dilations = enc_dilations
self.enc_attention_channels = enc_attention_channels
self.enc_res2net_scale = enc_res2net_scale
self.enc_se_channels = enc_se_channels
self.sample_rate = sample_rate
class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. It is used to instantiate a
Qwen3TTSTalkerCodePredictor model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3TTSTalkerCodePredictor model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
additional layer afterwards will use SWA (Sliding Window Attention).
layer_types (`list`, *optional*):
Attention pattern for each layer.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type = "qwen3_tts_talker_code_predictor"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=2048,
hidden_size=1024,
intermediate_size=3072,
num_hidden_layers=5,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=0.000001,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000,
rope_scaling=None,
attention_bias=False,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
layer_types=None,
attention_dropout=0,
num_code_groups=32,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
self.num_code_groups = num_code_groups
class Qwen3TTSTalkerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a
Qwen3TTSTalker model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3TTSTalker model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3TTSTalkerModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6144):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type = "qwen3_tts_talker"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3TTSTalker`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig}
def __init__(
self,
code_predictor_config=None,
vocab_size=3072,
hidden_size=1024,
intermediate_size=2048,
num_hidden_layers=20,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=0.000001,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000,
rope_scaling=None,
attention_bias=False,
use_sliding_window=False,
sliding_window=4096,
attention_dropout=0,
num_code_groups=32,
text_hidden_size=2048,
codec_eos_token_id=4198,
codec_think_id=4202,
codec_nothink_id=4203,
codec_think_bos_id=4204,
codec_think_eos_id=4205,
codec_pad_id=4196,
codec_bos_id=4197,
spk_id=None,
spk_is_dialect=None,
codec_language_id=None,
**kwargs,
):
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
if code_predictor_config is None:
code_predictor_config = {}
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig()
logger.info("code_predictor_config is None. Initializing code_predictor model with default values")
elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig):
self.code_predictor_config = code_predictor_config
else:
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config)
self.num_code_groups = num_code_groups
self.text_hidden_size = text_hidden_size
self.codec_eos_token_id = codec_eos_token_id
self.codec_think_id = codec_think_id
self.codec_language_id = codec_language_id
self.codec_nothink_id = codec_nothink_id
self.codec_think_bos_id = codec_think_bos_id
self.codec_think_eos_id = codec_think_eos_id
self.codec_pad_id = codec_pad_id
self.codec_bos_id = codec_bos_id
self.spk_id = spk_id
self.spk_is_dialect = spk_is_dialect
class Qwen3TTSConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`].
"""
model_type = "qwen3_tts"
sub_configs = {
"talker_config": Qwen3TTSTalkerConfig,
"speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig,
}
def __init__(
self,
talker_config=None,
speaker_encoder_config=None,
tokenizer_type=None,
tts_model_size=None,
tts_model_type=None,
im_start_token_id=151644,
im_end_token_id=151645,
tts_pad_token_id=151671,
tts_bos_token_id=151672,
tts_eos_token_id=151673,
**kwargs,
):
super().__init__(**kwargs)
if talker_config is None:
talker_config = {}
logger.info("talker_config is None. Initializing talker model with default values")
if speaker_encoder_config is None:
speaker_encoder_config = {}
logger.info("speaker_encoder_config is None. Initializing talker model with default values")
self.talker_config = Qwen3TTSTalkerConfig(**talker_config)
self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config)
self.tokenizer_type = tokenizer_type
self.tts_model_size = tts_model_size
self.tts_model_type = tts_model_type
self.im_start_token_id = im_start_token_id
self.im_end_token_id = im_end_token_id
self.tts_pad_token_id = tts_pad_token_id
self.tts_bos_token_id = tts_bos_token_id
self.tts_eos_token_id = tts_eos_token_id
__all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"padding_side": "left",
}
}
class Qwen3TTSProcessor(ProcessorMixin):
r"""
Constructs a Qwen3TTS processor.
Args:
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The text tokenizer.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
"""
attributes = ["tokenizer"]
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self, tokenizer=None, chat_template=None
):
super().__init__(tokenizer, chat_template=chat_template)
def __call__(self, text=None, **kwargs) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
"""
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
output_kwargs = self._merge_kwargs(
Qwen3TTSProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if not isinstance(text, list):
text = [text]
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(
data={**texts_inputs},
tensor_type=kwargs.get("return_tensors"),
)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
if isinstance(conversations[0], dict):
conversations = [conversations]
return super().apply_chat_template(conversations, chat_template, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
)
)
__all__ = ["Qwen3TTSProcessor"]

View File

@@ -0,0 +1,172 @@
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3TTSTokenizerV2 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import MimiConfig
logger = logging.get_logger(__name__)
class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
codebook_size (`int`, *optional*, defaults to 2048):
Number of entries in each residual codebook used for acoustic token quantization.
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder.
max_position_embeddings (`int`, *optional*, defaults to 8000):
Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period for rotary position embeddings (RoPE) applied to attention layers.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
Number of key and value attention heads used in grouped-query attention (if applicable).
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the attention projection layers.
sliding_window (`int`, *optional*, defaults to 72):
Window size for local attention mechanism, limiting attention context to improve efficiency.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the feed-forward (intermediate) layer in each transformer block.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function used in the feed-forward layers. Supports `"silu"`, `"relu"`, `"gelu"`, etc.
layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
Initial value for LayerScale applied in transformer blocks, helping stabilize training.
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
Epsilon value for RMS normalization layers to prevent division by zero.
num_hidden_layers (`int`, *optional*, defaults to 8):
Number of transformer blocks in the autoregressive decoder.
num_quantizers (`int`, *optional*, defaults to 16):
Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction.
upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`):
Rate at which features are upsampled in the final waveform synthesis stage.
upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`):
Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform.
decoder_dim (`int`, *optional*, defaults to 1536):
Final dimensionality of the decoder's output before waveform generation.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability applied to attention weights in the decoder.
"""
def __init__(
self,
codebook_size=2048,
hidden_size=1024,
latent_dim=1024,
max_position_embeddings=8000,
rope_theta=10000,
num_attention_heads=16,
num_key_value_heads=16,
attention_bias=False,
sliding_window=72,
intermediate_size=3072,
hidden_act="silu",
layer_scale_initial_scale=0.01,
rms_norm_eps=1e-5,
num_hidden_layers=8,
num_quantizers=16,
upsample_rates=(8, 5, 4, 3),
upsampling_ratios=(2, 2),
decoder_dim=1536,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.codebook_size = codebook_size
self.hidden_size = hidden_size
self.latent_dim = latent_dim
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.attention_bias = attention_bias
self.sliding_window = sliding_window
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.layer_scale_initial_scale = layer_scale_initial_scale
self.rms_norm_eps = rms_norm_eps
self.num_hidden_layers = num_hidden_layers
self.num_quantizers = num_quantizers
self.upsample_rates = upsample_rates
self.upsampling_ratios = upsampling_ratios
self.decoder_dim = decoder_dim
self.attention_dropout = attention_dropout
@property
def layer_types(self):
"""
All layer in code2wav should be sliding attention
"""
return ["sliding_attention"] * self.num_hidden_layers
class Qwen3TTSTokenizerV2Config(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. It is used to instantiate a Qwen3TTSTokenizerV2Model
model according to the specified sub-models configurations, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
"""
model_type = "qwen3_tts_tokenizer_12hz"
sub_configs = {
"encoder_config": MimiConfig,
"decoder_config": Qwen3TTSTokenizerV2DecoderConfig,
}
def __init__(
self,
encoder_config=None,
decoder_config=None,
encoder_valid_num_quantizers=16,
input_sample_rate=24000,
output_sample_rate=24000,
decode_upsample_rate=1920,
encode_downsample_rate=1920,
**kwargs,
):
super().__init__(**kwargs)
if encoder_config is None:
encoder_config = {}
logger.info("encoder_config is None. Initializing encoder with default values")
if decoder_config is None:
decoder_config = {}
logger.info("decoder_config is None. Initializing decoder with default values")
self.encoder_config = MimiConfig(**encoder_config)
self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config)
self.encoder_valid_num_quantizers = encoder_valid_num_quantizers
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.decode_upsample_rate = decode_upsample_rate
self.encode_downsample_rate = encode_downsample_rate
__all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,332 @@
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3TTSTokenizerV1 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT.
It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
The dimension of the model.
num_hidden_layers (`int`, *optional*, defaults to 22):
The number of transformer blocks in the DiT model.
num_attention_heads (`int`, *optional*, defaults to 16):
The number of attention heads in each transformer block.
ff_mult (`int`, *optional*, defaults to 2):
The multiplier for the feedforward layer in each transformer block.
emb_dim (`int`, *optional*, defaults to 512):
The dimension of the embedding layer.
head_dim (`int`, *optional*, defaults to 64):
The dimension of each attention head.
repeats (`int`, *optional*, defaults to 2):
The number of times the codec embeddings are repeated.
num_embeds (`int`, *optional*, defaults to 8193):
The number of unique embeddings in the codec.
mel_dim (`int`, *optional*, defaults to 80):
The dimension of the mel-spectrogram.
dropout (`float`, *optional*, defaults to 0.1):
The dropout rate for the transformer blocks.
enc_emb_dim (`int`, *optional*, defaults to 192):
The dimension of the pre-trained speaker embedding.
enc_dim (`int`, *optional*, defaults to 128):
The dimension of the encoder output.
enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
A list of output channels for each TDNN/SERes2Net layer in the encoder.
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
A list of kernel sizes for each layer in the encoder.
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
A list of dilations for each layer in the encoder.
enc_attention_channels (`int`, *optional*, defaults to 64):
The number of attention channels in the SqueezeExcitationBlock.
enc_res2net_scale (`int`, *optional*, defaults to 2):
The scale of the Res2Net block in the encoder.
enc_se_channels (`int`, *optional*, defaults to 64):
The number of output channels after squeeze in the SqueezeExcitationBlock.
"""
model_type = "qwen3_tts_tokenizer_v1_decoder_dit"
def __init__(
self,
hidden_size=1024,
num_hidden_layers=22,
num_attention_heads=16,
ff_mult=2,
emb_dim=512,
head_dim=64,
rope_theta=10000.0,
max_position_embeddings=32768,
block_size=24,
look_ahead_layers=[10],
look_backward_layers=[0, 20],
repeats=2,
num_embeds=8193,
mel_dim=80,
dropout=0.1,
enc_emb_dim=192,
enc_dim=128,
enc_channels=[256, 256, 256, 256, 768],
enc_kernel_sizes=[5, 3, 3, 3, 1],
enc_dilations=[1, 2, 3, 4, 1],
enc_attention_channels=64,
enc_res2net_scale=2,
enc_se_channels=64,
**kwargs,
):
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.ff_mult = ff_mult
self.emb_dim = emb_dim
self.head_dim = head_dim
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.block_size = block_size
self.look_ahead_layers = look_ahead_layers
self.look_backward_layers = look_backward_layers
self.repeats = repeats
self.num_embeds = num_embeds
self.mel_dim = mel_dim
self.dropout = dropout
self.enc_emb_dim = enc_emb_dim
self.enc_dim = enc_dim
self.enc_channels = enc_channels
self.enc_kernel_sizes = enc_kernel_sizes
self.enc_dilations = enc_dilations
self.enc_attention_channels = enc_attention_channels
self.enc_res2net_scale = enc_res2net_scale
self.enc_se_channels = enc_se_channels
super().__init__(**kwargs)
class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module.
It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
Args:
mel_dim (`int`, *optional*, defaults to 80):
The dimension of the mel-spectrogram.
upsample_initial_channel (`int`, *optional*, defaults to 1536):
The number of channels in the initial upsampling layer.
resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
A list of kernel sizes for each residual block.
resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
A list of dilation sizes for each residual block.
upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
A list of upsampling rates for each upsampling layer.
upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
A list of kernel sizes for each upsampling layer.
"""
model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan"
def __init__(
self,
mel_dim=80,
upsample_initial_channel=1536,
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[5, 3, 2, 2, 2, 2],
upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
**kwargs,
):
self.mel_dim = mel_dim
self.upsample_initial_channel = upsample_initial_channel
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
super().__init__(**kwargs)
class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
dit_config ([`DiT_Args`], *optional*):
Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
bigvgan_config ([`BigVGAN_Args`], *optional*):
Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
"""
model_type = "qwen3_tts_tokenizer_v1_decoder"
sub_configs = {
"dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig,
"bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig,
}
def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
if dit_config is None:
dit_config = {}
if bigvgan_config is None:
bigvgan_config = {}
self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config)
self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config)
super().__init__(**kwargs)
class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder.
The encoder typically takes mel-spectrogram features and produces high-level audio representations, then (optionally)
applies an Audio-VQ module (e.g., GRVQ) to discretize continuous representations into codes.
Args:
n_mels (`int`, *optional*, defaults to 128):
Number of mel bins in the input mel-spectrogram.
n_ctx (`int`, *optional*, defaults to 1500):
Maximum input sequence length (in frames/tokens) for the encoder.
n_state (`int`, *optional*, defaults to 1280):
Hidden size (model dimension) of the encoder transformer.
n_head (`int`, *optional*, defaults to 20):
Number of attention heads in each transformer layer.
n_layer (`int`, *optional*, defaults to 32):
Number of transformer layers.
n_window (`int`, *optional*, defaults to 100):
Window size used by the model for local attention / chunking (implementation-dependent).
output_dim (`int`, *optional*, defaults to 3584):
Output feature dimension produced by the encoder head (before/after projection, implementation-dependent).
grad_checkpointing (`bool`, *optional*, defaults to `False`):
Whether to enable gradient checkpointing to reduce memory usage during training.
enable_mp (`bool`, *optional*, defaults to `False`):
Whether to enable model parallel features (implementation-dependent).
audio_sequence_parallel (`bool`, *optional*, defaults to `False`):
Whether to enable sequence parallelism for audio branch (implementation-dependent).
audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`):
Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc.
audio_vq_layers (`int`, *optional*, defaults to 6):
Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs).
audio_vq_codebook_size (`int`, *optional*, defaults to 32768):
Size of each codebook (number of entries).
audio_vq_codebook_dim (`int`, *optional*, defaults to 1280):
Dimension of codebook vectors (often equals encoder hidden size).
audio_vq_pe (`bool`, *optional*, defaults to `True`):
Whether to use positional encoding (or position embeddings) inside the VQ module.
audio_vq_ds_rate (`int`, *optional*, defaults to 2):
Downsampling rate applied before VQ (e.g., temporal downsample factor).
"""
model_type = "qwen3_tts_tokenizer_v1_encoder"
def __init__(
self,
n_mels=128,
n_ctx=1500,
n_state=1280,
n_head=20,
n_layer=32,
n_window=100,
output_dim=3584,
grad_checkpointing=False,
enable_mp=False,
audio_sequence_parallel=False,
audio_vq_type="GRVQ",
audio_vq_layers=6,
audio_vq_codebook_size=32768,
audio_vq_codebook_dim=1280,
audio_vq_pe=True,
audio_vq_ds_rate=2,
**kwargs,
):
super().__init__(**kwargs)
self.n_mels = n_mels
self.n_ctx = n_ctx
self.n_state = n_state
self.n_head = n_head
self.n_layer = n_layer
self.n_window = n_window
self.output_dim = output_dim
self.grad_checkpointing = grad_checkpointing
self.enable_mp = enable_mp
self.audio_sequence_parallel = audio_sequence_parallel
self.audio_vq_type = audio_vq_type
self.audio_vq_layers = audio_vq_layers
self.audio_vq_codebook_size = audio_vq_codebook_size
self.audio_vq_codebook_dim = audio_vq_codebook_dim
self.audio_vq_pe = audio_vq_pe
self.audio_vq_ds_rate = audio_vq_ds_rate
class Qwen3TTSTokenizerV1Config(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. It is used to instantiate a Qwen3TTSTokenizerV1Model
model according to the specified sub-models configurations, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
"""
model_type = "qwen3_tts_tokenizer_25hz"
sub_configs = {
"encoder_config": Qwen3TTSTokenizerV1EncoderConfig,
"decoder_config": Qwen3TTSTokenizerV1DecoderConfig,
}
def __init__(
self,
encoder_config=None,
decoder_config=None,
input_sample_rate=24000,
output_sample_rate=24000,
decode_upsample_rate=1920,
encode_downsample_rate=1920,
**kwargs,
):
super().__init__(**kwargs)
if encoder_config is None:
encoder_config = {}
logger.info("encoder_config is None. Initializing encoder with default values")
if decoder_config is None:
decoder_config = {}
logger.info("decoder_config is None. Initializing decoder with default values")
self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config)
self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config)
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.decode_upsample_rate = decode_upsample_rate
self.encode_downsample_rate = encode_downsample_rate
__all__ = [
"Qwen3TTSTokenizerV1Config",
"Qwen3TTSTokenizerV1EncoderConfig",
"Qwen3TTSTokenizerV1DecoderConfig",
"Qwen3TTSTokenizerV1DecoderBigVGANConfig",
"Qwen3TTSTokenizerV1DecoderDiTConfig"
]

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@@ -0,0 +1,523 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import random
import typing as tp
from random import randrange
import numpy as np
from einops import rearrange, repeat
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
@torch.no_grad()
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
dists = -(
samples.pow(2).sum(1, keepdim=True)
- 2 * torch.matmul(samples, means.t())
+ means.t().pow(2).sum(0, keepdim=True)
)
buckets = dists.max(dim=-1).indices
del dists
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
def preprocess(x):
x = rearrange(x, "... d -> (...) d")
return x
def postprocess_emb(embed_ind, shape):
return embed_ind.view(*shape[:-1])
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: float = 2.0,
):
super().__init__()
self.decay = decay
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.inited = None
self.cluster_size = None
self.embed = None
self.embed_avg = None
self.training = True
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
# distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
expired_codes = cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
else:
print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
# sync buffers outside for efficiency
# distrib.broadcast_tensors(self.buffers())
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
shape = x.shape
# pre-process
x = preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
shape, dtype = x.shape, x.dtype
x = preprocess(x)
self.init_embed_(x)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
# Note: after ema update, there is a very small difference between codebooks on GPUs.
# The impact can be very small, ignore it.
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently, supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: float = 2.0,
commitment_weight: float = 1.,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
self.training = True
@property
def codebook(self):
return self._codebook.embed
def encode(self, x, buffers):
# x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x, buffers)
return embed_in
def decode(self, embed_ind, buffers):
quantize = self._codebook.decode(embed_ind, buffers)
quantize = self.project_out(quantize)
# quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x, buffers):
device = x.device
# x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x, buffers)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
# quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class DistributedResidualVectorQuantization(nn.Module):
"""Efficient distributed residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
"""
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
"""
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
kmeans_init = kwargs["kmeans_init"]
if isinstance(kmeans_init, bool):
if not kwargs["kmeans_init"]:
# use uniform init
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
inited = True
else:
# to perform kmeans init on first batch
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
inited = False
elif isinstance(kmeans_init, str):
# use prepared kmeans init
embed = np.load(kmeans_init)
embed = torch.from_numpy(embed)
if embed.dim() == 2:
embed = embed.unsqueeze(0)
inited = True
else:
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
self.q0_ds_ratio = 1
if "q0_ds_ratio" in kwargs:
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
self.layers = nn.ModuleList()
for i in range(num_quantizers):
vq_args = dict(**kwargs)
vq = VectorQuantization(**vq_args)
self.layers.append(vq)
self.quantize_dropout = quantize_dropout
self.rand_num_quant = rand_num_quant
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = torch.zeros_like(x)
residual = x
bb, cc, tt = x.shape
device = x.device
all_losses = []
all_indices = []
all_sub_quants = []
n_q = n_q or len(self.layers)
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
if should_quantize_dropout:
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
null_indices_shape = (x.shape[0], x.shape[2])
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
for quantizer_index, layer in enumerate(self.layers[:n_q]):
# dropout except the first quantizer
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
all_sub_quants.append(null_sub_quant)
continue
quant_in = residual
if self.q0_ds_ratio > 1 and quantizer_index == 0:
quant_in = F.interpolate(quant_in, size=[tt//2])
quantized, indices, loss = layer(quant_in, [
self.inited[quantizer_index],
self.cluster_size[quantizer_index],
self.embed[quantizer_index],
self.embed_avg[quantizer_index]
])
if self.q0_ds_ratio > 1 and quantizer_index == 0:
quantized = F.interpolate(quantized, size=[tt])
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
all_sub_quants.append(quantized)
# sync buffers after one forward step
# distrib.broadcast_tensors(self.buffers())
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
indices = layer.encode(residual, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
quantized = layer.decode(indices, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
quantized_out = quantized_out + quantized
return quantized_out
class DistributedGroupResidualVectorQuantization(nn.Module):
"""Efficient distributed group residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
Group Then rvq
"""
def __init__(self, *,
num_groups,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
self.rvqs = nn.ModuleList(
[
DistributedResidualVectorQuantization(
num_quantizers=num_quantizers,
quantize_dropout=quantize_dropout,
rand_num_quant=rand_num_quant,
**kwargs
)
for _ in range(num_groups)
]
)
self.num_groups = num_groups
def forward(self, x, n_q: tp.Optional[int] = None):
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
all_quantized_out = []
all_indices = []
all_losses = []
for mod, item in zip(self.rvqs, x_lst):
quantized_out, out_indices, out_losses = mod(item, n_q)
all_quantized_out.append(quantized_out)
all_indices.append(out_indices)
all_losses.append(out_losses)
out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)

View File

@@ -0,0 +1,357 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sox
import copy
import torch
import operator
import onnxruntime
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as kaldi
from librosa.filters import mel as librosa_mel_fn
from itertools import accumulate
from typing import List
from torch import Tensor
from .core_vq import DistributedGroupResidualVectorQuantization
from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
class MelSpectrogramFeatures(nn.Module):
"""
Calculate the BigVGAN style mel spectrogram of an input signal.
Args:
filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
Returns:
torch.Tensor: Mel spectrogram.
"""
def __init__(self,
filter_length=1024,
hop_length=160,
win_length=640,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=16000,
sampling_rate_org=None,
padding='center',
use_db = False,
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate
self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
self.mel_basis = {}
self.hann_window = {}
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
with torch.no_grad():
feats = self.extract(audio, **kwargs)
return feats
def extract(self, audio, **kwargs):
if len(audio.shape) == 3:
audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
assert len(audio.shape) == 2
y = audio
if len(list(self.mel_basis.keys())) == 0:
mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
class XVectorExtractor(nn.Module):
def __init__(self, audio_codec_with_xvector):
super().__init__()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
self.tfm = sox.Transformer()
self.tfm.norm(db_level=-6)
self.mel_ext = MelSpectrogramFeatures(
filter_length=1024,
hop_length=160,
win_length=640,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=16000
)
def extract_code(self, audio):
with torch.no_grad():
norm_audio = self.sox_norm(audio)
norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
feat = kaldi.fbank(norm_audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
ref_mel = self.mel_ext.extract(audio=norm_audio)
return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
def sox_norm(self, audio):
wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
return wav_norm
class WhisperEncoderVQ(WhisperEncoder):
def __init__(
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
n_window: int = 1500,
output_dim: int = 512,
grad_checkpointing: bool = False,
enable_mp: bool = False,
audio_sequence_parallel: bool = False,
audio_vq_layers: int = -1,
audio_vq_type: str = "NULL",
audio_vq_codebook_size: int = 4096,
audio_vq_pe: bool = False,
audio_vq_commit_loss: float = 0.0,
audio_vq_out_commit_loss: float = 0.0,
audio_vq_no_quantize: bool = False,
audio_vq_ff_layer: int = 0,
audio_vq_threshold_ema_dead_code: float = 0.1,
audio_vq_codebook_dim: int = None,
audio_vq_ds_rate: int = None,
):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
self.audio_vq_layers = audio_vq_layers
self.audio_vq_type = audio_vq_type
self.audio_vq_codebook_size = audio_vq_codebook_size
self.audio_vq_pe = audio_vq_pe
self.audio_vq_commit_loss = audio_vq_commit_loss
self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
self.audio_vq_no_quantize = audio_vq_no_quantize
self.audio_vq_ff_layer = audio_vq_ff_layer
if audio_vq_layers > 0:
self.vq_feature_dim = self.n_state
self.audio_vq_ds_rate = 1
else:
raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
if self.audio_vq_ds_rate == audio_vq_ds_rate:
self.audio_vq_downsample = nn.Identity()
self.audio_vq_upsample = nn.Identity()
else:
assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
stride = audio_vq_ds_rate // self.audio_vq_ds_rate
self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
self.audio_vq_ds_rate = audio_vq_ds_rate
if audio_vq_type == "GRVQ":
self.audio_quantizer = DistributedGroupResidualVectorQuantization(
codebook_size = audio_vq_codebook_size,
dim = self.vq_feature_dim,
codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
num_groups=1,
num_quantizers=1,
kmeans_init=False,
threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
)
else:
raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
if self.audio_vq_pe:
self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
def _calc_quantize_activities(self, indices):
indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
vq_num_activities = sum(indices_onehot>0)
vq_num_tokens = sum(indices_onehot)
return {
"vq_num_activities": vq_num_activities,
"vq_num_tokens": vq_num_tokens,
}
def _do_quantize(self, x, pe=None, y=None):
"""
x: torch.Tensor, shape = (T, D)
q: torch.Tensor, shape = (T, D)
i: torch.Tensor, shape = (T)
"""
if self.audio_vq_out_commit_loss > 0:
x_teacher = x.clone()
x = x.unsqueeze(0)
x = self.audio_vq_downsample(x.transpose(1, 2))
x = x.transpose(1, 2)
vq_stats = {}
if self.audio_vq_type == "GRVQ":
if self.training:
raise NotImplementedError
else:
indices = self.audio_quantizer.encode(x)
x = self.audio_quantizer.decode(indices)
indices = indices.squeeze(2).squeeze(1)
vq_stats.update(self._calc_quantize_activities(indices))
x, indices = x.squeeze(0), indices.squeeze(0)
if self.audio_vq_pe:
x = x + pe
x = self.project_after_vq_pe(x)
x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
x = x.transpose(1, 2).squeeze(0)
if self.audio_vq_out_commit_loss > 0:
vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
return x, indices, vq_stats
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
"""
x : torch.Tensor, shape = (n_mels, n_ctx)
the mel spectrogram of the audio
"""
aftercnn_x_list = []
pe_for_vq_list = []
for each_x in x_list:
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
for each_x_split in each_x_split_list:
each_x_split = F.gelu(self.conv1(each_x_split))
each_x_split = F.gelu(self.conv2(each_x_split))
each_x_split = each_x_split.permute(1, 0) # L,D
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
x = torch.cat(aftercnn_x_list, dim=0)
src_len = x.size(0)
output_list = []
for item in audio_aftercnnlens:
while item > self.n_window:
output_list.append(self.n_window)
item -= self.n_window
output_list.append(item)
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
layer_id = 0
for block in self.blocks:
layer_id+=1
x = block(x, cu_seqlens=cu_seqlens)
if self.audio_vq_layers == layer_id: # vq inside encoder
x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
if return_indices:
return x, indices
if self.avg_pooler:
x_list = x.split(audio_aftercnnlens, dim=0)
token_x_list = []
for x in x_list:
x = x.permute(1, 0)
x = self.avg_pooler(x)
x = x.permute(1, 0)
token_x_list.append(x)
x = torch.cat(token_x_list, dim=0)
x = self.ln_post(x)
x = self.proj(x)
output = torch.zeros(
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
device=x.device, dtype=x.dtype
)
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
audio_tokens_mask[start_ids] = False
audio_tokens_mask[end_ids] = False
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
output[audio_tokens_mask] = x
if self.audio_vq_type != "NULL":
return output, vq_stats
return output

View File

@@ -0,0 +1,406 @@
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import torch
import operator
import numpy as np
import torch.nn.functional as F
from functools import lru_cache
from typing import Optional, Union, List
from torch import nn, Tensor
from itertools import accumulate
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
except ImportError:
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
except ImportError:
print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
flash_attn_varlen_func = None
N_FFT = 400
HOP_LENGTH = 160
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def get_T_after_cnn(L_in, dilation=1):
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
return L_out
def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
audio_len = len(audio)
if padding:
reduction = 160 * 2 * audio_vq_ds_rate
audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
else:
mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
return mel
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
class ConvTranspose1d(nn.ConvTranspose1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.use_flash_attention = True
def forward(
self,
x: Tensor,
cu_seqlens = None,
):
q = self.query(x)
k = self.key(x)
v = self.value(x)
if self.use_flash_attention:
if flash_attn_varlen_func is None:
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
else:
if q.dtype not in [torch.float16, torch.bfloat16]:
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
self.use_flash_attention = False
else:
x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
else:
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
output = self.out(x)
return output
def qkv_flash_attention(
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
):
n_ctx, n_state = q.shape
# scale = (n_state // self.n_head) ** -0.25
q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
k = k.view(n_ctx, self.n_head, -1)
v = v.view(n_ctx, self.n_head, -1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
x = flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
)
x = x.reshape(n_ctx, n_state)
return x
def qkv_attention_manual(
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
):
n_ctx, n_state = q.shape
head_dim = n_state // self.n_head
scale = head_dim ** -0.5
q = q.view(n_ctx, self.n_head, head_dim)
k = k.view(n_ctx, self.n_head, head_dim)
v = v.view(n_ctx, self.n_head, head_dim)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = len(seqlens)
max_seqlen = max(seqlens)
q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
k_padded = torch.zeros_like(q_padded)
v_padded = torch.zeros_like(q_padded)
for i in range(batch_size):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i+1]
seq_len = seqlens[i]
q_padded[i, :seq_len] = q[start_idx:end_idx]
k_padded[i, :seq_len] = k[start_idx:end_idx]
v_padded[i, :seq_len] = v[start_idx:end_idx]
q_padded = q_padded.transpose(1, 2)
k_padded = k_padded.transpose(1, 2)
v_padded = v_padded.transpose(1, 2)
attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
attn_scores = attn_scores + attn_mask
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v_padded)
context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
assert output_packed.shape == (n_ctx, n_state)
return output_packed
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int,
enable_mp: bool = False, sequence_parallel: bool = False):
super().__init__()
n_mlp = n_state * 4
self.attn_ln = nn.LayerNorm(n_state)
self.mlp_ln = nn.LayerNorm(n_state)
self.attn = MultiHeadAttention(n_state, n_head)
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
def forward(
self,
x: Tensor,
cu_seqlens = None
):
x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
x = x + self.mlp(self.mlp_ln(x))
return x
class WhisperEncoder(nn.Module):
def __init__(
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
n_window: int = 1500,
output_dim: int = 512,
grad_checkpointing: bool = False,
enable_mp: bool = False,
audio_sequence_parallel: bool = False,
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.n_layer = n_layer
self.n_mels = n_mels
self.blocks = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
for _ in range(n_layer)]
)
self.ln_post = nn.LayerNorm(n_state)
self.avg_pooler = nn.AvgPool1d(2, stride=2)
self.proj = torch.nn.Linear(n_state, output_dim)
self.audio_bos_eos_token = nn.Embedding(2, output_dim)
self.output_dim = output_dim
self.grad_checkpointing = grad_checkpointing
self.enable_mp = enable_mp
self.n_head = n_head
self.n_state = n_state
self.n_window = n_window
self.audio_sequence_parallel = audio_sequence_parallel
self.tp_world_size = 1
self.set_audio_sync()
def set_audio_sync(self):
for name, param in self.named_parameters():
if not name.startswith("blocks"):
setattr(param, "audio_sync", True)
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
"""
x : torch.Tensor, shape = (n_mels, n_ctx)
the mel spectrogram of the audio
"""
aftercnn_x_list = []
for each_x in x_list:
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
for each_x_split in each_x_split_list:
each_x_split = F.gelu(self.conv1(each_x_split))
each_x_split = F.gelu(self.conv2(each_x_split))
each_x_split = each_x_split.permute(1, 0) # L,D
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
x = torch.cat(aftercnn_x_list, dim=0)
src_len = x.size(0)
output_list = []
for item in audio_aftercnnlens:
while item > self.n_window:
output_list.append(self.n_window)
item -= self.n_window
output_list.append(item)
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
layer_id = 0
for block in self.blocks:
layer_id+=1
x = block(x, cu_seqlens=cu_seqlens)
if self.avg_pooler:
x_list = x.split(audio_aftercnnlens, dim=0)
token_x_list = []
for x in x_list:
x = x.permute(1, 0)
x = self.avg_pooler(x)
x = x.permute(1, 0)
token_x_list.append(x)
x = torch.cat(token_x_list, dim=0)
x = self.ln_post(x)
x = self.proj(x)
output = torch.zeros(
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
device=x.device, dtype=x.dtype
)
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
audio_tokens_mask[start_ids] = False
audio_tokens_mask[end_ids] = False
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
output[audio_tokens_mask] = x
return output
def lock(self, layers: int):
self.conv1.requires_grad_(False)
self.conv2.requires_grad_(False)
for i in range(min(layers, len(self.blocks))):
self.blocks[i].requires_grad_(False)