""" AI API 调用封装,支持 OpenAI、阿里云通义千问、DeepSeek、Ollama(均兼容 OpenAI SDK) """ import re import time import logging from openai import OpenAI import config from contextlib import contextmanager # for type hints if needed logger = logging.getLogger(__name__) PROVIDER_NAMES = { 'qwen': '通义千问 (Qwen)', 'deepseek': 'DeepSeek', 'openai': 'OpenAI', 'ollama': 'Ollama 本地', 'doubao': '豆包 (Doubao)', 'kimi': 'Kimi (Moonshot)', } PROVIDER_LINKS = { 'qwen': 'https://dashscope.aliyun.com/', 'deepseek': 'https://platform.deepseek.com/', 'openai': 'https://platform.openai.com/', 'ollama': 'https://ollama.com/', 'doubao': 'https://console.volcengine.com/ark/', 'kimi': 'https://platform.moonshot.cn/', } def _check_api_key(): """调用前预检 API Key,无效时直接抛出友好提示,不做无意义的重试""" provider = config.MODEL_PROVIDER # Ollama 本地无需 API Key,跳过检查 if provider == 'ollama': return name = PROVIDER_NAMES.get(provider, provider) link = PROVIDER_LINKS.get(provider, '') if provider == 'qwen': key = config.QWEN_API_KEY elif provider == 'deepseek': key = config.DEEPSEEK_API_KEY elif provider == 'doubao': key = config.DOUBAO_API_KEY elif provider == 'kimi': key = config.KIMI_API_KEY else: key = config.OPENAI_API_KEY if not key or key.startswith('sk-your'): raise RuntimeError( f'尚未配置 {name} 的 API Key。' f'请点击右上角设置按钮,选择"{name}"并填入有效的 API Key。' f'申请地址:{link}' ) def _get_client() -> OpenAI: """根据 MODEL_PROVIDER 返回对应的 OpenAI 兼容客户端""" if config.MODEL_PROVIDER == 'qwen': return OpenAI(api_key=config.QWEN_API_KEY, base_url=config.QWEN_BASE_URL) if config.MODEL_PROVIDER == 'deepseek': return OpenAI(api_key=config.DEEPSEEK_API_KEY, base_url=config.DEEPSEEK_BASE_URL) if config.MODEL_PROVIDER == 'ollama': return OpenAI(api_key='ollama', base_url=config.OLLAMA_BASE_URL) if config.MODEL_PROVIDER == 'doubao': return OpenAI(api_key=config.DOUBAO_API_KEY, base_url=config.DOUBAO_BASE_URL) if config.MODEL_PROVIDER == 'kimi': return OpenAI(api_key=config.KIMI_API_KEY, base_url=config.KIMI_BASE_URL) return OpenAI(api_key=config.OPENAI_API_KEY, base_url=config.OPENAI_BASE_URL) def _get_model() -> str: if config.MODEL_PROVIDER == 'qwen': return config.QWEN_MODEL if config.MODEL_PROVIDER == 'deepseek': return config.DEEPSEEK_MODEL if config.MODEL_PROVIDER == 'ollama': return config.OLLAMA_MODEL if config.MODEL_PROVIDER == 'doubao': return config.DOUBAO_MODEL if config.MODEL_PROVIDER == 'kimi': return config.KIMI_MODEL return config.OPENAI_MODEL def _clean_response(text: str) -> str: """ 过滤推理模型(DeepSeek R1 / QwQ 等)输出的 ... 思考过程标签, 只保留最终正文内容,避免思考链污染标书正文。 """ # 去除 ... 块(含跨行内容) text = re.sub(r'[\s\S]*?', '', text, flags=re.IGNORECASE) return text.strip() def _is_auth_error(e: Exception) -> bool: """判断是否为认证错误(401 / invalid_api_key),无需重试""" # 优先用 openai 原生异常类型判断 try: from openai import AuthenticationError, PermissionDeniedError if isinstance(e, (AuthenticationError, PermissionDeniedError)): return True except ImportError: pass # 兜底:字符串匹配 err_str = str(e).lower() return ('401' in err_str or 'invalid_api_key' in err_str or 'incorrect api key' in err_str or 'authentication' in err_str) # OpenAI o 系列推理模型:不支持 temperature,max_tokens 需用 max_completion_tokens _OPENAI_REASONING_MODELS = {'o1', 'o1-mini', 'o1-pro', 'o3', 'o3-mini', 'o3-pro', 'o4-mini'} def _build_chat_kwargs( model: str, messages: list, temperature: float, max_tokens: int, request_timeout: float | None = None, ) -> dict: """ 根据模型类型构建 chat.completions.create 的参数字典。 OpenAI o 系列推理模型不接受 temperature,且使用 max_completion_tokens 替代 max_tokens。 """ base_model = model.split(':')[0] # 去掉 ollama tag 后缀 is_reasoning = base_model in _OPENAI_REASONING_MODELS to = request_timeout if request_timeout is not None else config.REQUEST_TIMEOUT kwargs = { 'model': model, 'messages': messages, 'timeout': to, } if is_reasoning: kwargs['max_completion_tokens'] = max_tokens else: kwargs['temperature'] = temperature kwargs['max_tokens'] = max_tokens return kwargs def chat( prompt: str, system: str = '你是一位专业的投标文件撰写专家。', temperature: float = 0.7, max_tokens: int = 8192, retries: int = None, request_timeout: float | None = None, ) -> str: """ 调用 AI 接口,返回文本响应。 认证错误立即终止;其他错误指数退避重试。 自动兼容 OpenAI o 系列推理模型的参数差异。 所有调用受全局 LLM 信号量保护(默认 40 路,与 config 一致)。 """ _check_api_key() max_retries = retries if retries is not None else config.MAX_RETRIES client = _get_client() model = _get_model() provider = config.MODEL_PROVIDER name = PROVIDER_NAMES.get(provider, provider) messages = [ {'role': 'system', 'content': system}, {'role': 'user', 'content': prompt}, ] for attempt in range(max_retries): try: with config.llm_call(): # 全局并发控制 kwargs = _build_chat_kwargs( model, messages, temperature, max_tokens, request_timeout=request_timeout ) resp = client.chat.completions.create(**kwargs) return _clean_response(resp.choices[0].message.content.strip()) except Exception as e: if _is_auth_error(e): raise RuntimeError( f'{name} API Key 无效或已过期,请在设置中重新配置。' f'申请地址:{PROVIDER_LINKS.get(provider, "")}' ) from e wait = 2 ** attempt logger.warning(f'AI 请求失败 (第{attempt+1}次),{wait}s 后重试: {e}') if attempt < max_retries - 1: time.sleep(wait) else: raise RuntimeError(f'AI 接口调用失败(已重试 {max_retries} 次): {e}') from e return '' def chat_with_history(system: str, messages: list, temperature: float = 0.7, max_tokens: int = 4096) -> str: """ 多轮对话接口,支持完整历史上下文,用于对话式章节生成。 messages 格式:[{'role': 'user'|'assistant', 'content': str}, ...] 受全局LLM_SEMAPHORE保护。 """ _check_api_key() client = _get_client() model = _get_model() provider = config.MODEL_PROVIDER name = PROVIDER_NAMES.get(provider, provider) full_messages = [{'role': 'system', 'content': system}] + messages for attempt in range(config.MAX_RETRIES): try: with config.llm_call(): # 全局并发控制 kwargs = _build_chat_kwargs(model, full_messages, temperature, max_tokens) resp = client.chat.completions.create(**kwargs) return _clean_response(resp.choices[0].message.content.strip()) except Exception as e: if _is_auth_error(e): raise RuntimeError( f'{name} API Key 无效或已过期,请在设置中重新配置。' f'申请地址:{PROVIDER_LINKS.get(provider, "")}' ) from e wait = 2 ** attempt logger.warning(f'对话 AI 请求失败 (第{attempt+1}次),{wait}s 后重试: {e}') if attempt < config.MAX_RETRIES - 1: time.sleep(wait) else: raise RuntimeError(f'AI 接口调用失败(已重试 {config.MAX_RETRIES} 次): {e}') from e return '' def get_embeddings(texts: list[str]) -> list[list[float]]: """获取文本嵌入向量。 支持 Qwen、OpenAI、Kimi;DeepSeek / Ollama / 豆包 暂不提供 Embedding API。 受全局LLM_SEMAPHORE保护(嵌入调用计入并发上限)。 """ provider = config.MODEL_PROVIDER if provider in ('deepseek', 'ollama', 'doubao'): raise NotImplementedError( f'{PROVIDER_NAMES.get(provider)} 暂不支持 Embedding API,知识库将使用关键词检索降级' ) client = _get_client() if provider == 'qwen': model = config.QWEN_EMBEDDING_MODEL elif provider == 'kimi': model = config.KIMI_EMBEDDING_MODEL else: model = config.OPENAI_EMBEDDING_MODEL with config.llm_call(): # 嵌入也受并发限制 resp = client.embeddings.create(model=model, input=texts) return [item.embedding for item in resp.data]