Merge pull request #5 from bdim404/dev

feat: Enhance stream_chat methods to accept max_tokens parameter
This commit is contained in:
2026-03-11 18:48:18 +08:00
committed by GitHub

View File

@@ -14,7 +14,7 @@ class LLMService:
self.api_key = api_key self.api_key = api_key
self.model = model self.model = model
async def stream_chat(self, system_prompt: str, user_message: str, on_token=None) -> str: async def stream_chat(self, system_prompt: str, user_message: str, on_token=None, max_tokens: int = 8192) -> str:
url = f"{self.base_url}/chat/completions" url = f"{self.base_url}/chat/completions"
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
@@ -27,6 +27,7 @@ class LLMService:
{"role": "user", "content": user_message}, {"role": "user", "content": user_message},
], ],
"temperature": 0.3, "temperature": 0.3,
"max_tokens": max_tokens,
"stream": True, "stream": True,
} }
full_text = "" full_text = ""
@@ -54,8 +55,8 @@ class LLMService:
continue continue
return full_text return full_text
async def stream_chat_json(self, system_prompt: str, user_message: str, on_token=None): async def stream_chat_json(self, system_prompt: str, user_message: str, on_token=None, max_tokens: int = 8192):
raw = await self.stream_chat(system_prompt, user_message, on_token) raw = await self.stream_chat(system_prompt, user_message, on_token, max_tokens=max_tokens)
raw = raw.strip() raw = raw.strip()
if not raw: if not raw:
raise ValueError("LLM returned empty response") raise ValueError("LLM returned empty response")
@@ -204,7 +205,7 @@ class LLMService:
'[{"character": "narrator", "text": "叙述文字"}, {"character": "角色名", "text": "对话内容"}, ...]' '[{"character": "narrator", "text": "叙述文字"}, {"character": "角色名", "text": "对话内容"}, ...]'
) )
user_message = f"请解析以下章节文本:\n\n{chapter_text}" user_message = f"请解析以下章节文本:\n\n{chapter_text}"
result = await self.stream_chat_json(system_prompt, user_message, on_token) result = await self.stream_chat_json(system_prompt, user_message, on_token, max_tokens=16384)
if isinstance(result, list): if isinstance(result, list):
return result return result
return [] return []