refactor: rename canto-backend → backend, canto-frontend → frontend
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
104
backend/indextts/utils/maskgct/models/codec/facodec/optimizer.py
Normal file
104
backend/indextts/utils/maskgct/models/codec/facodec/optimizer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2023 Amphion.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os, sys
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from functools import reduce
|
||||
from torch.optim import AdamW
|
||||
|
||||
|
||||
class MultiOptimizer:
|
||||
def __init__(self, optimizers={}, schedulers={}):
|
||||
self.optimizers = optimizers
|
||||
self.schedulers = schedulers
|
||||
self.keys = list(optimizers.keys())
|
||||
self.param_groups = reduce(
|
||||
lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
|
||||
)
|
||||
|
||||
def state_dict(self):
|
||||
state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
|
||||
return state_dicts
|
||||
|
||||
def scheduler_state_dict(self):
|
||||
state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
|
||||
return state_dicts
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
for key, val in state_dict:
|
||||
try:
|
||||
self.optimizers[key].load_state_dict(val)
|
||||
except:
|
||||
print("Unloaded %s" % key)
|
||||
|
||||
def load_scheduler_state_dict(self, state_dict):
|
||||
for key, val in state_dict:
|
||||
try:
|
||||
self.schedulers[key].load_state_dict(val)
|
||||
except:
|
||||
print("Unloaded %s" % key)
|
||||
|
||||
def step(self, key=None, scaler=None):
|
||||
keys = [key] if key is not None else self.keys
|
||||
_ = [self._step(key, scaler) for key in keys]
|
||||
|
||||
def _step(self, key, scaler=None):
|
||||
if scaler is not None:
|
||||
scaler.step(self.optimizers[key])
|
||||
scaler.update()
|
||||
else:
|
||||
self.optimizers[key].step()
|
||||
|
||||
def zero_grad(self, key=None):
|
||||
if key is not None:
|
||||
self.optimizers[key].zero_grad()
|
||||
else:
|
||||
_ = [self.optimizers[key].zero_grad() for key in self.keys]
|
||||
|
||||
def scheduler(self, *args, key=None):
|
||||
if key is not None:
|
||||
self.schedulers[key].step(*args)
|
||||
else:
|
||||
_ = [self.schedulers[key].step_batch(*args) for key in self.keys]
|
||||
|
||||
|
||||
def define_scheduler(optimizer, params):
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
|
||||
optim = {}
|
||||
for key, model in model_dict.items():
|
||||
model_parameters = model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
if type == "AdamW":
|
||||
optim[key] = AdamW(
|
||||
model_parameters,
|
||||
lr=lr,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-9,
|
||||
weight_decay=0.1,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown optimizer type: %s" % type)
|
||||
|
||||
schedulers = dict(
|
||||
[
|
||||
(key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
|
||||
for key, opt in optim.items()
|
||||
]
|
||||
)
|
||||
|
||||
multi_optim = MultiOptimizer(optim, schedulers)
|
||||
return multi_optim
|
||||
Reference in New Issue
Block a user