2026-04-23 14:36:26 +08:00

260 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI API 调用封装,支持 OpenAI、阿里云通义千问、DeepSeek、Ollama均兼容 OpenAI SDK
"""
import re
import time
import logging
from openai import OpenAI
import config
from contextlib import contextmanager # for type hints if needed
logger = logging.getLogger(__name__)
PROVIDER_NAMES = {
'qwen': '通义千问 (Qwen)',
'deepseek': 'DeepSeek',
'openai': 'OpenAI',
'ollama': 'Ollama 本地',
'doubao': '豆包 (Doubao)',
'kimi': 'Kimi (Moonshot)',
}
PROVIDER_LINKS = {
'qwen': 'https://dashscope.aliyun.com/',
'deepseek': 'https://platform.deepseek.com/',
'openai': 'https://platform.openai.com/',
'ollama': 'https://ollama.com/',
'doubao': 'https://console.volcengine.com/ark/',
'kimi': 'https://platform.moonshot.cn/',
}
def _check_api_key():
"""调用前预检 API Key无效时直接抛出友好提示不做无意义的重试"""
provider = config.MODEL_PROVIDER
# Ollama 本地无需 API Key跳过检查
if provider == 'ollama':
return
name = PROVIDER_NAMES.get(provider, provider)
link = PROVIDER_LINKS.get(provider, '')
if provider == 'qwen':
key = config.QWEN_API_KEY
elif provider == 'deepseek':
key = config.DEEPSEEK_API_KEY
elif provider == 'doubao':
key = config.DOUBAO_API_KEY
elif provider == 'kimi':
key = config.KIMI_API_KEY
else:
key = config.OPENAI_API_KEY
if not key or key.startswith('sk-your'):
raise RuntimeError(
f'尚未配置 {name} 的 API Key。'
f'请点击右上角设置按钮,选择"{name}"并填入有效的 API Key。'
f'申请地址:{link}'
)
def _get_client() -> OpenAI:
"""根据 MODEL_PROVIDER 返回对应的 OpenAI 兼容客户端"""
if config.MODEL_PROVIDER == 'qwen':
return OpenAI(api_key=config.QWEN_API_KEY, base_url=config.QWEN_BASE_URL)
if config.MODEL_PROVIDER == 'deepseek':
return OpenAI(api_key=config.DEEPSEEK_API_KEY, base_url=config.DEEPSEEK_BASE_URL)
if config.MODEL_PROVIDER == 'ollama':
return OpenAI(api_key='ollama', base_url=config.OLLAMA_BASE_URL)
if config.MODEL_PROVIDER == 'doubao':
return OpenAI(api_key=config.DOUBAO_API_KEY, base_url=config.DOUBAO_BASE_URL)
if config.MODEL_PROVIDER == 'kimi':
return OpenAI(api_key=config.KIMI_API_KEY, base_url=config.KIMI_BASE_URL)
return OpenAI(api_key=config.OPENAI_API_KEY, base_url=config.OPENAI_BASE_URL)
def _get_model() -> str:
if config.MODEL_PROVIDER == 'qwen':
return config.QWEN_MODEL
if config.MODEL_PROVIDER == 'deepseek':
return config.DEEPSEEK_MODEL
if config.MODEL_PROVIDER == 'ollama':
return config.OLLAMA_MODEL
if config.MODEL_PROVIDER == 'doubao':
return config.DOUBAO_MODEL
if config.MODEL_PROVIDER == 'kimi':
return config.KIMI_MODEL
return config.OPENAI_MODEL
def _clean_response(text: str) -> str:
"""
过滤推理模型DeepSeek R1 / QwQ 等)输出的 <think>...</think> 思考过程标签,
只保留最终正文内容,避免思考链污染标书正文。
"""
# 去除 <think>...</think> 块(含跨行内容)
text = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
return text.strip()
def _is_auth_error(e: Exception) -> bool:
"""判断是否为认证错误401 / invalid_api_key无需重试"""
# 优先用 openai 原生异常类型判断
try:
from openai import AuthenticationError, PermissionDeniedError
if isinstance(e, (AuthenticationError, PermissionDeniedError)):
return True
except ImportError:
pass
# 兜底:字符串匹配
err_str = str(e).lower()
return ('401' in err_str or 'invalid_api_key' in err_str
or 'incorrect api key' in err_str or 'authentication' in err_str)
# OpenAI o 系列推理模型:不支持 temperaturemax_tokens 需用 max_completion_tokens
_OPENAI_REASONING_MODELS = {'o1', 'o1-mini', 'o1-pro', 'o3', 'o3-mini', 'o3-pro', 'o4-mini'}
def _build_chat_kwargs(
model: str,
messages: list,
temperature: float,
max_tokens: int,
request_timeout: float | None = None,
) -> dict:
"""
根据模型类型构建 chat.completions.create 的参数字典。
OpenAI o 系列推理模型不接受 temperature且使用 max_completion_tokens 替代 max_tokens。
"""
base_model = model.split(':')[0] # 去掉 ollama tag 后缀
is_reasoning = base_model in _OPENAI_REASONING_MODELS
to = request_timeout if request_timeout is not None else config.REQUEST_TIMEOUT
kwargs = {
'model': model,
'messages': messages,
'timeout': to,
}
if is_reasoning:
kwargs['max_completion_tokens'] = max_tokens
else:
kwargs['temperature'] = temperature
kwargs['max_tokens'] = max_tokens
return kwargs
def chat(
prompt: str,
system: str = '你是一位专业的投标文件撰写专家。',
temperature: float = 0.7,
max_tokens: int = 8192,
retries: int = None,
request_timeout: float | None = None,
) -> str:
"""
调用 AI 接口,返回文本响应。
认证错误立即终止;其他错误指数退避重试。
自动兼容 OpenAI o 系列推理模型的参数差异。
所有调用受全局LLM_SEMAPHORE(上限20)保护,实现极速并发优化。
"""
_check_api_key()
max_retries = retries if retries is not None else config.MAX_RETRIES
client = _get_client()
model = _get_model()
provider = config.MODEL_PROVIDER
name = PROVIDER_NAMES.get(provider, provider)
messages = [
{'role': 'system', 'content': system},
{'role': 'user', 'content': prompt},
]
for attempt in range(max_retries):
try:
with config.llm_call(): # 全局并发控制上限20
kwargs = _build_chat_kwargs(
model, messages, temperature, max_tokens, request_timeout=request_timeout
)
resp = client.chat.completions.create(**kwargs)
return _clean_response(resp.choices[0].message.content.strip())
except Exception as e:
if _is_auth_error(e):
raise RuntimeError(
f'{name} API Key 无效或已过期,请在设置中重新配置。'
f'申请地址:{PROVIDER_LINKS.get(provider, "")}'
) from e
wait = 2 ** attempt
logger.warning(f'AI 请求失败 (第{attempt+1}次){wait}s 后重试: {e}')
if attempt < max_retries - 1:
time.sleep(wait)
else:
raise RuntimeError(f'AI 接口调用失败(已重试 {max_retries} 次): {e}') from e
return ''
def chat_with_history(system: str, messages: list,
temperature: float = 0.7, max_tokens: int = 4096) -> str:
"""
多轮对话接口,支持完整历史上下文,用于对话式章节生成。
messages 格式:[{'role': 'user'|'assistant', 'content': str}, ...]
受全局LLM_SEMAPHORE保护。
"""
_check_api_key()
client = _get_client()
model = _get_model()
provider = config.MODEL_PROVIDER
name = PROVIDER_NAMES.get(provider, provider)
full_messages = [{'role': 'system', 'content': system}] + messages
for attempt in range(config.MAX_RETRIES):
try:
with config.llm_call(): # 全局并发控制
kwargs = _build_chat_kwargs(model, full_messages, temperature, max_tokens)
resp = client.chat.completions.create(**kwargs)
return _clean_response(resp.choices[0].message.content.strip())
except Exception as e:
if _is_auth_error(e):
raise RuntimeError(
f'{name} API Key 无效或已过期,请在设置中重新配置。'
f'申请地址:{PROVIDER_LINKS.get(provider, "")}'
) from e
wait = 2 ** attempt
logger.warning(f'对话 AI 请求失败 (第{attempt+1}次){wait}s 后重试: {e}')
if attempt < config.MAX_RETRIES - 1:
time.sleep(wait)
else:
raise RuntimeError(f'AI 接口调用失败(已重试 {config.MAX_RETRIES} 次): {e}') from e
return ''
def get_embeddings(texts: list[str]) -> list[list[float]]:
"""获取文本嵌入向量。
支持 Qwen、OpenAI、KimiDeepSeek / Ollama / 豆包 暂不提供 Embedding API。
受全局LLM_SEMAPHORE保护嵌入调用计入并发上限
"""
provider = config.MODEL_PROVIDER
if provider in ('deepseek', 'ollama', 'doubao'):
raise NotImplementedError(
f'{PROVIDER_NAMES.get(provider)} 暂不支持 Embedding API知识库将使用关键词检索降级'
)
client = _get_client()
if provider == 'qwen':
model = config.QWEN_EMBEDDING_MODEL
elif provider == 'kimi':
model = config.KIMI_EMBEDDING_MODEL
else:
model = config.OPENAI_EMBEDDING_MODEL
with config.llm_call(): # 嵌入也受并发限制
resp = client.embeddings.create(model=model, input=texts)
return [item.embedding for item in resp.data]