""" 企业知识库模块(无外部向量库依赖) 存储后端:SQLite(与主数据库共用同一文件) - knowledge_vectors 表:文本块 + JSON 向量 - knowledge_files 表:文件元数据(已在 app.py init_db 中建立) 检索策略: Qwen / OpenAI provider → Embedding API + 余弦相似度(语义检索) DeepSeek / Ollama → SQL LIKE 关键词检索(降级) """ import json import math import logging import os import sqlite3 import threading from datetime import datetime import config from utils.file_utils import extract_text, split_text_chunks logger = logging.getLogger(__name__) # 正在后台入库的文件名集合(供前端轮询感知"处理中"状态) _processing_files: set = set() _processing_lock = threading.Lock() # 每次 Embedding API 批量请求的块数(避免单次请求过大) _EMBED_BATCH = 16 # ─── 数据库 ────────────────────────────────────────────────────────────────── def _conn() -> sqlite3.Connection: return sqlite3.connect(config.DB_PATH) def _init_tables(cur: sqlite3.Cursor) -> None: """确保向量块表存在(knowledge_files 已由 app.py init_db 创建)""" cur.execute(''' CREATE TABLE IF NOT EXISTS knowledge_vectors ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_name TEXT NOT NULL, chunk_idx INTEGER NOT NULL, text TEXT NOT NULL, embedding TEXT, UNIQUE(file_name, chunk_idx) ) ''') # ─── Embedding API ──────────────────────────────────────────────────────────── def _get_embeddings_batch(texts: list[str]) -> list[list[float] | None]: """ 调用当前 provider 的 Embedding API,批量返回向量列表。 不支持 Embedding 的 provider(DeepSeek / Ollama)返回全 None 列表。 """ if not texts: return [] provider = getattr(config, 'MODEL_PROVIDER', '') try: from openai import OpenAI if provider == 'qwen': client = OpenAI(api_key=config.QWEN_API_KEY, base_url=config.QWEN_BASE_URL) model = config.QWEN_EMBEDDING_MODEL elif provider == 'openai': client = OpenAI(api_key=config.OPENAI_API_KEY, base_url=config.OPENAI_BASE_URL) model = config.OPENAI_EMBEDDING_MODEL elif provider == 'kimi': client = OpenAI(api_key=config.KIMI_API_KEY, base_url=config.KIMI_BASE_URL) model = config.KIMI_EMBEDDING_MODEL else: # DeepSeek / Ollama / 豆包 无公开 Embedding API,降级到关键词检索 return [None] * len(texts) resp = client.embeddings.create(input=texts, model=model) return [item.embedding for item in resp.data] except Exception as e: logger.warning(f'Embedding API 调用失败,将使用关键词检索降级: {e}') return [None] * len(texts) def _cosine(a: list[float], b: list[float]) -> float: """纯 Python 余弦相似度,无需 numpy""" dot = sum(x * y for x, y in zip(a, b)) na = math.sqrt(sum(x * x for x in a)) nb = math.sqrt(sum(x * x for x in b)) return dot / (na * nb) if na and nb else 0.0 # ─── 公开接口 ───────────────────────────────────────────────────────────────── def is_available() -> dict: """ 知识库始终可用(无外部依赖),返回当前状态。 search_mode: 'vector'(语义检索)或 'keyword'(关键词降级) """ with _processing_lock: processing = list(_processing_files) try: db = _conn() cur = db.cursor() _init_tables(cur) db.commit() cur.execute('SELECT COUNT(*) FROM knowledge_vectors') doc_count = cur.fetchone()[0] # 判断是否已有向量(即 Embedding API 是否可用过) cur.execute('SELECT 1 FROM knowledge_vectors WHERE embedding IS NOT NULL LIMIT 1') has_embedding = cur.fetchone() is not None db.close() provider = getattr(config, 'MODEL_PROVIDER', '') can_embed = provider in ('qwen', 'openai', 'kimi') mode = 'vector' if (has_embedding or can_embed) else 'keyword' return { 'available': True, 'doc_count': doc_count, 'processing': processing, 'search_mode': mode, } except Exception as e: return { 'available': True, 'doc_count': 0, 'processing': processing, 'search_mode': 'keyword', 'error': str(e), } def add_file(file_path: str, db_path: str) -> dict: """ 将文件切块 → 批量 Embedding → 写入 SQLite。 此函数在后台线程中调用,_processing_files 用于前端感知进度。 """ file_name = os.path.basename(file_path) with _processing_lock: _processing_files.add(file_name) try: text = extract_text(file_path) chunks = split_text_chunks(text, config.CHUNK_SIZE, config.CHUNK_OVERLAP) if not chunks: return {'success': False, 'error': '文件内容为空,无法入库'} # 批量获取 Embedding(Qwen/OpenAI provider 有效;否则全 None) embeddings: list[list[float] | None] = [] for i in range(0, len(chunks), _EMBED_BATCH): batch = chunks[i:i + _EMBED_BATCH] embeddings.extend(_get_embeddings_batch(batch)) db = _conn() try: cur = db.cursor() _init_tables(cur) # 先删除同名文件的旧数据 cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,)) for idx, (chunk, emb) in enumerate(zip(chunks, embeddings)): emb_json = json.dumps(emb) if emb is not None else None cur.execute( 'INSERT INTO knowledge_vectors (file_name, chunk_idx, text, embedding) VALUES (?,?,?,?)', (file_name, idx, chunk, emb_json), ) cur.execute(''' INSERT OR REPLACE INTO knowledge_files (file_name, file_path, chunk_count, added_at) VALUES (?, ?, ?, ?) ''', (file_name, file_path, len(chunks), datetime.now())) db.commit() finally: db.close() logger.info(f'知识库入库完成: {file_name},{len(chunks)} 块' f'{"(含向量)" if any(e is not None for e in embeddings) else "(关键词模式)"}') return {'success': True, 'chunks': len(chunks)} except Exception as e: logger.exception('知识库添加文件失败') return {'success': False, 'error': str(e)} finally: with _processing_lock: _processing_files.discard(file_name) def search(query: str, top_k: int = None) -> list[str]: """ 从知识库检索与 query 最相关的文本块。 - 向量模式:获取 query 的 Embedding → 余弦相似度排序 - 关键词模式(降级):SQL LIKE 多词匹配 """ if top_k is None: top_k = config.TOP_K_KNOWLEDGE try: db = _conn() try: cur = db.cursor() _init_tables(cur) db.commit() cur.execute('SELECT COUNT(*) FROM knowledge_vectors') if cur.fetchone()[0] == 0: return [] # ── 向量语义检索 ────────────────────────────────────────────────── q_embs = _get_embeddings_batch([query]) q_emb = q_embs[0] if q_embs else None if q_emb is not None: cur.execute( 'SELECT text, embedding FROM knowledge_vectors WHERE embedding IS NOT NULL' ) rows = cur.fetchall() if rows: scored: list[tuple[float, str]] = [] for text, emb_json in rows: try: emb = json.loads(emb_json) scored.append((_cosine(q_emb, emb), text)) except Exception: continue scored.sort(reverse=True) return [t for _, t in scored[:top_k]] # ── 关键词降级检索(DeepSeek / Ollama 无 Embedding API)───────── # 过滤纯数字/编号词(如 "1.2" "一、"),避免误匹配无关段落 import re as _re _num_pat = _re.compile(r'^[\d\.\-、一二三四五六七八九十]+$') words = [ w.strip() for w in query.split() if len(w.strip()) > 1 and not _num_pat.match(w.strip()) ][:6] if not words: cur.execute('SELECT text FROM knowledge_vectors LIMIT ?', (top_k,)) return [r[0] for r in cur.fetchall()] conditions = ' OR '.join(['text LIKE ?' for _ in words]) params = [f'%{w}%' for w in words] + [top_k] cur.execute( f'SELECT text FROM knowledge_vectors WHERE {conditions} LIMIT ?', params ) return [r[0] for r in cur.fetchall()] finally: db.close() except Exception as e: logger.error(f'知识库检索失败: {e}') return [] def list_files(db_path: str) -> list[dict]: """列出知识库已入库的文件""" try: db = sqlite3.connect(db_path) cur = db.cursor() cur.execute( 'SELECT file_name, chunk_count, added_at FROM knowledge_files ORDER BY added_at DESC' ) rows = cur.fetchall() db.close() return [{'name': r[0], 'chunks': r[1], 'added_at': r[2]} for r in rows] except Exception: return [] def delete_file(file_name: str, db_path: str) -> dict: """从知识库删除指定文件的所有数据""" try: db = _conn() cur = db.cursor() _init_tables(cur) cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,)) cur.execute('DELETE FROM knowledge_files WHERE file_name=?', (file_name,)) db.commit() db.close() return {'success': True} except Exception as e: logger.exception('知识库删除文件失败') return {'success': False, 'error': str(e)}