refactor: rename canto-backend → backend, canto-frontend → frontend
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
436
backend/indextts/s2mel/modules/gpt_fast/generate.py
Normal file
436
backend/indextts/s2mel/modules/gpt_fast/generate.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import itertools
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
|
||||
def device_sync(device):
|
||||
if "cuda" in device:
|
||||
torch.cuda.synchronize(device)
|
||||
elif ("cpu" in device) or ("mps" in device):
|
||||
pass
|
||||
else:
|
||||
print(f"device={device} is not yet suppported")
|
||||
|
||||
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||
|
||||
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# support running without installing as a package
|
||||
wd = Path(__file__).parent.parent.resolve()
|
||||
sys.path.append(str(wd))
|
||||
|
||||
from model import Transformer
|
||||
from tokenizer import get_tokenizer
|
||||
|
||||
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.empty_like(probs_sort).exponential_(1)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
|
||||
# input_pos: [B, S]
|
||||
logits = model(x, input_pos)
|
||||
return sample(logits, **sampling_kwargs)[0]
|
||||
|
||||
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# input_pos: [B, 1]
|
||||
assert input_pos.shape[-1] == 1
|
||||
logits = model(x, input_pos)
|
||||
return sample(logits, **sampling_kwargs)
|
||||
|
||||
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
|
||||
new_tokens, new_probs = [], []
|
||||
for i in range(num_new_tokens):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
|
||||
next_token, next_prob = decode_one_token(
|
||||
model, cur_token, input_pos, **sampling_kwargs
|
||||
)
|
||||
input_pos += 1
|
||||
new_tokens.append(next_token.clone())
|
||||
callback(new_tokens[-1])
|
||||
new_probs.append(next_prob.clone())
|
||||
cur_token = next_token.view(1, -1)
|
||||
|
||||
return new_tokens, new_probs
|
||||
|
||||
|
||||
def model_forward(model, x, input_pos):
|
||||
return model(x, input_pos)
|
||||
|
||||
def speculative_decode(
|
||||
model: Transformer,
|
||||
draft_model: Transformer,
|
||||
cur_token: torch.Tensor,
|
||||
input_pos: int,
|
||||
speculate_k: int,
|
||||
**sampling_kwargs
|
||||
) -> torch.Tensor:
|
||||
# draft model inference sequentially
|
||||
device = cur_token.device
|
||||
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
|
||||
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
|
||||
|
||||
draft_tokens = torch.cat(draft_tokens)
|
||||
# parallel inference on target model using draft tokens
|
||||
target_logits = model_forward(
|
||||
model,
|
||||
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
|
||||
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
|
||||
)
|
||||
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
|
||||
draft_probs = torch.stack(draft_probs)
|
||||
# q: target prob, p: draft prob
|
||||
# q >= p: always accept draft token
|
||||
# q < p: q/p prob to accept draft token
|
||||
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
||||
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
||||
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
|
||||
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
|
||||
|
||||
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
|
||||
accept_length = speculate_k + 1
|
||||
last_token = multinomial_sample_one_no_sync(target_probs[-1])
|
||||
# fill last token into draft model
|
||||
model_forward(
|
||||
draft_model,
|
||||
draft_tokens[-1].view(1, -1),
|
||||
orig_input_pos + speculate_k,
|
||||
)
|
||||
return torch.cat([draft_tokens, last_token])
|
||||
else:
|
||||
accept_length = rejected_locations[0].item()
|
||||
p = draft_probs[accept_length]
|
||||
q = target_probs[accept_length]
|
||||
new = q - p
|
||||
new = torch.where(new > 0, new, 0.0)
|
||||
new = new / new.sum()
|
||||
next_token = multinomial_sample_one_no_sync(new)
|
||||
return torch.cat([draft_tokens[:accept_length], next_token])
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
model: Transformer,
|
||||
prompt: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
*,
|
||||
interactive: bool,
|
||||
draft_model: Transformer,
|
||||
speculate_k: Optional[int] = 8,
|
||||
callback = lambda x: x,
|
||||
**sampling_kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
||||
"""
|
||||
|
||||
is_speculative = draft_model is not None
|
||||
# create an empty tensor of the expected final shape and fill in the current tokens
|
||||
T = prompt.size(0)
|
||||
T_new = T + max_new_tokens
|
||||
if interactive:
|
||||
max_seq_length = 350
|
||||
else:
|
||||
max_seq_length = min(T_new, model.config.block_size)
|
||||
|
||||
device, dtype = prompt.device, prompt.dtype
|
||||
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
|
||||
with torch.device(device):
|
||||
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
||||
if is_speculative and draft_model is not model:
|
||||
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
||||
|
||||
# create an empty tensor of the expected final shape and fill in the current tokens
|
||||
empty = torch.empty(T_new, dtype=dtype, device=device)
|
||||
empty[:T] = prompt
|
||||
seq = empty
|
||||
input_pos = torch.arange(0, T, device=device)
|
||||
|
||||
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
|
||||
if is_speculative:
|
||||
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
|
||||
seq[T] = next_token
|
||||
|
||||
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
||||
accept_counts = [0] * (speculate_k + 1)
|
||||
|
||||
if is_speculative:
|
||||
input_pos = input_pos.item() # for speculative decoding easier to keep on host
|
||||
while input_pos < T_new - 1:
|
||||
cur_token = next_token.view(())
|
||||
|
||||
next_tokens = speculative_decode(
|
||||
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
|
||||
)
|
||||
|
||||
accept_counts[len(next_tokens) - 1] += 1
|
||||
num_added = min(T_new - input_pos - 1, len(next_tokens))
|
||||
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
|
||||
for i in next_tokens[: num_added,]:
|
||||
callback(i)
|
||||
input_pos = input_pos + num_added
|
||||
next_token = next_tokens[-1]
|
||||
else:
|
||||
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
|
||||
seq[T + 1:] = torch.cat(generated_tokens)
|
||||
|
||||
generate_stats = {
|
||||
'accept_counts': accept_counts
|
||||
}
|
||||
return seq, generate_stats
|
||||
|
||||
def encode_tokens(tokenizer, string, bos=True, device=default_device):
|
||||
tokens = tokenizer.encode(string)
|
||||
if bos:
|
||||
tokens = [tokenizer.bos_id()] + tokens
|
||||
return torch.tensor(tokens, dtype=torch.int, device=device)
|
||||
|
||||
def _load_model(checkpoint_path, device, precision, use_tp):
|
||||
use_cuda = 'cuda' in device
|
||||
with torch.device('meta'):
|
||||
model = Transformer.from_name(checkpoint_path.parent.name)
|
||||
|
||||
if "int8" in str(checkpoint_path):
|
||||
print("Using int8 weight-only quantization!")
|
||||
from quantize import WeightOnlyInt8QuantHandler
|
||||
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
||||
model = simple_quantizer.convert_for_runtime()
|
||||
|
||||
if "int4" in str(checkpoint_path):
|
||||
print("Using int4 weight-only quantization!")
|
||||
path_comps = checkpoint_path.name.split(".")
|
||||
groupsize = int(path_comps[-2][1:])
|
||||
from quantize import WeightOnlyInt4QuantHandler
|
||||
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
||||
model = simple_quantizer.convert_for_runtime()
|
||||
|
||||
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
||||
if "model" in checkpoint and "stories" in str(checkpoint_path):
|
||||
checkpoint = checkpoint["model"]
|
||||
model.load_state_dict(checkpoint, assign=True)
|
||||
|
||||
if use_tp:
|
||||
from tp import apply_tp
|
||||
print("Applying tensor parallel to model ...")
|
||||
apply_tp(model)
|
||||
|
||||
model = model.to(device=device, dtype=precision)
|
||||
return model.eval()
|
||||
|
||||
def _get_model_size(model):
|
||||
model_size = 0
|
||||
for name, child in model.named_children():
|
||||
if not isinstance(child, torch.nn.Embedding):
|
||||
model_size += sum(
|
||||
[
|
||||
p.numel() * p.dtype.itemsize
|
||||
for p in itertools.chain(child.parameters(), child.buffers())
|
||||
]
|
||||
)
|
||||
return model_size
|
||||
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
|
||||
def main(
|
||||
prompt: str = "Hello, my name is",
|
||||
interactive: bool = False,
|
||||
num_samples: int = 5,
|
||||
max_new_tokens: int = 100,
|
||||
top_k: int = 200,
|
||||
temperature: float = 0.8,
|
||||
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
|
||||
compile: bool = True,
|
||||
compile_prefill: bool = False,
|
||||
profile: Optional[Path] = None,
|
||||
draft_checkpoint_path: Optional[Path] = None,
|
||||
speculate_k: int = 5,
|
||||
device=default_device,
|
||||
) -> None:
|
||||
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
|
||||
"""
|
||||
assert checkpoint_path.is_file(), checkpoint_path
|
||||
|
||||
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
||||
assert tokenizer_path.is_file(), str(tokenizer_path)
|
||||
|
||||
global print
|
||||
from tp import maybe_init_dist
|
||||
rank = maybe_init_dist()
|
||||
use_tp = rank is not None
|
||||
if use_tp:
|
||||
if rank != 0:
|
||||
# only print on rank 0
|
||||
print = lambda *args, **kwargs: None
|
||||
|
||||
print(f"Using device={device}")
|
||||
precision = torch.bfloat16
|
||||
is_speculative = draft_checkpoint_path is not None
|
||||
is_chat = "chat" in str(checkpoint_path)
|
||||
|
||||
print("Loading model ...")
|
||||
t0 = time.time()
|
||||
model = _load_model(checkpoint_path, device, precision, use_tp)
|
||||
|
||||
if is_speculative:
|
||||
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
|
||||
else:
|
||||
draft_model = None
|
||||
|
||||
device_sync(device=device) # MKG
|
||||
print(f"Time to load model: {time.time() - t0:.02f} seconds")
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
||||
|
||||
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
||||
prompt_length = encoded.size(0)
|
||||
|
||||
torch.manual_seed(1234)
|
||||
model_size = _get_model_size(model)
|
||||
if compile:
|
||||
if is_speculative and use_tp: # and ("cuda" in device):
|
||||
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
|
||||
|
||||
if is_speculative:
|
||||
global model_forward, logits_to_prob
|
||||
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
global decode_one_token, prefill
|
||||
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# Uncomment to squeeze more perf out of prefill
|
||||
if compile_prefill:
|
||||
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
|
||||
|
||||
|
||||
aggregate_metrics = {
|
||||
'tokens_per_sec': [],
|
||||
'accept_counts': [],
|
||||
}
|
||||
start = -1 if compile else 0
|
||||
|
||||
for i in range(start, num_samples):
|
||||
device_sync(device=device) # MKG
|
||||
if i >= 0 and interactive:
|
||||
prompt = input("What is your prompt? ")
|
||||
if is_chat:
|
||||
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
|
||||
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
||||
|
||||
if interactive and i >= 0:
|
||||
buffer = []
|
||||
period_id = tokenizer.encode('.')[0]
|
||||
done_generating = False
|
||||
def callback(x):
|
||||
nonlocal done_generating
|
||||
if done_generating:
|
||||
return
|
||||
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
|
||||
if x.item() == tokenizer.eos_id():
|
||||
done_generating = True
|
||||
if len(buffer) == 4 or done_generating:
|
||||
print(''.join(buffer), end='', flush=True)
|
||||
buffer.clear()
|
||||
# print(, end='', flush=True)
|
||||
else:
|
||||
callback = lambda x : x
|
||||
t0 = time.perf_counter()
|
||||
import contextlib
|
||||
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
|
||||
prof = contextlib.nullcontext()
|
||||
else:
|
||||
torch.profiler._utils._init_for_cuda_graphs()
|
||||
prof = torch.profiler.profile()
|
||||
with prof:
|
||||
y, metrics = generate(
|
||||
model,
|
||||
encoded,
|
||||
max_new_tokens,
|
||||
draft_model=draft_model,
|
||||
speculate_k=speculate_k,
|
||||
interactive=interactive,
|
||||
callback=callback,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
)
|
||||
aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
|
||||
if i == -1:
|
||||
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
||||
continue
|
||||
if hasattr(prof, "export_chrome_trace"):
|
||||
if use_tp:
|
||||
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
|
||||
else:
|
||||
prof.export_chrome_trace(f"{profile}.json")
|
||||
device_sync(device=device) # MKG
|
||||
t = time.perf_counter() - t0
|
||||
|
||||
if not interactive:
|
||||
print(tokenizer.decode(y.tolist()))
|
||||
else:
|
||||
print()
|
||||
tokens_generated = y.size(0) - prompt_length
|
||||
tokens_sec = tokens_generated / t
|
||||
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
|
||||
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
|
||||
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
||||
print("==========")
|
||||
if is_speculative:
|
||||
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
|
||||
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
|
||||
print(f"Acceptance probs: {acceptance_probs}")
|
||||
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
|
||||
|
||||
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
|
||||
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Your CLI description.')
|
||||
|
||||
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
|
||||
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
|
||||
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
|
||||
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
|
||||
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
|
||||
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
|
||||
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
|
||||
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
|
||||
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
|
||||
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
|
||||
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
|
||||
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
|
||||
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(
|
||||
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
|
||||
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
|
||||
args.speculate_k, args.device
|
||||
)
|
||||
360
backend/indextts/s2mel/modules/gpt_fast/model.py
Normal file
360
backend/indextts/s2mel/modules/gpt_fast/model.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
class AdaptiveLayerNorm(nn.Module):
|
||||
r"""Adaptive Layer Normalization"""
|
||||
|
||||
def __init__(self, d_model, norm) -> None:
|
||||
super(AdaptiveLayerNorm, self).__init__()
|
||||
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
||||
self.norm = norm
|
||||
self.d_model = d_model
|
||||
self.eps = self.norm.eps
|
||||
|
||||
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
||||
if embedding is None:
|
||||
return self.norm(input)
|
||||
weight, bias = torch.split(
|
||||
self.project_layer(embedding),
|
||||
split_size_or_sections=self.d_model,
|
||||
dim=-1,
|
||||
)
|
||||
return weight * self.norm(input) + bias
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
block_size: int = 2048
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = None
|
||||
n_local_heads: int = -1
|
||||
head_dim: int = 64
|
||||
rope_base: float = 10000
|
||||
norm_eps: float = 1e-5
|
||||
has_cross_attention: bool = False
|
||||
context_dim: int = 0
|
||||
uvit_skip_connection: bool = False
|
||||
time_as_token: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_local_heads == -1:
|
||||
self.n_local_heads = self.n_head
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = find_multiple(n_hidden, 256)
|
||||
# self.head_dim = self.dim // self.n_head
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name in transformer_configs:
|
||||
return cls(**transformer_configs[name])
|
||||
# fuzzy search
|
||||
config = [config for config in transformer_configs if config.lower() in str(name).lower()]
|
||||
|
||||
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
|
||||
# take longer name (as it have more symbols matched)
|
||||
if len(config) > 1:
|
||||
config.sort(key=len, reverse=True)
|
||||
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
|
||||
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
|
||||
transformer_configs = {
|
||||
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000),
|
||||
"7B": dict(n_layer=32, n_head=32, dim=4096),
|
||||
"13B": dict(n_layer=40, n_head=40, dim=5120),
|
||||
"30B": dict(n_layer=60, n_head=52, dim=6656),
|
||||
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016,
|
||||
rope_base=1000000), # CodeLlama-34B-Python-hf
|
||||
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
|
||||
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
|
||||
"stories15M": dict(n_layer=6, n_head=6, dim=288),
|
||||
"stories110M": dict(n_layer=12, n_head=12, dim=768),
|
||||
|
||||
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336,
|
||||
vocab_size=128256, rope_base=500000),
|
||||
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672,
|
||||
vocab_size=128256, rope_base=500000),
|
||||
}
|
||||
|
||||
|
||||
class KVCache(nn.Module):
|
||||
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
||||
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
||||
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
||||
|
||||
def update(self, input_pos, k_val, v_val):
|
||||
# input_pos: [S], k_val: [B, H, S, D]
|
||||
assert input_pos.shape[0] == k_val.shape[2]
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out[:, :, input_pos] = k_val
|
||||
v_out[:, :, input_pos] = v_val
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
||||
self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
||||
|
||||
self.freqs_cis: Optional[Tensor] = None
|
||||
self.mask_cache: Optional[Tensor] = None
|
||||
self.max_batch_size = -1
|
||||
self.max_seq_length = -1
|
||||
|
||||
def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True):
|
||||
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
||||
return
|
||||
head_dim = self.config.dim // self.config.n_head
|
||||
max_seq_length = find_multiple(max_seq_length, 8)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
dtype = self.norm.project_layer.weight.dtype
|
||||
device = self.norm.project_layer.weight.device
|
||||
|
||||
if not self.training and use_kv_cache:
|
||||
for b in self.layers:
|
||||
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
|
||||
self.config.rope_base, dtype).to(device)
|
||||
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
|
||||
self.use_kv_cache = use_kv_cache
|
||||
self.uvit_skip_connection = self.config.uvit_skip_connection
|
||||
if self.uvit_skip_connection:
|
||||
self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
|
||||
self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
|
||||
else:
|
||||
self.layers_emit_skip = []
|
||||
self.layers_receive_skip = []
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
c: Tensor,
|
||||
input_pos: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
context: Optional[Tensor] = None,
|
||||
context_input_pos: Optional[Tensor] = None,
|
||||
cross_attention_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert self.freqs_cis is not None, "Caches must be initialized first"
|
||||
if mask is None: # in case of non-causal model
|
||||
if not self.training and self.use_kv_cache:
|
||||
mask = self.causal_mask[None, None, input_pos]
|
||||
else:
|
||||
mask = self.causal_mask[None, None, input_pos]
|
||||
mask = mask[..., input_pos]
|
||||
freqs_cis = self.freqs_cis[input_pos]
|
||||
if context is not None:
|
||||
context_freqs_cis = self.freqs_cis[context_input_pos]
|
||||
else:
|
||||
context_freqs_cis = None
|
||||
skip_in_x_list = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
if self.uvit_skip_connection and i in self.layers_receive_skip:
|
||||
skip_in_x = skip_in_x_list.pop(-1)
|
||||
else:
|
||||
skip_in_x = None
|
||||
x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
|
||||
if self.uvit_skip_connection and i in self.layers_emit_skip:
|
||||
skip_in_x_list.append(x)
|
||||
x = self.norm(x, c)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
return cls(ModelArgs.from_name(name))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.attention = Attention(config)
|
||||
self.feed_forward = FeedForward(config)
|
||||
self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
||||
self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
||||
|
||||
if config.has_cross_attention:
|
||||
self.has_cross_attention = True
|
||||
self.cross_attention = Attention(config, is_cross_attention=True)
|
||||
self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
if config.uvit_skip_connection:
|
||||
self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
|
||||
self.uvit_skip_connection = True
|
||||
else:
|
||||
self.uvit_skip_connection = False
|
||||
|
||||
self.time_as_token = config.time_as_token
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
c: Tensor,
|
||||
input_pos: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
mask: Tensor,
|
||||
context: Optional[Tensor] = None,
|
||||
context_freqs_cis: Optional[Tensor] = None,
|
||||
cross_attention_mask: Optional[Tensor] = None,
|
||||
skip_in_x: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
c = None if self.time_as_token else c
|
||||
if self.uvit_skip_connection and skip_in_x is not None:
|
||||
x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
|
||||
h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
|
||||
if self.has_cross_attention:
|
||||
h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
|
||||
out = h + self.feed_forward(self.ffn_norm(h, c))
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
|
||||
super().__init__()
|
||||
assert config.dim % config.n_head == 0
|
||||
|
||||
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
if is_cross_attention:
|
||||
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
|
||||
self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
|
||||
else:
|
||||
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
||||
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
||||
self.kv_cache = None
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
self.n_local_heads = config.n_local_heads
|
||||
self.dim = config.dim
|
||||
# self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
# def load_hook(self, state_dict, prefix, *args):
|
||||
# if prefix + "wq.weight" in state_dict:
|
||||
# wq = state_dict.pop(prefix + "wq.weight")
|
||||
# wk = state_dict.pop(prefix + "wk.weight")
|
||||
# wv = state_dict.pop(prefix + "wv.weight")
|
||||
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
mask: Tensor,
|
||||
input_pos: Optional[Tensor] = None,
|
||||
context: Optional[Tensor] = None,
|
||||
context_freqs_cis: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
kv_size = self.n_local_heads * self.head_dim
|
||||
if context is None:
|
||||
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
||||
context_seqlen = seqlen
|
||||
else:
|
||||
q = self.wq(x)
|
||||
k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
|
||||
context_seqlen = context.shape[1]
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
||||
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
||||
|
||||
q = apply_rotary_emb(q, freqs_cis)
|
||||
k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
if self.kv_cache is not None:
|
||||
k, v = self.kv_cache.update(input_pos, k, v)
|
||||
|
||||
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
||||
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
||||
|
||||
y = self.wo(y)
|
||||
return y
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
||||
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
||||
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def precompute_freqs_cis(
|
||||
seq_len: int, n_elem: int, base: int = 10000,
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
) -> Tensor:
|
||||
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
||||
t = torch.arange(seq_len, device=freqs.device)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
||||
return cache.to(dtype=dtype)
|
||||
|
||||
|
||||
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
||||
x_out2 = torch.stack(
|
||||
[
|
||||
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
||||
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
x_out2 = x_out2.flatten(3)
|
||||
return x_out2.type_as(x)
|
||||
622
backend/indextts/s2mel/modules/gpt_fast/quantize.py
Normal file
622
backend/indextts/s2mel/modules/gpt_fast/quantize.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tokenizer import get_tokenizer
|
||||
|
||||
try:
|
||||
from GPTQ import GenericGPTQRunner, InputRecorder
|
||||
from eval import get_task_dict, evaluate, lm_eval
|
||||
except:
|
||||
pass
|
||||
|
||||
from model import Transformer
|
||||
|
||||
##### Quantization Primitives ######
|
||||
|
||||
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
||||
# assumes symmetric quantization
|
||||
# assumes axis == 0
|
||||
# assumes dense memory format
|
||||
# TODO(future): relax ^ as needed
|
||||
|
||||
# default setup for affine quantization of activations
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
|
||||
# get min and max
|
||||
min_val, max_val = torch.aminmax(x, dim=1)
|
||||
|
||||
# calculate scales and zero_points based on min and max
|
||||
# reference: https://fburl.com/code/srbiybme
|
||||
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
||||
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
||||
device = min_val_neg.device
|
||||
|
||||
# reference: https://fburl.com/code/4wll53rk
|
||||
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
||||
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
||||
# ensure scales is the same dtype as the original tensor
|
||||
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
||||
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
||||
|
||||
# quantize based on qmin/qmax/scales/zp
|
||||
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
||||
x_div = x / scales.unsqueeze(-1)
|
||||
x_round = torch.round(x_div)
|
||||
x_zp = x_round + zero_points.unsqueeze(-1)
|
||||
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
||||
|
||||
return quant, scales, zero_points
|
||||
|
||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
||||
# needed for GPTQ with padding
|
||||
if groupsize > w.shape[-1]:
|
||||
groupsize = w.shape[-1]
|
||||
assert groupsize > 1
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
||||
torch.bfloat16
|
||||
).reshape(w.shape[0], -1)
|
||||
|
||||
|
||||
def pack_scales_and_zeros(scales, zeros):
|
||||
assert scales.shape == zeros.shape
|
||||
assert scales.dtype == torch.bfloat16
|
||||
assert zeros.dtype == torch.bfloat16
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
||||
],
|
||||
2,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def unpack_scales_and_zeros(scales_and_zeros):
|
||||
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
||||
assert scales_and_zeros.dtype == torch.float
|
||||
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
||||
|
||||
|
||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
||||
assert groupsize > 1
|
||||
# needed for GPTQ single column quantize
|
||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
||||
groupsize = w.shape[-1]
|
||||
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
scales = scales.reshape(-1, 1)
|
||||
zeros = zeros.reshape(-1, 1)
|
||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
w_int32 = (
|
||||
to_quant.sub(min_val)
|
||||
.div(scales)
|
||||
.round()
|
||||
.clamp_(min_int, max_int)
|
||||
.to(torch.int32)
|
||||
.reshape_as(w)
|
||||
)
|
||||
|
||||
return w_int32
|
||||
|
||||
|
||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
||||
return w_int32, scales_and_zeros
|
||||
|
||||
|
||||
def group_dequantize_tensor_from_qparams(
|
||||
w_int32, scales, zeros, n_bit=4, groupsize=128
|
||||
):
|
||||
assert groupsize > 1
|
||||
# needed for GPTQ single column dequantize
|
||||
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
||||
groupsize = w_int32.shape[-1]
|
||||
assert w_int32.shape[-1] % groupsize == 0
|
||||
assert w_int32.dim() == 2
|
||||
|
||||
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
||||
scales = scales.reshape(-1, 1)
|
||||
zeros = zeros.reshape(-1, 1)
|
||||
|
||||
w_dq = (
|
||||
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
||||
)
|
||||
return w_dq
|
||||
|
||||
|
||||
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
||||
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
||||
return group_dequantize_tensor_from_qparams(
|
||||
w_int32, scales, zeros, n_bit, groupsize
|
||||
)
|
||||
|
||||
class QuantHandler:
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
|
||||
def create_quantized_state_dict(self) -> "StateDict":
|
||||
pass
|
||||
|
||||
def convert_for_runtime(self) -> "nn.Module":
|
||||
pass
|
||||
|
||||
class GPTQQuantHandler(QuantHandler):
|
||||
"""
|
||||
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
|
||||
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
|
||||
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
|
||||
|
||||
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
|
||||
create_quantized_state_dict. Here is a description of each function.
|
||||
|
||||
get_qparams_func:
|
||||
A function that calculates the quantization qparams for an input tensor.
|
||||
Args:
|
||||
weight: A 2d weight tensor with non-integer dtype.
|
||||
Returns:
|
||||
qparams: it can have any format but will need to be handled by the other defined functions below.
|
||||
|
||||
quantize_func:
|
||||
A function that applies quantization to an input tensor. It should be noted
|
||||
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
|
||||
or a single column.
|
||||
Args:
|
||||
weight: A 2d weight tensor with non-integer dtype.
|
||||
qparams: the output from get_qparams_func
|
||||
Returns:
|
||||
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
||||
|
||||
|
||||
dequantize_func:
|
||||
A function that dequantizes an input quantized weight tensor. It should be noted
|
||||
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
|
||||
or a single column.
|
||||
Args:
|
||||
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
||||
qparams: the output from get_qparams_func
|
||||
Returns:
|
||||
weight: A 2d weight tensor with non-integer dtype.
|
||||
|
||||
combine_qparams_list_func:
|
||||
A function that combines several qparams into one qparam.
|
||||
Args:
|
||||
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
|
||||
on a single group from a weight tensor
|
||||
Returns:
|
||||
qparams: an object of the same format as the qparams above.
|
||||
|
||||
skip_layer_func:
|
||||
A function that determines which linear layers should be skipped during GPTQ
|
||||
Args:
|
||||
weight: A 2d weight tensor with non-integer dtype.
|
||||
Returns:
|
||||
skip: boolean indicating whether layer should be skipped
|
||||
|
||||
make_names_and_values_dict_func:
|
||||
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
|
||||
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
|
||||
Args:
|
||||
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
||||
qparams: the output from get_qparams_func
|
||||
Returns:
|
||||
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
|
||||
corresponding quantized weights and qparams.
|
||||
"""
|
||||
def __init__(self):
|
||||
assert self.mod is not None
|
||||
assert self.get_qparams_func is not None
|
||||
assert self.quantize_func is not None
|
||||
assert self.dequantize_func is not None
|
||||
assert self.combine_qparams_list_func is not None
|
||||
assert self.make_names_and_values_dict_func is not None
|
||||
|
||||
@staticmethod
|
||||
def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
|
||||
input_recorder = InputRecorder(
|
||||
model,
|
||||
tokenizer,
|
||||
calibration_seq_length,
|
||||
pad_calibration_inputs,
|
||||
)
|
||||
|
||||
try:
|
||||
lm_eval.tasks.initialize_tasks()
|
||||
except:
|
||||
pass
|
||||
task_dict = get_task_dict(calibration_tasks)
|
||||
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
|
||||
|
||||
evaluate(
|
||||
input_recorder,
|
||||
task_dict,
|
||||
limit=calibration_limit,
|
||||
)
|
||||
inputs = input_recorder.get_recorded_inputs()
|
||||
assert inputs is not None, (
|
||||
f"No inputs were collected, use a task other than {calibration_tasks}, "+
|
||||
f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
|
||||
f"{calibration_seq_length})"
|
||||
)
|
||||
print(f"Obtained {len(inputs[0].values)} calibration samples")
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def create_quantized_state_dict(
|
||||
self,
|
||||
tokenizer,
|
||||
blocksize,
|
||||
percdamp,
|
||||
groupsize,
|
||||
calibration_tasks,
|
||||
calibration_limit,
|
||||
calibration_seq_length,
|
||||
pad_calibration_inputs,
|
||||
) -> "StateDict":
|
||||
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
|
||||
print("Tracing model for GPTQ")
|
||||
GPTQ_runner = GenericGPTQRunner(
|
||||
self.mod,
|
||||
inputs,
|
||||
blocksize,
|
||||
percdamp,
|
||||
groupsize,
|
||||
).configure_quantization_mode(
|
||||
self.get_qparams_func,
|
||||
self.quantize_func,
|
||||
self.dequantize_func,
|
||||
self.combine_qparams_list_func,
|
||||
self.make_names_and_values_dict_func,
|
||||
self.skip_layer_func
|
||||
)
|
||||
|
||||
print("Applying GPTQ to weights")
|
||||
GPTQ_runner.run()
|
||||
return GPTQ_runner.get_quantized_state_dict()
|
||||
|
||||
def convert_for_runtime(self) -> "nn.Module":
|
||||
pass
|
||||
|
||||
##### Weight-only int8 per-channel quantized code ######
|
||||
|
||||
def replace_linear_weight_only_int8_per_channel(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
|
||||
else:
|
||||
replace_linear_weight_only_int8_per_channel(child)
|
||||
|
||||
class WeightOnlyInt8QuantHandler:
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
|
||||
@torch.no_grad()
|
||||
def create_quantized_state_dict(self):
|
||||
cur_state_dict = self.mod.state_dict()
|
||||
for fqn, mod in self.mod.named_modules():
|
||||
if isinstance(mod, torch.nn.Linear):
|
||||
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
|
||||
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
||||
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
||||
|
||||
return cur_state_dict
|
||||
|
||||
def convert_for_runtime(self):
|
||||
replace_linear_weight_only_int8_per_channel(self.mod)
|
||||
return self.mod
|
||||
|
||||
|
||||
class WeightOnlyInt8Linear(torch.nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: torch.Tensor
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
|
||||
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
||||
|
||||
##### weight only int4 per channel groupwise quantized code ######
|
||||
|
||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
||||
weight_bf16, n_bit=4, groupsize=groupsize
|
||||
)
|
||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
||||
return weight_int4pack, scales_and_zeros
|
||||
|
||||
|
||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
||||
origin_x_size = x.size()
|
||||
x = x.reshape(-1, origin_x_size[-1])
|
||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
||||
new_shape = origin_x_size[:-1] + (out_features,)
|
||||
c = c.reshape(new_shape)
|
||||
return c
|
||||
|
||||
|
||||
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
|
||||
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
||||
|
||||
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
||||
setattr(module, name, WeightOnlyInt4Linear(
|
||||
child.in_features, child.out_features, bias=False,
|
||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
|
||||
))
|
||||
elif padding:
|
||||
setattr(module, name, WeightOnlyInt4Linear(
|
||||
child.in_features, child.out_features, bias=False,
|
||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
|
||||
))
|
||||
else:
|
||||
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
||||
|
||||
|
||||
class WeightOnlyInt4QuantHandler:
|
||||
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
||||
self.mod = mod
|
||||
self.groupsize = groupsize
|
||||
self.inner_k_tiles = inner_k_tiles
|
||||
self.padding = padding
|
||||
assert groupsize in [32, 64, 128, 256]
|
||||
assert inner_k_tiles in [2, 4, 8]
|
||||
|
||||
@torch.no_grad()
|
||||
def create_quantized_state_dict(self, use_cuda = True):
|
||||
if use_cuda:
|
||||
device="cuda"
|
||||
else:
|
||||
device="cpu"
|
||||
|
||||
cur_state_dict = self.mod.state_dict()
|
||||
for fqn, mod in self.mod.named_modules():
|
||||
if isinstance(mod, torch.nn.Linear):
|
||||
assert not mod.bias
|
||||
out_features = mod.out_features
|
||||
in_features = mod.in_features
|
||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
||||
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
||||
|
||||
weight = mod.weight.data
|
||||
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
|
||||
if self.padding:
|
||||
from model import find_multiple
|
||||
import torch.nn.functional as F
|
||||
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
|
||||
padded_in_features = find_multiple(in_features, 1024)
|
||||
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
|
||||
else:
|
||||
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
|
||||
"and that groupsize and inner_k_tiles*16 evenly divide into it")
|
||||
continue
|
||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
||||
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
|
||||
)
|
||||
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
|
||||
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
|
||||
|
||||
return cur_state_dict
|
||||
|
||||
def convert_for_runtime(self):
|
||||
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
||||
return self.mod
|
||||
|
||||
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
|
||||
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
||||
from model import find_multiple
|
||||
self.mod = mod
|
||||
self.groupsize = groupsize
|
||||
self.inner_k_tiles = inner_k_tiles
|
||||
self.padding = padding
|
||||
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
|
||||
self.quantize_func = lambda w, qparams: \
|
||||
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
|
||||
self.dequantize_func = lambda q, qparams: \
|
||||
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
|
||||
self.combine_qparams_list_func = lambda qparams_list: \
|
||||
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
|
||||
# skip unless padding=True or its correctly sized
|
||||
self.skip_layer_func = lambda linear_weight: not (
|
||||
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
|
||||
)
|
||||
# we need to do the padding here, both for q and the qparams if necessary
|
||||
def make_names_and_values_dict_func(q, qparams):
|
||||
k = q.shape[1]
|
||||
new_k = find_multiple(k, 1024)
|
||||
# how much we need to pad the weight
|
||||
delta_k = new_k - q.shape[1]
|
||||
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
|
||||
scales_and_zeros = pack_scales_and_zeros(*qparams)
|
||||
# how many new groups we need for padded weight
|
||||
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
|
||||
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
|
||||
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
|
||||
self.make_names_and_values_dict_func = make_names_and_values_dict_func
|
||||
super().__init__()
|
||||
|
||||
|
||||
def convert_for_runtime(self):
|
||||
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
||||
return self.mod
|
||||
|
||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int,
|
||||
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding = padding
|
||||
if padding:
|
||||
from model import find_multiple
|
||||
self.origin_in_features = in_features
|
||||
in_features = find_multiple(in_features, 1024)
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
assert not bias, "require bias=False"
|
||||
self.groupsize = groupsize
|
||||
self.inner_k_tiles = inner_k_tiles
|
||||
|
||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales_and_zeros",
|
||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
input = input.to(torch.bfloat16)
|
||||
if self.padding:
|
||||
import torch.nn.functional as F
|
||||
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
||||
return linear_forward_int4(
|
||||
input,
|
||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
||||
)
|
||||
|
||||
|
||||
def quantize(
|
||||
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
||||
mode: str = 'int8',
|
||||
# following arguments only available when setting int4 quantization.
|
||||
groupsize: int = 128,
|
||||
# following arguments only used for GPTQ
|
||||
calibration_tasks: list = ["hellaswag"],
|
||||
calibration_limit: int = 1000,
|
||||
calibration_seq_length: int = 100,
|
||||
pad_calibration_inputs: bool = False,
|
||||
percdamp: float = .01,
|
||||
blocksize: int = 128,
|
||||
label: str = '',
|
||||
) -> None:
|
||||
assert checkpoint_path.is_file(), checkpoint_path
|
||||
|
||||
device = 'cpu'
|
||||
precision = torch.bfloat16
|
||||
|
||||
print("Loading model ...")
|
||||
t0 = time.time()
|
||||
|
||||
with torch.device('meta'):
|
||||
model = Transformer.from_name(checkpoint_path.parent.name)
|
||||
|
||||
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
||||
model.load_state_dict(checkpoint, assign=True)
|
||||
model = model.to(dtype=precision, device=device)
|
||||
|
||||
if mode == 'int8':
|
||||
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
|
||||
quant_handler = WeightOnlyInt8QuantHandler(model)
|
||||
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
||||
|
||||
dir_name = checkpoint_path.parent
|
||||
base_name = checkpoint_path.name
|
||||
new_base_name = base_name.replace('.pth', f'{label}int8.pth')
|
||||
|
||||
elif mode == 'int4':
|
||||
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
|
||||
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
||||
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
||||
|
||||
dir_name = checkpoint_path.parent
|
||||
base_name = checkpoint_path.name
|
||||
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
|
||||
|
||||
elif mode == 'int4-gptq':
|
||||
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
|
||||
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
|
||||
|
||||
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
||||
assert tokenizer_path.is_file(), str(tokenizer_path)
|
||||
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
||||
|
||||
quantized_state_dict = quant_handler.create_quantized_state_dict(
|
||||
tokenizer,
|
||||
blocksize,
|
||||
percdamp,
|
||||
groupsize,
|
||||
calibration_tasks,
|
||||
calibration_limit,
|
||||
calibration_seq_length,
|
||||
pad_calibration_inputs
|
||||
)
|
||||
|
||||
dir_name = checkpoint_path.parent
|
||||
base_name = checkpoint_path.name
|
||||
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
|
||||
else:
|
||||
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
|
||||
|
||||
quantize_path = dir_name / new_base_name
|
||||
print(f"Writing quantized weights to {quantize_path}")
|
||||
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
||||
torch.save(quantized_state_dict, quantize_path)
|
||||
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Quantize a model.')
|
||||
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
|
||||
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
|
||||
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
|
||||
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
|
||||
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
|
||||
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
|
||||
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
|
||||
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
|
||||
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
|
||||
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
|
||||
|
||||
args = parser.parse_args()
|
||||
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
|
||||
Reference in New Issue
Block a user