142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
"""
|
||
配置持久化:将用户在界面中设置的 API Key 等配置保存到 data/settings.json,
|
||
服务重启后自动恢复,不再每次重启都丢失 Key。
|
||
"""
|
||
import json
|
||
import os
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_SETTINGS_PATH: str = '' # 由 app.py 初始化时注入
|
||
|
||
|
||
def init(settings_path: str):
|
||
global _SETTINGS_PATH
|
||
_SETTINGS_PATH = settings_path
|
||
|
||
|
||
def load(cfg) -> None:
|
||
"""从 settings.json 加载配置,覆盖 config 模块中的默认值"""
|
||
if not _SETTINGS_PATH or not os.path.exists(_SETTINGS_PATH):
|
||
_apply_env_overrides(cfg)
|
||
return
|
||
try:
|
||
with open(_SETTINGS_PATH, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
_apply(cfg, data)
|
||
_apply_env_overrides(cfg)
|
||
logger.info(f'已从 {_SETTINGS_PATH} 恢复配置,当前 provider={cfg.MODEL_PROVIDER}')
|
||
except Exception as e:
|
||
logger.warning(f'加载配置文件失败: {e}')
|
||
_apply_env_overrides(cfg)
|
||
|
||
|
||
_ENV_API_KEYS = (
|
||
('QWEN_API_KEY', 'QWEN_API_KEY'),
|
||
('OPENAI_API_KEY', 'OPENAI_API_KEY'),
|
||
('DEEPSEEK_API_KEY', 'DEEPSEEK_API_KEY'),
|
||
('DOUBAO_API_KEY', 'DOUBAO_API_KEY'),
|
||
('KIMI_API_KEY', 'KIMI_API_KEY'),
|
||
)
|
||
|
||
|
||
def _apply_env_overrides(cfg) -> None:
|
||
"""环境变量中的 API Key 优先于 settings.json(便于 Docker / 本机 .env 注入)。"""
|
||
mp = os.environ.get('MODEL_PROVIDER')
|
||
if mp and isinstance(mp, str) and mp.strip():
|
||
cfg.MODEL_PROVIDER = mp.strip()
|
||
for env_name, attr in _ENV_API_KEYS:
|
||
val = os.environ.get(env_name)
|
||
if val and isinstance(val, str) and not val.startswith('sk-your'):
|
||
setattr(cfg, attr, val.strip())
|
||
|
||
|
||
def save(cfg) -> None:
|
||
"""将当前 config 模块的关键配置写入 settings.json"""
|
||
if not _SETTINGS_PATH:
|
||
return
|
||
data = {
|
||
'model_provider': cfg.MODEL_PROVIDER,
|
||
'qwen_api_key': cfg.QWEN_API_KEY,
|
||
'qwen_model': cfg.QWEN_MODEL,
|
||
'qwen_base_url': cfg.QWEN_BASE_URL,
|
||
'openai_api_key': cfg.OPENAI_API_KEY,
|
||
'openai_model': cfg.OPENAI_MODEL,
|
||
'openai_base_url': cfg.OPENAI_BASE_URL,
|
||
'deepseek_api_key': cfg.DEEPSEEK_API_KEY,
|
||
'deepseek_model': cfg.DEEPSEEK_MODEL,
|
||
'deepseek_base_url': cfg.DEEPSEEK_BASE_URL,
|
||
'ollama_base_url': cfg.OLLAMA_BASE_URL,
|
||
'ollama_model': cfg.OLLAMA_MODEL,
|
||
'doubao_api_key': cfg.DOUBAO_API_KEY,
|
||
'doubao_model': cfg.DOUBAO_MODEL,
|
||
'doubao_base_url': cfg.DOUBAO_BASE_URL,
|
||
'kimi_api_key': cfg.KIMI_API_KEY,
|
||
'kimi_model': cfg.KIMI_MODEL,
|
||
'kimi_base_url': cfg.KIMI_BASE_URL,
|
||
'max_concurrent': cfg.MAX_CONCURRENT_SECTIONS,
|
||
'content_volume': cfg.CONTENT_VOLUME,
|
||
'target_pages': getattr(cfg, 'TARGET_PAGES', 0),
|
||
'page_char_estimate': getattr(cfg, 'PAGE_CHAR_ESTIMATE', 700),
|
||
}
|
||
try:
|
||
os.makedirs(os.path.dirname(_SETTINGS_PATH), exist_ok=True)
|
||
with open(_SETTINGS_PATH, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
logger.warning(f'保存配置文件失败: {e}')
|
||
|
||
|
||
def _apply(cfg, data: dict) -> None:
|
||
"""将 dict 中的值安全地写回 config 模块"""
|
||
str_fields = {
|
||
'model_provider': 'MODEL_PROVIDER',
|
||
'qwen_api_key': 'QWEN_API_KEY',
|
||
'qwen_model': 'QWEN_MODEL',
|
||
'qwen_base_url': 'QWEN_BASE_URL',
|
||
'openai_api_key': 'OPENAI_API_KEY',
|
||
'openai_model': 'OPENAI_MODEL',
|
||
'openai_base_url': 'OPENAI_BASE_URL',
|
||
'deepseek_api_key': 'DEEPSEEK_API_KEY',
|
||
'deepseek_model': 'DEEPSEEK_MODEL',
|
||
'deepseek_base_url': 'DEEPSEEK_BASE_URL',
|
||
'ollama_base_url': 'OLLAMA_BASE_URL',
|
||
'ollama_model': 'OLLAMA_MODEL',
|
||
'doubao_api_key': 'DOUBAO_API_KEY',
|
||
'doubao_model': 'DOUBAO_MODEL',
|
||
'doubao_base_url': 'DOUBAO_BASE_URL',
|
||
'kimi_api_key': 'KIMI_API_KEY',
|
||
'kimi_model': 'KIMI_MODEL',
|
||
'kimi_base_url': 'KIMI_BASE_URL',
|
||
}
|
||
for key, attr in str_fields.items():
|
||
val = data.get(key)
|
||
if val and isinstance(val, str):
|
||
setattr(cfg, attr, val)
|
||
|
||
if 'max_concurrent' in data:
|
||
try:
|
||
v = int(data['max_concurrent'])
|
||
cfg.MAX_CONCURRENT_SECTIONS = max(1, min(v, 20))
|
||
except (ValueError, TypeError):
|
||
pass
|
||
|
||
valid_volumes = ('concise', 'standard', 'detailed', 'full')
|
||
vol = data.get('content_volume')
|
||
if vol and vol in valid_volumes:
|
||
cfg.CONTENT_VOLUME = vol
|
||
|
||
if 'target_pages' in data:
|
||
try:
|
||
cfg.TARGET_PAGES = max(0, int(data['target_pages']))
|
||
except (ValueError, TypeError):
|
||
pass
|
||
|
||
if 'page_char_estimate' in data:
|
||
try:
|
||
cfg.PAGE_CHAR_ESTIMATE = max(300, min(3000, int(data['page_char_estimate'])))
|
||
except (ValueError, TypeError):
|
||
pass
|