2026-04-23 14:36:26 +08:00

151 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
配置持久化:将用户在界面中设置的 API Key 等配置保存到 data/settings.json
服务重启后自动恢复,不再每次重启都丢失 Key。
"""
import json
import os
import logging
logger = logging.getLogger(__name__)
_SETTINGS_PATH: str = '' # 由 app.py 初始化时注入
def init(settings_path: str):
global _SETTINGS_PATH
_SETTINGS_PATH = settings_path
def load(cfg) -> None:
"""从 settings.json 加载配置,覆盖 config 模块中的默认值"""
if not _SETTINGS_PATH or not os.path.exists(_SETTINGS_PATH):
_apply_env_overrides(cfg)
return
try:
with open(_SETTINGS_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
_apply(cfg, data)
_apply_env_overrides(cfg)
logger.info(f'已从 {_SETTINGS_PATH} 恢复配置,当前 provider={cfg.MODEL_PROVIDER}')
except Exception as e:
logger.warning(f'加载配置文件失败: {e}')
_apply_env_overrides(cfg)
_ENV_API_KEYS = (
('QWEN_API_KEY', 'QWEN_API_KEY'),
('OPENAI_API_KEY', 'OPENAI_API_KEY'),
('DEEPSEEK_API_KEY', 'DEEPSEEK_API_KEY'),
('DOUBAO_API_KEY', 'DOUBAO_API_KEY'),
('KIMI_API_KEY', 'KIMI_API_KEY'),
)
def _apply_env_overrides(cfg) -> None:
"""环境变量中的 API Key 优先于 settings.json便于 Docker / 本机 .env 注入)。"""
mp = os.environ.get('MODEL_PROVIDER')
if mp and isinstance(mp, str) and mp.strip():
cfg.MODEL_PROVIDER = mp.strip()
for env_name, attr in _ENV_API_KEYS:
val = os.environ.get(env_name)
if val and isinstance(val, str) and not val.startswith('sk-your'):
setattr(cfg, attr, val.strip())
def save(cfg) -> None:
"""将当前 config 模块的关键配置写入 settings.json"""
if not _SETTINGS_PATH:
return
data = {
'model_provider': cfg.MODEL_PROVIDER,
'qwen_api_key': cfg.QWEN_API_KEY,
'qwen_model': cfg.QWEN_MODEL,
'qwen_base_url': cfg.QWEN_BASE_URL,
'openai_api_key': cfg.OPENAI_API_KEY,
'openai_model': cfg.OPENAI_MODEL,
'openai_base_url': cfg.OPENAI_BASE_URL,
'deepseek_api_key': cfg.DEEPSEEK_API_KEY,
'deepseek_model': cfg.DEEPSEEK_MODEL,
'deepseek_base_url': cfg.DEEPSEEK_BASE_URL,
'ollama_base_url': cfg.OLLAMA_BASE_URL,
'ollama_model': cfg.OLLAMA_MODEL,
'doubao_api_key': cfg.DOUBAO_API_KEY,
'doubao_model': cfg.DOUBAO_MODEL,
'doubao_base_url': cfg.DOUBAO_BASE_URL,
'kimi_api_key': cfg.KIMI_API_KEY,
'kimi_model': cfg.KIMI_MODEL,
'kimi_base_url': cfg.KIMI_BASE_URL,
'max_concurrent': cfg.MAX_CONCURRENT_SECTIONS,
'llm_concurrency_limit': getattr(cfg, 'LLM_CONCURRENCY_LIMIT', 20),
'content_volume': cfg.CONTENT_VOLUME,
'target_pages': getattr(cfg, 'TARGET_PAGES', 0),
'page_char_estimate': getattr(cfg, 'PAGE_CHAR_ESTIMATE', 700),
}
try:
os.makedirs(os.path.dirname(_SETTINGS_PATH), exist_ok=True)
with open(_SETTINGS_PATH, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.warning(f'保存配置文件失败: {e}')
def _apply(cfg, data: dict) -> None:
"""将 dict 中的值安全地写回 config 模块"""
str_fields = {
'model_provider': 'MODEL_PROVIDER',
'qwen_api_key': 'QWEN_API_KEY',
'qwen_model': 'QWEN_MODEL',
'qwen_base_url': 'QWEN_BASE_URL',
'openai_api_key': 'OPENAI_API_KEY',
'openai_model': 'OPENAI_MODEL',
'openai_base_url': 'OPENAI_BASE_URL',
'deepseek_api_key': 'DEEPSEEK_API_KEY',
'deepseek_model': 'DEEPSEEK_MODEL',
'deepseek_base_url': 'DEEPSEEK_BASE_URL',
'ollama_base_url': 'OLLAMA_BASE_URL',
'ollama_model': 'OLLAMA_MODEL',
'doubao_api_key': 'DOUBAO_API_KEY',
'doubao_model': 'DOUBAO_MODEL',
'doubao_base_url': 'DOUBAO_BASE_URL',
'kimi_api_key': 'KIMI_API_KEY',
'kimi_model': 'KIMI_MODEL',
'kimi_base_url': 'KIMI_BASE_URL',
}
for key, attr in str_fields.items():
val = data.get(key)
if val and isinstance(val, str):
setattr(cfg, attr, val)
if 'max_concurrent' in data:
try:
v = int(data['max_concurrent'])
cfg.MAX_CONCURRENT_SECTIONS = max(1, min(v, 20))
except (ValueError, TypeError):
pass
if 'llm_concurrency_limit' in data:
try:
v = int(data['llm_concurrency_limit'])
cfg.LLM_CONCURRENCY_LIMIT = max(1, min(v, 30)) # 略高上限以支持配置
# 信号量在config模块初始化时创建重启后生效
except (ValueError, TypeError):
pass
valid_volumes = ('concise', 'standard', 'detailed', 'full')
vol = data.get('content_volume')
if vol and vol in valid_volumes:
cfg.CONTENT_VOLUME = vol
if 'target_pages' in data:
try:
cfg.TARGET_PAGES = max(0, int(data['target_pages']))
except (ValueError, TypeError):
pass
if 'page_char_estimate' in data:
try:
cfg.PAGE_CHAR_ESTIMATE = max(300, min(3000, int(data['page_char_estimate'])))
except (ValueError, TypeError):
pass