289 lines
11 KiB
Python
289 lines
11 KiB
Python
"""
|
||
企业知识库模块(无外部向量库依赖)
|
||
|
||
存储后端:SQLite(与主数据库共用同一文件)
|
||
- knowledge_vectors 表:文本块 + JSON 向量
|
||
- knowledge_files 表:文件元数据(已在 app.py init_db 中建立)
|
||
|
||
检索策略:
|
||
Qwen / OpenAI provider → Embedding API + 余弦相似度(语义检索)
|
||
DeepSeek / Ollama → SQL LIKE 关键词检索(降级)
|
||
"""
|
||
import json
|
||
import math
|
||
import logging
|
||
import os
|
||
import sqlite3
|
||
import threading
|
||
from datetime import datetime
|
||
|
||
import config
|
||
from utils.file_utils import extract_text, split_text_chunks
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 正在后台入库的文件名集合(供前端轮询感知"处理中"状态)
|
||
_processing_files: set = set()
|
||
_processing_lock = threading.Lock()
|
||
|
||
# 每次 Embedding API 批量请求的块数(避免单次请求过大)
|
||
_EMBED_BATCH = 16
|
||
|
||
|
||
# ─── 数据库 ──────────────────────────────────────────────────────────────────
|
||
|
||
def _conn() -> sqlite3.Connection:
|
||
return sqlite3.connect(config.DB_PATH)
|
||
|
||
|
||
def _init_tables(cur: sqlite3.Cursor) -> None:
|
||
"""确保向量块表存在并创建优化索引(极速检索)。knowledge_files 已由 app.py init_db 创建"""
|
||
cur.execute('''
|
||
CREATE TABLE IF NOT EXISTS knowledge_vectors (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
file_name TEXT NOT NULL,
|
||
chunk_idx INTEGER NOT NULL,
|
||
text TEXT NOT NULL,
|
||
embedding TEXT,
|
||
UNIQUE(file_name, chunk_idx)
|
||
)
|
||
''')
|
||
# 优化索引:加速LIKE查询和向量检索时的文本过滤
|
||
cur.execute('CREATE INDEX IF NOT EXISTS idx_kv_file ON knowledge_vectors(file_name)')
|
||
cur.execute('CREATE INDEX IF NOT EXISTS idx_kv_text ON knowledge_vectors(text)') # helps FTS/LIKE
|
||
cur.execute('PRAGMA optimize') # SQLite auto-optimization
|
||
|
||
|
||
# ─── Embedding API ────────────────────────────────────────────────────────────
|
||
|
||
def _get_embeddings_batch(texts: list[str]) -> list[list[float] | None]:
|
||
"""
|
||
调用 ai_client.get_embeddings (复用全局 semaphore 和客户端逻辑)。
|
||
不支持 Embedding 的 provider 返回全 None 列表。优化了并发控制。
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
try:
|
||
# 使用统一 ai_client 接口,确保全局LLM semaphore生效,避免重复客户端创建
|
||
from utils import ai_client
|
||
embeddings = ai_client.get_embeddings(texts)
|
||
return embeddings
|
||
except Exception as e:
|
||
if "NotImplementedError" in str(type(e).__name__) or "不支持" in str(e):
|
||
logger.info('Embedding provider不支持,降级到关键词检索')
|
||
return [None] * len(texts)
|
||
logger.warning(f'Embedding API 调用失败,将使用关键词检索降级: {e}')
|
||
return [None] * len(texts)
|
||
|
||
|
||
def _cosine(a: list[float], b: list[float]) -> float:
|
||
"""纯 Python 余弦相似度,无需 numpy"""
|
||
dot = sum(x * y for x, y in zip(a, b))
|
||
na = math.sqrt(sum(x * x for x in a))
|
||
nb = math.sqrt(sum(x * x for x in b))
|
||
return dot / (na * nb) if na and nb else 0.0
|
||
|
||
|
||
# ─── 公开接口 ─────────────────────────────────────────────────────────────────
|
||
|
||
def is_available() -> dict:
|
||
"""
|
||
知识库始终可用(无外部依赖),返回当前状态。
|
||
search_mode: 'vector'(语义检索)或 'keyword'(关键词降级)
|
||
"""
|
||
with _processing_lock:
|
||
processing = list(_processing_files)
|
||
|
||
try:
|
||
db = _conn()
|
||
cur = db.cursor()
|
||
_init_tables(cur)
|
||
db.commit()
|
||
|
||
cur.execute('SELECT COUNT(*) FROM knowledge_vectors')
|
||
doc_count = cur.fetchone()[0]
|
||
|
||
# 判断是否已有向量(即 Embedding API 是否可用过)
|
||
cur.execute('SELECT 1 FROM knowledge_vectors WHERE embedding IS NOT NULL LIMIT 1')
|
||
has_embedding = cur.fetchone() is not None
|
||
|
||
db.close()
|
||
|
||
provider = getattr(config, 'MODEL_PROVIDER', '')
|
||
can_embed = provider in ('qwen', 'openai', 'kimi')
|
||
mode = 'vector' if (has_embedding or can_embed) else 'keyword'
|
||
|
||
return {
|
||
'available': True,
|
||
'doc_count': doc_count,
|
||
'processing': processing,
|
||
'search_mode': mode,
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
'available': True,
|
||
'doc_count': 0,
|
||
'processing': processing,
|
||
'search_mode': 'keyword',
|
||
'error': str(e),
|
||
}
|
||
|
||
|
||
def add_file(file_path: str, db_path: str) -> dict:
|
||
"""
|
||
将文件切块 → 批量 Embedding → 写入 SQLite。
|
||
此函数在后台线程中调用,_processing_files 用于前端感知进度。
|
||
"""
|
||
file_name = os.path.basename(file_path)
|
||
with _processing_lock:
|
||
_processing_files.add(file_name)
|
||
|
||
try:
|
||
text = extract_text(file_path)
|
||
chunks = split_text_chunks(text, config.CHUNK_SIZE, config.CHUNK_OVERLAP)
|
||
if not chunks:
|
||
return {'success': False, 'error': '文件内容为空,无法入库'}
|
||
|
||
# 批量获取 Embedding(Qwen/OpenAI provider 有效;否则全 None)
|
||
embeddings: list[list[float] | None] = []
|
||
for i in range(0, len(chunks), _EMBED_BATCH):
|
||
batch = chunks[i:i + _EMBED_BATCH]
|
||
embeddings.extend(_get_embeddings_batch(batch))
|
||
|
||
db = _conn()
|
||
try:
|
||
cur = db.cursor()
|
||
_init_tables(cur)
|
||
|
||
# 先删除同名文件的旧数据
|
||
cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,))
|
||
|
||
for idx, (chunk, emb) in enumerate(zip(chunks, embeddings)):
|
||
emb_json = json.dumps(emb) if emb is not None else None
|
||
cur.execute(
|
||
'INSERT INTO knowledge_vectors (file_name, chunk_idx, text, embedding) VALUES (?,?,?,?)',
|
||
(file_name, idx, chunk, emb_json),
|
||
)
|
||
|
||
cur.execute('''
|
||
INSERT OR REPLACE INTO knowledge_files (file_name, file_path, chunk_count, added_at)
|
||
VALUES (?, ?, ?, ?)
|
||
''', (file_name, file_path, len(chunks), datetime.now()))
|
||
|
||
db.commit()
|
||
finally:
|
||
db.close()
|
||
|
||
logger.info(f'知识库入库完成: {file_name},{len(chunks)} 块'
|
||
f'{"(含向量)" if any(e is not None for e in embeddings) else "(关键词模式)"}')
|
||
return {'success': True, 'chunks': len(chunks)}
|
||
|
||
except Exception as e:
|
||
logger.exception('知识库添加文件失败')
|
||
return {'success': False, 'error': str(e)}
|
||
finally:
|
||
with _processing_lock:
|
||
_processing_files.discard(file_name)
|
||
|
||
|
||
def search(query: str, top_k: int = None) -> list[str]:
|
||
"""
|
||
从知识库检索与 query 最相关的文本块。
|
||
- 向量模式:获取 query 的 Embedding → 余弦相似度排序
|
||
- 关键词模式(降级):SQL LIKE 多词匹配
|
||
"""
|
||
if top_k is None:
|
||
top_k = config.TOP_K_KNOWLEDGE
|
||
|
||
try:
|
||
db = _conn()
|
||
try:
|
||
cur = db.cursor()
|
||
_init_tables(cur)
|
||
db.commit()
|
||
|
||
cur.execute('SELECT COUNT(*) FROM knowledge_vectors')
|
||
if cur.fetchone()[0] == 0:
|
||
return []
|
||
|
||
# ── 向量语义检索 ──────────────────────────────────────────────────
|
||
q_embs = _get_embeddings_batch([query])
|
||
q_emb = q_embs[0] if q_embs else None
|
||
|
||
if q_emb is not None:
|
||
# 极速优化:限制扫描行数(避免知识库大时全表扫描),优先最近添加的内容
|
||
cur.execute(
|
||
'''SELECT text, embedding FROM knowledge_vectors
|
||
WHERE embedding IS NOT NULL
|
||
ORDER BY id DESC LIMIT 500'''
|
||
)
|
||
rows = cur.fetchall()
|
||
if rows:
|
||
scored: list[tuple[float, str]] = []
|
||
for text, emb_json in rows:
|
||
try:
|
||
emb = json.loads(emb_json)
|
||
scored.append((_cosine(q_emb, emb), text))
|
||
except Exception:
|
||
continue
|
||
scored.sort(reverse=True)
|
||
return [t for _, t in scored[:top_k]]
|
||
|
||
# ── 关键词降级检索(DeepSeek / Ollama 无 Embedding API)─────────
|
||
# 过滤纯数字/编号词(如 "1.2" "一、"),避免误匹配无关段落
|
||
import re as _re
|
||
_num_pat = _re.compile(r'^[\d\.\-、一二三四五六七八九十]+$')
|
||
words = [
|
||
w.strip() for w in query.split()
|
||
if len(w.strip()) > 1 and not _num_pat.match(w.strip())
|
||
][:6]
|
||
if not words:
|
||
cur.execute('SELECT text FROM knowledge_vectors LIMIT ?', (top_k,))
|
||
return [r[0] for r in cur.fetchall()]
|
||
|
||
conditions = ' OR '.join(['text LIKE ?' for _ in words])
|
||
params = [f'%{w}%' for w in words] + [top_k]
|
||
cur.execute(
|
||
f'SELECT text FROM knowledge_vectors WHERE {conditions} LIMIT ?', params
|
||
)
|
||
return [r[0] for r in cur.fetchall()]
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f'知识库检索失败: {e}')
|
||
return []
|
||
|
||
|
||
def list_files(db_path: str) -> list[dict]:
|
||
"""列出知识库已入库的文件"""
|
||
try:
|
||
db = sqlite3.connect(db_path)
|
||
cur = db.cursor()
|
||
cur.execute(
|
||
'SELECT file_name, chunk_count, added_at FROM knowledge_files ORDER BY added_at DESC'
|
||
)
|
||
rows = cur.fetchall()
|
||
db.close()
|
||
return [{'name': r[0], 'chunks': r[1], 'added_at': r[2]} for r in rows]
|
||
except Exception:
|
||
return []
|
||
|
||
|
||
def delete_file(file_name: str, db_path: str) -> dict:
|
||
"""从知识库删除指定文件的所有数据"""
|
||
try:
|
||
db = _conn()
|
||
cur = db.cursor()
|
||
_init_tables(cur)
|
||
cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,))
|
||
cur.execute('DELETE FROM knowledge_files WHERE file_name=?', (file_name,))
|
||
db.commit()
|
||
db.close()
|
||
return {'success': True}
|
||
except Exception as e:
|
||
logger.exception('知识库删除文件失败')
|
||
return {'success': False, 'error': str(e)}
|