239 lines
8.5 KiB
Python
239 lines
8.5 KiB
Python
"""
|
||
AI API 调用封装,支持 OpenAI、阿里云通义千问、DeepSeek、Ollama(均兼容 OpenAI SDK)
|
||
"""
|
||
import re
|
||
import time
|
||
import logging
|
||
from openai import OpenAI
|
||
import config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
PROVIDER_NAMES = {
|
||
'qwen': '通义千问 (Qwen)',
|
||
'deepseek': 'DeepSeek',
|
||
'openai': 'OpenAI',
|
||
'ollama': 'Ollama 本地',
|
||
'doubao': '豆包 (Doubao)',
|
||
'kimi': 'Kimi (Moonshot)',
|
||
}
|
||
|
||
PROVIDER_LINKS = {
|
||
'qwen': 'https://dashscope.aliyun.com/',
|
||
'deepseek': 'https://platform.deepseek.com/',
|
||
'openai': 'https://platform.openai.com/',
|
||
'ollama': 'https://ollama.com/',
|
||
'doubao': 'https://console.volcengine.com/ark/',
|
||
'kimi': 'https://platform.moonshot.cn/',
|
||
}
|
||
|
||
|
||
def _check_api_key():
|
||
"""调用前预检 API Key,无效时直接抛出友好提示,不做无意义的重试"""
|
||
provider = config.MODEL_PROVIDER
|
||
|
||
# Ollama 本地无需 API Key,跳过检查
|
||
if provider == 'ollama':
|
||
return
|
||
|
||
name = PROVIDER_NAMES.get(provider, provider)
|
||
link = PROVIDER_LINKS.get(provider, '')
|
||
|
||
if provider == 'qwen':
|
||
key = config.QWEN_API_KEY
|
||
elif provider == 'deepseek':
|
||
key = config.DEEPSEEK_API_KEY
|
||
elif provider == 'doubao':
|
||
key = config.DOUBAO_API_KEY
|
||
elif provider == 'kimi':
|
||
key = config.KIMI_API_KEY
|
||
else:
|
||
key = config.OPENAI_API_KEY
|
||
|
||
if not key or key.startswith('sk-your'):
|
||
raise RuntimeError(
|
||
f'尚未配置 {name} 的 API Key。'
|
||
f'请点击右上角设置按钮,选择"{name}"并填入有效的 API Key。'
|
||
f'申请地址:{link}'
|
||
)
|
||
|
||
|
||
def _get_client() -> OpenAI:
|
||
"""根据 MODEL_PROVIDER 返回对应的 OpenAI 兼容客户端"""
|
||
if config.MODEL_PROVIDER == 'qwen':
|
||
return OpenAI(api_key=config.QWEN_API_KEY, base_url=config.QWEN_BASE_URL)
|
||
if config.MODEL_PROVIDER == 'deepseek':
|
||
return OpenAI(api_key=config.DEEPSEEK_API_KEY, base_url=config.DEEPSEEK_BASE_URL)
|
||
if config.MODEL_PROVIDER == 'ollama':
|
||
return OpenAI(api_key='ollama', base_url=config.OLLAMA_BASE_URL)
|
||
if config.MODEL_PROVIDER == 'doubao':
|
||
return OpenAI(api_key=config.DOUBAO_API_KEY, base_url=config.DOUBAO_BASE_URL)
|
||
if config.MODEL_PROVIDER == 'kimi':
|
||
return OpenAI(api_key=config.KIMI_API_KEY, base_url=config.KIMI_BASE_URL)
|
||
return OpenAI(api_key=config.OPENAI_API_KEY, base_url=config.OPENAI_BASE_URL)
|
||
|
||
|
||
def _get_model() -> str:
|
||
if config.MODEL_PROVIDER == 'qwen':
|
||
return config.QWEN_MODEL
|
||
if config.MODEL_PROVIDER == 'deepseek':
|
||
return config.DEEPSEEK_MODEL
|
||
if config.MODEL_PROVIDER == 'ollama':
|
||
return config.OLLAMA_MODEL
|
||
if config.MODEL_PROVIDER == 'doubao':
|
||
return config.DOUBAO_MODEL
|
||
if config.MODEL_PROVIDER == 'kimi':
|
||
return config.KIMI_MODEL
|
||
return config.OPENAI_MODEL
|
||
|
||
|
||
def _clean_response(text: str) -> str:
|
||
"""
|
||
过滤推理模型(DeepSeek R1 / QwQ 等)输出的 <think>...</think> 思考过程标签,
|
||
只保留最终正文内容,避免思考链污染标书正文。
|
||
"""
|
||
# 去除 <think>...</think> 块(含跨行内容)
|
||
text = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
|
||
return text.strip()
|
||
|
||
|
||
def _is_auth_error(e: Exception) -> bool:
|
||
"""判断是否为认证错误(401 / invalid_api_key),无需重试"""
|
||
# 优先用 openai 原生异常类型判断
|
||
try:
|
||
from openai import AuthenticationError, PermissionDeniedError
|
||
if isinstance(e, (AuthenticationError, PermissionDeniedError)):
|
||
return True
|
||
except ImportError:
|
||
pass
|
||
# 兜底:字符串匹配
|
||
err_str = str(e).lower()
|
||
return ('401' in err_str or 'invalid_api_key' in err_str
|
||
or 'incorrect api key' in err_str or 'authentication' in err_str)
|
||
|
||
|
||
# OpenAI o 系列推理模型:不支持 temperature,max_tokens 需用 max_completion_tokens
|
||
_OPENAI_REASONING_MODELS = {'o1', 'o1-mini', 'o1-pro', 'o3', 'o3-mini', 'o3-pro', 'o4-mini'}
|
||
|
||
|
||
def _build_chat_kwargs(model: str, messages: list, temperature: float, max_tokens: int) -> dict:
|
||
"""
|
||
根据模型类型构建 chat.completions.create 的参数字典。
|
||
OpenAI o 系列推理模型不接受 temperature,且使用 max_completion_tokens 替代 max_tokens。
|
||
"""
|
||
base_model = model.split(':')[0] # 去掉 ollama tag 后缀
|
||
is_reasoning = base_model in _OPENAI_REASONING_MODELS
|
||
|
||
kwargs = {
|
||
'model': model,
|
||
'messages': messages,
|
||
'timeout': config.REQUEST_TIMEOUT,
|
||
}
|
||
if is_reasoning:
|
||
kwargs['max_completion_tokens'] = max_tokens
|
||
else:
|
||
kwargs['temperature'] = temperature
|
||
kwargs['max_tokens'] = max_tokens
|
||
return kwargs
|
||
|
||
|
||
def chat(prompt: str, system: str = '你是一位专业的投标文件撰写专家。',
|
||
temperature: float = 0.7, max_tokens: int = 8192,
|
||
retries: int = None) -> str:
|
||
"""
|
||
调用 AI 接口,返回文本响应。
|
||
认证错误立即终止;其他错误指数退避重试。
|
||
自动兼容 OpenAI o 系列推理模型的参数差异。
|
||
"""
|
||
_check_api_key()
|
||
|
||
max_retries = retries if retries is not None else config.MAX_RETRIES
|
||
client = _get_client()
|
||
model = _get_model()
|
||
provider = config.MODEL_PROVIDER
|
||
name = PROVIDER_NAMES.get(provider, provider)
|
||
|
||
messages = [
|
||
{'role': 'system', 'content': system},
|
||
{'role': 'user', 'content': prompt},
|
||
]
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
kwargs = _build_chat_kwargs(model, messages, temperature, max_tokens)
|
||
resp = client.chat.completions.create(**kwargs)
|
||
return _clean_response(resp.choices[0].message.content.strip())
|
||
except Exception as e:
|
||
if _is_auth_error(e):
|
||
raise RuntimeError(
|
||
f'{name} API Key 无效或已过期,请在设置中重新配置。'
|
||
f'申请地址:{PROVIDER_LINKS.get(provider, "")}'
|
||
) from e
|
||
|
||
wait = 2 ** attempt
|
||
logger.warning(f'AI 请求失败 (第{attempt+1}次),{wait}s 后重试: {e}')
|
||
if attempt < max_retries - 1:
|
||
time.sleep(wait)
|
||
else:
|
||
raise RuntimeError(f'AI 接口调用失败(已重试 {max_retries} 次): {e}') from e
|
||
|
||
return ''
|
||
|
||
|
||
def chat_with_history(system: str, messages: list,
|
||
temperature: float = 0.7, max_tokens: int = 4096) -> str:
|
||
"""
|
||
多轮对话接口,支持完整历史上下文,用于对话式章节生成。
|
||
messages 格式:[{'role': 'user'|'assistant', 'content': str}, ...]
|
||
"""
|
||
_check_api_key()
|
||
|
||
client = _get_client()
|
||
model = _get_model()
|
||
provider = config.MODEL_PROVIDER
|
||
name = PROVIDER_NAMES.get(provider, provider)
|
||
|
||
full_messages = [{'role': 'system', 'content': system}] + messages
|
||
|
||
for attempt in range(config.MAX_RETRIES):
|
||
try:
|
||
kwargs = _build_chat_kwargs(model, full_messages, temperature, max_tokens)
|
||
resp = client.chat.completions.create(**kwargs)
|
||
return _clean_response(resp.choices[0].message.content.strip())
|
||
except Exception as e:
|
||
if _is_auth_error(e):
|
||
raise RuntimeError(
|
||
f'{name} API Key 无效或已过期,请在设置中重新配置。'
|
||
f'申请地址:{PROVIDER_LINKS.get(provider, "")}'
|
||
) from e
|
||
wait = 2 ** attempt
|
||
logger.warning(f'对话 AI 请求失败 (第{attempt+1}次),{wait}s 后重试: {e}')
|
||
if attempt < config.MAX_RETRIES - 1:
|
||
time.sleep(wait)
|
||
else:
|
||
raise RuntimeError(f'AI 接口调用失败(已重试 {config.MAX_RETRIES} 次): {e}') from e
|
||
|
||
return ''
|
||
|
||
|
||
def get_embeddings(texts: list[str]) -> list[list[float]]:
|
||
"""获取文本嵌入向量。
|
||
支持 Qwen、OpenAI、Kimi;DeepSeek / Ollama / 豆包 暂不提供 Embedding API。
|
||
"""
|
||
provider = config.MODEL_PROVIDER
|
||
if provider in ('deepseek', 'ollama', 'doubao'):
|
||
raise NotImplementedError(
|
||
f'{PROVIDER_NAMES.get(provider)} 暂不支持 Embedding API,知识库将使用关键词检索降级'
|
||
)
|
||
|
||
client = _get_client()
|
||
if provider == 'qwen':
|
||
model = config.QWEN_EMBEDDING_MODEL
|
||
elif provider == 'kimi':
|
||
model = config.KIMI_EMBEDDING_MODEL
|
||
else:
|
||
model = config.OPENAI_EMBEDDING_MODEL
|
||
|
||
resp = client.embeddings.create(model=model, input=texts)
|
||
return [item.embedding for item in resp.data]
|