feat: Implement Aliyun TTS backend integration and API key management
This commit is contained in:
@@ -36,9 +36,12 @@ async def process_custom_voice_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
backend_type: str,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -51,42 +54,24 @@ async def process_custom_voice_job(
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing custom-voice job {job_id}")
|
||||
logger.info(f"Processing custom-voice job {job_id} with backend {backend_type}")
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("custom-voice")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
user_api_key = None
|
||||
if backend_type == "aliyun":
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user and user.aliyun_api_key:
|
||||
user_api_key = decrypt_api_key(user.aliyun_api_key)
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load custom-voice model")
|
||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||
|
||||
result = tts.generate_custom_voice(
|
||||
text=request_data['text'],
|
||||
language=request_data['language'],
|
||||
speaker=request_data['speaker'],
|
||||
instruct=request_data.get('instruct', ''),
|
||||
max_new_tokens=request_data['max_new_tokens'],
|
||||
temperature=request_data['temperature'],
|
||||
top_k=request_data['top_k'],
|
||||
top_p=request_data['top_p'],
|
||||
repetition_penalty=request_data['repetition_penalty']
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
if isinstance(result, tuple):
|
||||
audio_data = result[0]
|
||||
elif isinstance(result, list):
|
||||
audio_data = np.array(result)
|
||||
else:
|
||||
audio_data = result
|
||||
|
||||
from pathlib import Path
|
||||
audio_bytes, sample_rate = await backend.generate_custom_voice(request_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
save_audio_file(audio_data, 24000, output_path)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
@@ -112,9 +97,12 @@ async def process_voice_design_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
backend_type: str,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -127,41 +115,24 @@ async def process_voice_design_job(
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing voice-design job {job_id}")
|
||||
logger.info(f"Processing voice-design job {job_id} with backend {backend_type}")
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("voice-design")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
user_api_key = None
|
||||
if backend_type == "aliyun":
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user and user.aliyun_api_key:
|
||||
user_api_key = decrypt_api_key(user.aliyun_api_key)
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load voice-design model")
|
||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||
|
||||
result = tts.generate_voice_design(
|
||||
text=request_data['text'],
|
||||
language=request_data['language'],
|
||||
instruct=request_data['instruct'],
|
||||
max_new_tokens=request_data['max_new_tokens'],
|
||||
temperature=request_data['temperature'],
|
||||
top_k=request_data['top_k'],
|
||||
top_p=request_data['top_p'],
|
||||
repetition_penalty=request_data['repetition_penalty']
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
if isinstance(result, tuple):
|
||||
audio_data = result[0]
|
||||
elif isinstance(result, list):
|
||||
audio_data = np.array(result)
|
||||
else:
|
||||
audio_data = result
|
||||
|
||||
from pathlib import Path
|
||||
audio_bytes, sample_rate = await backend.generate_voice_design(request_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
save_audio_file(audio_data, 24000, output_path)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
@@ -188,9 +159,11 @@ async def process_voice_clone_job(
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
ref_audio_path: str,
|
||||
backend_type: str,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
import numpy as np
|
||||
|
||||
db = SessionLocal()
|
||||
@@ -204,7 +177,14 @@ async def process_voice_clone_job(
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing voice-clone job {job_id}")
|
||||
logger.info(f"Processing voice-clone job {job_id} with backend {backend_type}")
|
||||
|
||||
from core.security import decrypt_api_key
|
||||
user_api_key = None
|
||||
if backend_type == "aliyun":
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user and user.aliyun_api_key:
|
||||
user_api_key = decrypt_api_key(user.aliyun_api_key)
|
||||
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
@@ -212,49 +192,49 @@ async def process_voice_clone_job(
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
x_vector = None
|
||||
cache_id = None
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||
if cached:
|
||||
x_vector = cached['data']
|
||||
cache_id = cached['cache_id']
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
if x_vector is None:
|
||||
cache_metrics.record_miss(user_id)
|
||||
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=request_data.get('ref_text', ''),
|
||||
x_vector_only_mode=request_data.get('x_vector_only_mode', False)
|
||||
)
|
||||
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
|
||||
x_vector = None
|
||||
cache_id = None
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': request_data.get('ref_text', ''),
|
||||
'x_vector_only_mode': request_data.get('x_vector_only_mode', False)
|
||||
}
|
||||
cache_id = await cache_manager.set_cache(
|
||||
user_id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
|
||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||
if cached:
|
||||
x_vector = cached['data']
|
||||
cache_id = cached['cache_id']
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
if x_vector is None:
|
||||
cache_metrics.record_miss(user_id)
|
||||
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=request_data.get('ref_text', ''),
|
||||
x_vector_only_mode=True
|
||||
)
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': request_data.get('ref_text', ''),
|
||||
'x_vector_only_mode': True
|
||||
}
|
||||
cache_id = await cache_manager.set_cache(
|
||||
user_id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
if request_data.get('x_vector_only_mode', False):
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = f"x_vector_cached_{cache_id}"
|
||||
job.completed_at = datetime.utcnow()
|
||||
@@ -262,31 +242,16 @@ async def process_voice_clone_job(
|
||||
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
|
||||
return
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
wavs, sample_rate = tts.generate_voice_clone(
|
||||
text=request_data['text'],
|
||||
language=request_data['language'],
|
||||
voice_clone_prompt=x_vector,
|
||||
max_new_tokens=request_data['max_new_tokens'],
|
||||
temperature=request_data['temperature'],
|
||||
top_k=request_data['top_k'],
|
||||
top_p=request_data['top_p'],
|
||||
repetition_penalty=request_data['repetition_penalty']
|
||||
)
|
||||
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
save_audio_file(audio_data, sample_rate, output_path)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
@@ -319,6 +284,16 @@ async def create_custom_voice_job(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
backend_type = req_data.backend or settings.DEFAULT_BACKEND
|
||||
if backend_type == "aliyun":
|
||||
if not current_user.aliyun_api_key:
|
||||
raise HTTPException(status_code=400, detail="Aliyun API key not configured. Please set your API key first.")
|
||||
user_api_key = decrypt_api_key(current_user.aliyun_api_key)
|
||||
if not user_api_key:
|
||||
raise HTTPException(status_code=400, detail="Invalid Aliyun API key. Please update your API key.")
|
||||
|
||||
try:
|
||||
validate_text_length(req_data.text)
|
||||
language = validate_language(req_data.language)
|
||||
@@ -339,6 +314,7 @@ async def create_custom_voice_job(
|
||||
user_id=current_user.id,
|
||||
job_type="custom-voice",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type=backend_type,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
@@ -365,6 +341,7 @@ async def create_custom_voice_job(
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
@@ -384,6 +361,16 @@ async def create_voice_design_job(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
backend_type = req_data.backend or settings.DEFAULT_BACKEND
|
||||
if backend_type == "aliyun":
|
||||
if not current_user.aliyun_api_key:
|
||||
raise HTTPException(status_code=400, detail="Aliyun API key not configured. Please set your API key first.")
|
||||
user_api_key = decrypt_api_key(current_user.aliyun_api_key)
|
||||
if not user_api_key:
|
||||
raise HTTPException(status_code=400, detail="Invalid Aliyun API key. Please update your API key.")
|
||||
|
||||
try:
|
||||
validate_text_length(req_data.text)
|
||||
language = validate_language(req_data.language)
|
||||
@@ -406,6 +393,7 @@ async def create_voice_design_job(
|
||||
user_id=current_user.id,
|
||||
job_type="voice-design",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type=backend_type,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
@@ -430,6 +418,7 @@ async def create_voice_design_job(
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
@@ -455,10 +444,21 @@ async def create_voice_clone_job(
|
||||
top_k: Optional[int] = Form(default=50),
|
||||
top_p: Optional[float] = Form(default=1.0),
|
||||
repetition_penalty: Optional[float] = Form(default=1.05),
|
||||
backend: Optional[str] = Form(default=None),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
backend_type = backend or settings.DEFAULT_BACKEND
|
||||
if backend_type == "aliyun":
|
||||
if not current_user.aliyun_api_key:
|
||||
raise HTTPException(status_code=400, detail="Aliyun API key not configured. Please set your API key first.")
|
||||
user_api_key = decrypt_api_key(current_user.aliyun_api_key)
|
||||
if not user_api_key:
|
||||
raise HTTPException(status_code=400, detail="Invalid Aliyun API key. Please update your API key.")
|
||||
|
||||
try:
|
||||
validate_text_length(text)
|
||||
language = validate_language(language)
|
||||
@@ -486,6 +486,7 @@ async def create_voice_clone_job(
|
||||
user_id=current_user.id,
|
||||
job_type="voice-clone",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type=backend_type,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": text,
|
||||
@@ -520,6 +521,7 @@ async def create_voice_clone_job(
|
||||
current_user.id,
|
||||
request_data,
|
||||
tmp_audio_path,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user