2026-04-24 14:44:38 +08:00

289 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
企业知识库模块(无外部向量库依赖)
存储后端SQLite与主数据库共用同一文件
- knowledge_vectors 表:文本块 + JSON 向量
- knowledge_files 表:文件元数据(已在 app.py init_db 中建立)
检索策略:
Qwen / OpenAI provider → Embedding API + 余弦相似度(语义检索)
DeepSeek / Ollama → SQL LIKE 关键词检索(降级)
"""
import json
import math
import logging
import os
import sqlite3
import threading
from datetime import datetime
import config
from utils.file_utils import extract_text, split_text_chunks
logger = logging.getLogger(__name__)
# 正在后台入库的文件名集合(供前端轮询感知"处理中"状态)
_processing_files: set = set()
_processing_lock = threading.Lock()
# 每次 Embedding API 批量请求的块数(避免单次请求过大)
_EMBED_BATCH = 16
# ─── 数据库 ──────────────────────────────────────────────────────────────────
def _conn() -> sqlite3.Connection:
return sqlite3.connect(config.DB_PATH)
def _init_tables(cur: sqlite3.Cursor) -> None:
"""确保向量块表存在并创建优化索引极速检索。knowledge_files 已由 app.py init_db 创建"""
cur.execute('''
CREATE TABLE IF NOT EXISTS knowledge_vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT NOT NULL,
chunk_idx INTEGER NOT NULL,
text TEXT NOT NULL,
embedding TEXT,
UNIQUE(file_name, chunk_idx)
)
''')
# 优化索引加速LIKE查询和向量检索时的文本过滤
cur.execute('CREATE INDEX IF NOT EXISTS idx_kv_file ON knowledge_vectors(file_name)')
cur.execute('CREATE INDEX IF NOT EXISTS idx_kv_text ON knowledge_vectors(text)') # helps FTS/LIKE
cur.execute('PRAGMA optimize') # SQLite auto-optimization
# ─── Embedding API ────────────────────────────────────────────────────────────
def _get_embeddings_batch(texts: list[str]) -> list[list[float] | None]:
"""
调用 ai_client.get_embeddings (复用全局 semaphore 和客户端逻辑)。
不支持 Embedding 的 provider 返回全 None 列表。优化了并发控制。
"""
if not texts:
return []
try:
# 使用统一 ai_client 接口确保全局LLM semaphore生效避免重复客户端创建
from utils import ai_client
embeddings = ai_client.get_embeddings(texts)
return embeddings
except Exception as e:
if "NotImplementedError" in str(type(e).__name__) or "不支持" in str(e):
logger.info('Embedding provider不支持降级到关键词检索')
return [None] * len(texts)
logger.warning(f'Embedding API 调用失败,将使用关键词检索降级: {e}')
return [None] * len(texts)
def _cosine(a: list[float], b: list[float]) -> float:
"""纯 Python 余弦相似度,无需 numpy"""
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(x * x for x in b))
return dot / (na * nb) if na and nb else 0.0
# ─── 公开接口 ─────────────────────────────────────────────────────────────────
def is_available() -> dict:
"""
知识库始终可用(无外部依赖),返回当前状态。
search_mode: 'vector'(语义检索)或 'keyword'(关键词降级)
"""
with _processing_lock:
processing = list(_processing_files)
try:
db = _conn()
cur = db.cursor()
_init_tables(cur)
db.commit()
cur.execute('SELECT COUNT(*) FROM knowledge_vectors')
doc_count = cur.fetchone()[0]
# 判断是否已有向量(即 Embedding API 是否可用过)
cur.execute('SELECT 1 FROM knowledge_vectors WHERE embedding IS NOT NULL LIMIT 1')
has_embedding = cur.fetchone() is not None
db.close()
provider = getattr(config, 'MODEL_PROVIDER', '')
can_embed = provider in ('qwen', 'openai', 'kimi')
mode = 'vector' if (has_embedding or can_embed) else 'keyword'
return {
'available': True,
'doc_count': doc_count,
'processing': processing,
'search_mode': mode,
}
except Exception as e:
return {
'available': True,
'doc_count': 0,
'processing': processing,
'search_mode': 'keyword',
'error': str(e),
}
def add_file(file_path: str, db_path: str) -> dict:
"""
将文件切块 → 批量 Embedding → 写入 SQLite。
此函数在后台线程中调用_processing_files 用于前端感知进度。
"""
file_name = os.path.basename(file_path)
with _processing_lock:
_processing_files.add(file_name)
try:
text = extract_text(file_path)
chunks = split_text_chunks(text, config.CHUNK_SIZE, config.CHUNK_OVERLAP)
if not chunks:
return {'success': False, 'error': '文件内容为空,无法入库'}
# 批量获取 EmbeddingQwen/OpenAI provider 有效;否则全 None
embeddings: list[list[float] | None] = []
for i in range(0, len(chunks), _EMBED_BATCH):
batch = chunks[i:i + _EMBED_BATCH]
embeddings.extend(_get_embeddings_batch(batch))
db = _conn()
try:
cur = db.cursor()
_init_tables(cur)
# 先删除同名文件的旧数据
cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,))
for idx, (chunk, emb) in enumerate(zip(chunks, embeddings)):
emb_json = json.dumps(emb) if emb is not None else None
cur.execute(
'INSERT INTO knowledge_vectors (file_name, chunk_idx, text, embedding) VALUES (?,?,?,?)',
(file_name, idx, chunk, emb_json),
)
cur.execute('''
INSERT OR REPLACE INTO knowledge_files (file_name, file_path, chunk_count, added_at)
VALUES (?, ?, ?, ?)
''', (file_name, file_path, len(chunks), datetime.now()))
db.commit()
finally:
db.close()
logger.info(f'知识库入库完成: {file_name}{len(chunks)}'
f'{"(含向量)" if any(e is not None for e in embeddings) else "(关键词模式)"}')
return {'success': True, 'chunks': len(chunks)}
except Exception as e:
logger.exception('知识库添加文件失败')
return {'success': False, 'error': str(e)}
finally:
with _processing_lock:
_processing_files.discard(file_name)
def search(query: str, top_k: int = None) -> list[str]:
"""
从知识库检索与 query 最相关的文本块。
- 向量模式:获取 query 的 Embedding → 余弦相似度排序
- 关键词模式降级SQL LIKE 多词匹配
"""
if top_k is None:
top_k = config.TOP_K_KNOWLEDGE
try:
db = _conn()
try:
cur = db.cursor()
_init_tables(cur)
db.commit()
cur.execute('SELECT COUNT(*) FROM knowledge_vectors')
if cur.fetchone()[0] == 0:
return []
# ── 向量语义检索 ──────────────────────────────────────────────────
q_embs = _get_embeddings_batch([query])
q_emb = q_embs[0] if q_embs else None
if q_emb is not None:
# 极速优化:限制扫描行数(避免知识库大时全表扫描),优先最近添加的内容
cur.execute(
'''SELECT text, embedding FROM knowledge_vectors
WHERE embedding IS NOT NULL
ORDER BY id DESC LIMIT 500'''
)
rows = cur.fetchall()
if rows:
scored: list[tuple[float, str]] = []
for text, emb_json in rows:
try:
emb = json.loads(emb_json)
scored.append((_cosine(q_emb, emb), text))
except Exception:
continue
scored.sort(reverse=True)
return [t for _, t in scored[:top_k]]
# ── 关键词降级检索DeepSeek / Ollama 无 Embedding API─────────
# 过滤纯数字/编号词(如 "1.2" "一、"),避免误匹配无关段落
import re as _re
_num_pat = _re.compile(r'^[\d\.\-、一二三四五六七八九十]+$')
words = [
w.strip() for w in query.split()
if len(w.strip()) > 1 and not _num_pat.match(w.strip())
][:6]
if not words:
cur.execute('SELECT text FROM knowledge_vectors LIMIT ?', (top_k,))
return [r[0] for r in cur.fetchall()]
conditions = ' OR '.join(['text LIKE ?' for _ in words])
params = [f'%{w}%' for w in words] + [top_k]
cur.execute(
f'SELECT text FROM knowledge_vectors WHERE {conditions} LIMIT ?', params
)
return [r[0] for r in cur.fetchall()]
finally:
db.close()
except Exception as e:
logger.error(f'知识库检索失败: {e}')
return []
def list_files(db_path: str) -> list[dict]:
"""列出知识库已入库的文件"""
try:
db = sqlite3.connect(db_path)
cur = db.cursor()
cur.execute(
'SELECT file_name, chunk_count, added_at FROM knowledge_files ORDER BY added_at DESC'
)
rows = cur.fetchall()
db.close()
return [{'name': r[0], 'chunks': r[1], 'added_at': r[2]} for r in rows]
except Exception:
return []
def delete_file(file_name: str, db_path: str) -> dict:
"""从知识库删除指定文件的所有数据"""
try:
db = _conn()
cur = db.cursor()
_init_tables(cur)
cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,))
cur.execute('DELETE FROM knowledge_files WHERE file_name=?', (file_name,))
db.commit()
db.close()
return {'success': True}
except Exception as e:
logger.exception('知识库删除文件失败')
return {'success': False, 'error': str(e)}