feat: add NSFW script generation feature and Grok API configuration
This commit is contained in:
@@ -8,7 +8,7 @@ from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import settings
|
||||
from core.llm_service import LLMService
|
||||
from core.llm_service import LLMService, GrokLLMService
|
||||
from core import progress_store as ps
|
||||
from db import crud
|
||||
from db.models import AudiobookProject, AudiobookCharacter, User
|
||||
@@ -44,6 +44,20 @@ def _get_llm_service(db: Session) -> LLMService:
|
||||
return LLMService(base_url=base_url, api_key=api_key, model=model)
|
||||
|
||||
|
||||
def _get_grok_service(db: Session) -> GrokLLMService:
|
||||
from core.security import decrypt_api_key
|
||||
from db.crud import get_system_setting
|
||||
api_key_encrypted = get_system_setting(db, "grok_api_key")
|
||||
base_url = get_system_setting(db, "grok_base_url")
|
||||
model = get_system_setting(db, "grok_model") or "grok-4"
|
||||
if not api_key_encrypted or not base_url:
|
||||
raise ValueError("Grok config not set. Please configure Grok API key and base URL in admin settings.")
|
||||
api_key = decrypt_api_key(api_key_encrypted)
|
||||
if not api_key:
|
||||
raise ValueError("Failed to decrypt Grok API key.")
|
||||
return GrokLLMService(base_url=base_url, api_key=api_key, model=model)
|
||||
|
||||
|
||||
def _get_gendered_instruct(gender: Optional[str], base_instruct: str) -> str:
|
||||
"""Ensure the instruction sent to the TTS model has explicit gender cues if known."""
|
||||
if not gender or gender == "未知":
|
||||
@@ -1472,3 +1486,136 @@ async def generate_character_preview(project_id: int, char_id: int, user: User,
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate preview for char {char_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def generate_ai_script_nsfw(project_id: int, user: User, db: Session) -> None:
|
||||
from core.database import SessionLocal
|
||||
|
||||
project = db.query(AudiobookProject).filter(AudiobookProject.id == project_id).first()
|
||||
if not project or not project.script_config:
|
||||
return
|
||||
|
||||
key = str(project_id)
|
||||
ps.reset(key)
|
||||
cfg = project.script_config
|
||||
|
||||
try:
|
||||
crud.update_audiobook_project_status(db, project_id, "analyzing")
|
||||
ps.append_line(key, f"[NSFW剧本] 项目「{project.title}」开始生成剧本")
|
||||
|
||||
llm = _get_grok_service(db)
|
||||
_llm_model = crud.get_system_setting(db, "grok_model") or "grok-4"
|
||||
_user_id = user.id
|
||||
|
||||
def _log_usage(prompt_tokens: int, completion_tokens: int) -> None:
|
||||
log_db = SessionLocal()
|
||||
try:
|
||||
crud.create_usage_log(log_db, _user_id, prompt_tokens, completion_tokens,
|
||||
model=_llm_model, context="nsfw_script_generate")
|
||||
finally:
|
||||
log_db.close()
|
||||
|
||||
genre = cfg.get("genre", "")
|
||||
subgenre = cfg.get("subgenre", "")
|
||||
premise = cfg.get("premise", "")
|
||||
style = cfg.get("style", "")
|
||||
num_characters = cfg.get("num_characters", 5)
|
||||
num_chapters = cfg.get("num_chapters", 8)
|
||||
|
||||
ps.append_line(key, f"\n[Step 1] 生成 {num_characters} 个角色...\n")
|
||||
ps.append_line(key, "")
|
||||
|
||||
def on_token(token: str) -> None:
|
||||
ps.append_token(key, token)
|
||||
|
||||
characters_data = await llm.generate_story_characters(
|
||||
genre=genre, subgenre=subgenre, premise=premise, style=style,
|
||||
num_characters=num_characters, usage_callback=_log_usage,
|
||||
)
|
||||
|
||||
has_narrator = any(c.get("name") in ("narrator", "旁白") for c in characters_data)
|
||||
if not has_narrator:
|
||||
characters_data.insert(0, {
|
||||
"name": "旁白",
|
||||
"gender": "未知",
|
||||
"description": "第三人称旁白叙述者",
|
||||
"instruct": (
|
||||
"音色信息:浑厚醇厚的男性中低音,嗓音饱满有力,带有传统说书人的磁性与感染力\n"
|
||||
"身份背景:中国传统说书艺人,精通评书、章回小说叙述艺术,深谙故事节奏与听众心理\n"
|
||||
"年龄设定:中年男性,四五十岁,声音历经岁月沉淀,成熟稳重而不失活力\n"
|
||||
"外貌特征:面容沉稳,气度从容,台风大气,给人以可信赖的叙述者印象\n"
|
||||
"性格特质:沉稳睿智,叙事冷静客观,情到深处能引发共鸣,不动声色间娓娓道来\n"
|
||||
"叙事风格:语速适中偏慢,抑扬顿挫,擅长铺垫悬念,停顿恰到好处,语气庄重而生动,富有画面感"
|
||||
)
|
||||
})
|
||||
|
||||
ps.append_line(key, f"\n\n[完成] 角色列表:{', '.join(c.get('name', '') for c in characters_data)}")
|
||||
|
||||
crud.delete_audiobook_segments(db, project_id)
|
||||
crud.delete_audiobook_characters(db, project_id)
|
||||
|
||||
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
|
||||
|
||||
for char_data in characters_data:
|
||||
name = char_data.get("name", "旁白")
|
||||
if name == "narrator":
|
||||
name = "旁白"
|
||||
instruct = char_data.get("instruct", "")
|
||||
description = char_data.get("description", "")
|
||||
gender = char_data.get("gender") or ("未知" if name == "旁白" else None)
|
||||
try:
|
||||
voice_design = crud.create_voice_design(
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
name=f"[有声书] {project.title} - {name}",
|
||||
instruct=instruct,
|
||||
backend_type=backend_type,
|
||||
preview_text=description[:100] if description else None,
|
||||
)
|
||||
crud.create_audiobook_character(
|
||||
db=db,
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
gender=gender,
|
||||
description=description,
|
||||
instruct=instruct,
|
||||
voice_design_id=voice_design.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create char/voice for {name}: {e}")
|
||||
|
||||
crud.update_audiobook_project_status(db, project_id, "characters_ready")
|
||||
ps.append_line(key, f"\n[状态] 角色创建完成,请确认角色后继续生成剧本")
|
||||
ps.mark_done(key)
|
||||
|
||||
user_id = user.id
|
||||
|
||||
async def _generate_all_previews():
|
||||
temp_db = SessionLocal()
|
||||
try:
|
||||
characters = crud.list_audiobook_characters(temp_db, project_id)
|
||||
char_ids = [c.id for c in characters]
|
||||
finally:
|
||||
temp_db.close()
|
||||
if not char_ids:
|
||||
return
|
||||
sem = asyncio.Semaphore(3)
|
||||
async def _gen(char_id: int):
|
||||
async with sem:
|
||||
local_db = SessionLocal()
|
||||
try:
|
||||
db_user = crud.get_user_by_id(local_db, user_id)
|
||||
await generate_character_preview(project_id, char_id, db_user, local_db)
|
||||
except Exception as e:
|
||||
logger.error(f"Background preview failed for char {char_id}: {e}")
|
||||
finally:
|
||||
local_db.close()
|
||||
await asyncio.gather(*[_gen(cid) for cid in char_ids])
|
||||
|
||||
asyncio.create_task(_generate_all_previews())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"generate_ai_script_nsfw failed for project {project_id}: {e}", exc_info=True)
|
||||
ps.append_line(key, f"\n[错误] {e}")
|
||||
ps.mark_done(key)
|
||||
crud.update_audiobook_project_status(db, project_id, "error", error_message=str(e))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import httpx
|
||||
@@ -8,6 +9,22 @@ import httpx
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def strip_grok_thinking(text: str) -> str:
|
||||
lines = text.split('\n')
|
||||
cleaned = []
|
||||
for line in lines:
|
||||
if line.startswith('> '):
|
||||
continue
|
||||
cleaned.append(line)
|
||||
result = []
|
||||
for line in cleaned:
|
||||
if result and line and not line.startswith('【') and result[-1] != '':
|
||||
result[-1] += line
|
||||
else:
|
||||
result.append(line)
|
||||
return '\n'.join(result).strip()
|
||||
|
||||
|
||||
class LLMService:
|
||||
def __init__(self, base_url: str, api_key: str, model: str):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
@@ -68,11 +85,15 @@ class LLMService:
|
||||
if not raw:
|
||||
raise ValueError("LLM returned empty response")
|
||||
if raw.startswith("```"):
|
||||
lines = raw.split("\n")
|
||||
inner = lines[1:]
|
||||
if inner and inner[-1].strip().startswith("```"):
|
||||
inner = inner[:-1]
|
||||
raw = "\n".join(inner).strip()
|
||||
m = re.search(r'^```[a-z]*\n?([\s\S]*?)```\s*$', raw)
|
||||
if m:
|
||||
raw = m.group(1).strip()
|
||||
else:
|
||||
lines = raw.split("\n")
|
||||
inner = lines[1:]
|
||||
if inner and inner[-1].strip().startswith("```"):
|
||||
inner = inner[:-1]
|
||||
raw = "\n".join(inner).strip()
|
||||
if not raw:
|
||||
raise ValueError("LLM returned empty JSON after stripping markdown")
|
||||
try:
|
||||
@@ -115,11 +136,15 @@ class LLMService:
|
||||
if not raw:
|
||||
raise ValueError("LLM returned empty response")
|
||||
if raw.startswith("```"):
|
||||
lines = raw.split("\n")
|
||||
inner = lines[1:]
|
||||
if inner and inner[-1].strip().startswith("```"):
|
||||
inner = inner[:-1]
|
||||
raw = "\n".join(inner).strip()
|
||||
m = re.search(r'^```[a-z]*\n?([\s\S]*?)```\s*$', raw)
|
||||
if m:
|
||||
raw = m.group(1).strip()
|
||||
else:
|
||||
lines = raw.split("\n")
|
||||
inner = lines[1:]
|
||||
if inner and inner[-1].strip().startswith("```"):
|
||||
inner = inner[:-1]
|
||||
raw = "\n".join(inner).strip()
|
||||
if not raw:
|
||||
raise ValueError("LLM returned empty JSON after stripping markdown")
|
||||
try:
|
||||
@@ -379,3 +404,13 @@ class LLMService:
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
class GrokLLMService(LLMService):
|
||||
async def stream_chat(self, system_prompt: str, user_message: str, on_token=None, max_tokens: int = 8192, usage_callback=None) -> str:
|
||||
full_text = await super().stream_chat(system_prompt, user_message, on_token, max_tokens=max_tokens, usage_callback=usage_callback)
|
||||
return strip_grok_thinking(full_text)
|
||||
|
||||
async def chat(self, system_prompt: str, user_message: str, usage_callback=None) -> str:
|
||||
full_text = await super().chat(system_prompt, user_message, usage_callback=usage_callback)
|
||||
return strip_grok_thinking(full_text)
|
||||
|
||||
Reference in New Issue
Block a user