2026-04-23 14:37:19 +08:00

293 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
企业知识库模块(无外部向量库依赖)
存储后端SQLite与主数据库共用同一文件
- knowledge_vectors 表:文本块 + JSON 向量
- knowledge_files 表:文件元数据(已在 app.py init_db 中建立)
检索策略:
Qwen / OpenAI provider → Embedding API + 余弦相似度(语义检索)
DeepSeek / Ollama → SQL LIKE 关键词检索(降级)
"""
import json
import math
import logging
import os
import sqlite3
import threading
from datetime import datetime
import config
from utils.file_utils import extract_text, split_text_chunks
logger = logging.getLogger(__name__)
# 正在后台入库的文件名集合(供前端轮询感知"处理中"状态)
_processing_files: set = set()
_processing_lock = threading.Lock()
# 每次 Embedding API 批量请求的块数(避免单次请求过大)
_EMBED_BATCH = 16
# ─── 数据库 ──────────────────────────────────────────────────────────────────
def _conn() -> sqlite3.Connection:
return sqlite3.connect(config.DB_PATH)
def _init_tables(cur: sqlite3.Cursor) -> None:
"""确保向量块表存在knowledge_files 已由 app.py init_db 创建)"""
cur.execute('''
CREATE TABLE IF NOT EXISTS knowledge_vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT NOT NULL,
chunk_idx INTEGER NOT NULL,
text TEXT NOT NULL,
embedding TEXT,
UNIQUE(file_name, chunk_idx)
)
''')
# ─── Embedding API ────────────────────────────────────────────────────────────
def _get_embeddings_batch(texts: list[str]) -> list[list[float] | None]:
"""
调用当前 provider 的 Embedding API批量返回向量列表。
不支持 Embedding 的 providerDeepSeek / Ollama返回全 None 列表。
"""
if not texts:
return []
provider = getattr(config, 'MODEL_PROVIDER', '')
try:
from openai import OpenAI
if provider == 'qwen':
client = OpenAI(api_key=config.QWEN_API_KEY, base_url=config.QWEN_BASE_URL)
model = config.QWEN_EMBEDDING_MODEL
elif provider == 'openai':
client = OpenAI(api_key=config.OPENAI_API_KEY, base_url=config.OPENAI_BASE_URL)
model = config.OPENAI_EMBEDDING_MODEL
elif provider == 'kimi':
client = OpenAI(api_key=config.KIMI_API_KEY, base_url=config.KIMI_BASE_URL)
model = config.KIMI_EMBEDDING_MODEL
else:
# DeepSeek / Ollama / 豆包 无公开 Embedding API降级到关键词检索
return [None] * len(texts)
resp = client.embeddings.create(input=texts, model=model)
return [item.embedding for item in resp.data]
except Exception as e:
logger.warning(f'Embedding API 调用失败,将使用关键词检索降级: {e}')
return [None] * len(texts)
def _cosine(a: list[float], b: list[float]) -> float:
"""纯 Python 余弦相似度,无需 numpy"""
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(x * x for x in b))
return dot / (na * nb) if na and nb else 0.0
# ─── 公开接口 ─────────────────────────────────────────────────────────────────
def is_available() -> dict:
"""
知识库始终可用(无外部依赖),返回当前状态。
search_mode: 'vector'(语义检索)或 'keyword'(关键词降级)
"""
with _processing_lock:
processing = list(_processing_files)
try:
db = _conn()
cur = db.cursor()
_init_tables(cur)
db.commit()
cur.execute('SELECT COUNT(*) FROM knowledge_vectors')
doc_count = cur.fetchone()[0]
# 判断是否已有向量(即 Embedding API 是否可用过)
cur.execute('SELECT 1 FROM knowledge_vectors WHERE embedding IS NOT NULL LIMIT 1')
has_embedding = cur.fetchone() is not None
db.close()
provider = getattr(config, 'MODEL_PROVIDER', '')
can_embed = provider in ('qwen', 'openai', 'kimi')
mode = 'vector' if (has_embedding or can_embed) else 'keyword'
return {
'available': True,
'doc_count': doc_count,
'processing': processing,
'search_mode': mode,
}
except Exception as e:
return {
'available': True,
'doc_count': 0,
'processing': processing,
'search_mode': 'keyword',
'error': str(e),
}
def add_file(file_path: str, db_path: str) -> dict:
"""
将文件切块 → 批量 Embedding → 写入 SQLite。
此函数在后台线程中调用_processing_files 用于前端感知进度。
"""
file_name = os.path.basename(file_path)
with _processing_lock:
_processing_files.add(file_name)
try:
text = extract_text(file_path)
chunks = split_text_chunks(text, config.CHUNK_SIZE, config.CHUNK_OVERLAP)
if not chunks:
return {'success': False, 'error': '文件内容为空,无法入库'}
# 批量获取 EmbeddingQwen/OpenAI provider 有效;否则全 None
embeddings: list[list[float] | None] = []
for i in range(0, len(chunks), _EMBED_BATCH):
batch = chunks[i:i + _EMBED_BATCH]
embeddings.extend(_get_embeddings_batch(batch))
db = _conn()
try:
cur = db.cursor()
_init_tables(cur)
# 先删除同名文件的旧数据
cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,))
for idx, (chunk, emb) in enumerate(zip(chunks, embeddings)):
emb_json = json.dumps(emb) if emb is not None else None
cur.execute(
'INSERT INTO knowledge_vectors (file_name, chunk_idx, text, embedding) VALUES (?,?,?,?)',
(file_name, idx, chunk, emb_json),
)
cur.execute('''
INSERT OR REPLACE INTO knowledge_files (file_name, file_path, chunk_count, added_at)
VALUES (?, ?, ?, ?)
''', (file_name, file_path, len(chunks), datetime.now()))
db.commit()
finally:
db.close()
logger.info(f'知识库入库完成: {file_name}{len(chunks)}'
f'{"(含向量)" if any(e is not None for e in embeddings) else "(关键词模式)"}')
return {'success': True, 'chunks': len(chunks)}
except Exception as e:
logger.exception('知识库添加文件失败')
return {'success': False, 'error': str(e)}
finally:
with _processing_lock:
_processing_files.discard(file_name)
def search(query: str, top_k: int = None) -> list[str]:
"""
从知识库检索与 query 最相关的文本块。
- 向量模式:获取 query 的 Embedding → 余弦相似度排序
- 关键词模式降级SQL LIKE 多词匹配
"""
if top_k is None:
top_k = config.TOP_K_KNOWLEDGE
try:
db = _conn()
try:
cur = db.cursor()
_init_tables(cur)
db.commit()
cur.execute('SELECT COUNT(*) FROM knowledge_vectors')
if cur.fetchone()[0] == 0:
return []
# ── 向量语义检索 ──────────────────────────────────────────────────
q_embs = _get_embeddings_batch([query])
q_emb = q_embs[0] if q_embs else None
if q_emb is not None:
cur.execute(
'SELECT text, embedding FROM knowledge_vectors WHERE embedding IS NOT NULL'
)
rows = cur.fetchall()
if rows:
scored: list[tuple[float, str]] = []
for text, emb_json in rows:
try:
emb = json.loads(emb_json)
scored.append((_cosine(q_emb, emb), text))
except Exception:
continue
scored.sort(reverse=True)
return [t for _, t in scored[:top_k]]
# ── 关键词降级检索DeepSeek / Ollama 无 Embedding API─────────
# 过滤纯数字/编号词(如 "1.2" "一、"),避免误匹配无关段落
import re as _re
_num_pat = _re.compile(r'^[\d\.\-、一二三四五六七八九十]+$')
words = [
w.strip() for w in query.split()
if len(w.strip()) > 1 and not _num_pat.match(w.strip())
][:6]
if not words:
cur.execute('SELECT text FROM knowledge_vectors LIMIT ?', (top_k,))
return [r[0] for r in cur.fetchall()]
conditions = ' OR '.join(['text LIKE ?' for _ in words])
params = [f'%{w}%' for w in words] + [top_k]
cur.execute(
f'SELECT text FROM knowledge_vectors WHERE {conditions} LIMIT ?', params
)
return [r[0] for r in cur.fetchall()]
finally:
db.close()
except Exception as e:
logger.error(f'知识库检索失败: {e}')
return []
def list_files(db_path: str) -> list[dict]:
"""列出知识库已入库的文件"""
try:
db = sqlite3.connect(db_path)
cur = db.cursor()
cur.execute(
'SELECT file_name, chunk_count, added_at FROM knowledge_files ORDER BY added_at DESC'
)
rows = cur.fetchall()
db.close()
return [{'name': r[0], 'chunks': r[1], 'added_at': r[2]} for r in rows]
except Exception:
return []
def delete_file(file_name: str, db_path: str) -> dict:
"""从知识库删除指定文件的所有数据"""
try:
db = _conn()
cur = db.cursor()
_init_tables(cur)
cur.execute('DELETE FROM knowledge_vectors WHERE file_name=?', (file_name,))
cur.execute('DELETE FROM knowledge_files WHERE file_name=?', (file_name,))
db.commit()
db.close()
return {'success': True}
except Exception as e:
logger.exception('知识库删除文件失败')
return {'success': False, 'error': str(e)}