285 lines
8.5 KiB
Python
285 lines
8.5 KiB
Python
"""
|
||
章节级图/表意图:字符特征 + 大纲上下文窗口计分,栈式优先级,驱动提示词附加段。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||
|
||
import config
|
||
from utils import prompts as P
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_DIAGRAM_RULES: Dict[str, Any] = {
|
||
'schema_version': 1,
|
||
'threshold_figure': 1.0,
|
||
'threshold_table': 1.0,
|
||
'title_weight': 1.0,
|
||
'context_weight': 0.6,
|
||
'outline_context_lines': {'before': 4, 'after': 6},
|
||
'stack_order_when_both': 'score_desc',
|
||
'figure_keywords': [],
|
||
'table_keywords': [],
|
||
}
|
||
|
||
|
||
def diagram_rules_path() -> str:
|
||
return os.path.join(config.DATA_DIR, 'diagram_intent_rules.json')
|
||
|
||
|
||
def load_diagram_rules(path: Optional[str] = None) -> Dict[str, Any]:
|
||
"""加载规则 JSON;文件缺失或解析失败时返回内置默认。"""
|
||
p = path or diagram_rules_path()
|
||
data = dict(DEFAULT_DIAGRAM_RULES)
|
||
if not os.path.isfile(p):
|
||
return data
|
||
try:
|
||
with open(p, encoding='utf-8') as f:
|
||
raw = json.load(f)
|
||
if isinstance(raw, dict):
|
||
for k, v in raw.items():
|
||
if k.startswith('_'):
|
||
continue
|
||
if k == 'outline_context_lines' and isinstance(v, dict):
|
||
data['outline_context_lines'] = {
|
||
**data.get('outline_context_lines', {}),
|
||
**v,
|
||
}
|
||
else:
|
||
data[k] = v
|
||
except Exception as e:
|
||
logger.warning('加载 diagram_intent_rules.json 失败,使用内置默认: %s', e)
|
||
return data
|
||
|
||
|
||
def _normalize_keyword_entries(raw: Any) -> List[Tuple[str, float]]:
|
||
out: List[Tuple[str, float]] = []
|
||
if not isinstance(raw, list):
|
||
return out
|
||
for item in raw:
|
||
if isinstance(item, str) and item.strip():
|
||
out.append((item.strip(), 1.0))
|
||
elif isinstance(item, dict):
|
||
t = (item.get('text') or item.get('pattern') or '').strip()
|
||
if not t:
|
||
continue
|
||
w = float(item.get('weight', 1.0))
|
||
out.append((t, w))
|
||
return out
|
||
|
||
|
||
def _score_text(text: str, entries: Sequence[Tuple[str, float]]) -> float:
|
||
if not text or not entries:
|
||
return 0.0
|
||
s = 0.0
|
||
for kw, w in entries:
|
||
if kw in text:
|
||
s += w
|
||
return s
|
||
|
||
|
||
DiagramKind = str # 'figure' | 'table'
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class DiagramIntent:
|
||
kind: str
|
||
score: float
|
||
sources: str
|
||
|
||
|
||
# 栈顶 = index 0,优先生效
|
||
DiagramStack = List[DiagramIntent]
|
||
|
||
|
||
def score_figure_table(
|
||
title: str,
|
||
context_snippet: str,
|
||
rules: Dict[str, Any],
|
||
) -> Tuple[float, float]:
|
||
"""标题与上下文分别计分后按权重合并。"""
|
||
fig_kw = _normalize_keyword_entries(rules.get('figure_keywords'))
|
||
tbl_kw = _normalize_keyword_entries(rules.get('table_keywords'))
|
||
tw = float(rules.get('title_weight', 1.0))
|
||
cw = float(rules.get('context_weight', 0.6))
|
||
t = title or ''
|
||
c = context_snippet or ''
|
||
fig = tw * _score_text(t, fig_kw) + cw * _score_text(c, fig_kw)
|
||
tbl = tw * _score_text(t, tbl_kw) + cw * _score_text(c, tbl_kw)
|
||
return fig, tbl
|
||
|
||
|
||
def extract_outline_window(
|
||
outline_text: str,
|
||
section_title: str,
|
||
before: int,
|
||
after: int,
|
||
fallback_chars: int = 1200,
|
||
) -> str:
|
||
"""
|
||
在大纲中定位章节标题所在行,取上下窗口;找不到则取全文前缀。
|
||
"""
|
||
if not outline_text or not section_title:
|
||
return (outline_text or '')[:fallback_chars]
|
||
title_stripped = section_title.strip()
|
||
if not title_stripped:
|
||
return outline_text[:fallback_chars]
|
||
lines = outline_text.splitlines()
|
||
idx = -1
|
||
# 优先整行包含;否则子串匹配(去编号后)
|
||
def _strip_serial(s: str) -> str:
|
||
return re.sub(r'^\s*[\d一二三四五六七八九十]+[、..\s]+', '', s).strip()
|
||
|
||
core = _strip_serial(title_stripped)
|
||
for i, line in enumerate(lines):
|
||
line_s = line.strip()
|
||
if title_stripped in line_s or (core and core in _strip_serial(line_s)):
|
||
idx = i
|
||
break
|
||
if core and core in line_s:
|
||
idx = i
|
||
break
|
||
if idx < 0:
|
||
return outline_text[:fallback_chars]
|
||
lo = max(0, idx - max(0, before))
|
||
hi = min(len(lines), idx + max(0, after) + 1)
|
||
return '\n'.join(lines[lo:hi])
|
||
|
||
|
||
def build_stack(
|
||
fig_score: float,
|
||
tbl_score: float,
|
||
rules: Dict[str, Any],
|
||
enable_figure: bool,
|
||
enable_table: bool,
|
||
) -> DiagramStack:
|
||
tf = float(rules.get('threshold_figure', 1.0))
|
||
tt = float(rules.get('threshold_table', 1.0))
|
||
mode = (rules.get('stack_order_when_both') or 'score_desc').strip()
|
||
|
||
fig_ok = enable_figure and fig_score >= tf
|
||
tbl_ok = enable_table and tbl_score >= tt
|
||
|
||
intents: List[DiagramIntent] = []
|
||
if fig_ok:
|
||
intents.append(
|
||
DiagramIntent('figure', fig_score, 'title+context')
|
||
)
|
||
if tbl_ok:
|
||
intents.append(
|
||
DiagramIntent('table', tbl_score, 'title+context')
|
||
)
|
||
if len(intents) <= 1:
|
||
return intents
|
||
|
||
a, b = intents[0], intents[1]
|
||
if mode == 'figure_first':
|
||
order = [a, b] if a.kind == 'figure' else [b, a]
|
||
elif mode == 'table_first':
|
||
order = [a, b] if a.kind == 'table' else [b, a]
|
||
else: # score_desc — 高分在栈顶
|
||
order = sorted([a, b], key=lambda x: -x.score)
|
||
return order
|
||
|
||
|
||
def stack_compact_labels(stack: DiagramStack) -> List[str]:
|
||
"""与 stack_to_addon 中 labels 一致,供附件仅块输出的提示词。"""
|
||
labels: List[str] = []
|
||
for it in stack:
|
||
if it.kind == 'figure':
|
||
labels.append('图示([FIGURE] 块)')
|
||
else:
|
||
labels.append('表格([TABLE] 块)')
|
||
return labels
|
||
|
||
|
||
def make_fallback_stack(kind: str) -> DiagramStack:
|
||
"""栈空且需生成时,按单一 figure/table 占位。"""
|
||
k = (kind or '').strip().lower()
|
||
if k not in ('figure', 'table'):
|
||
k = 'table'
|
||
return [DiagramIntent(k, 1.0, 'fallback')]
|
||
|
||
|
||
def stack_to_addon(stack: DiagramStack) -> str:
|
||
"""按栈序拼接优先级说明 + 图示/表格规范全文。"""
|
||
if not stack:
|
||
return ''
|
||
labels: List[str] = []
|
||
for it in stack:
|
||
if it.kind == 'figure':
|
||
labels.append('图示([FIGURE] 块)')
|
||
else:
|
||
labels.append('表格([TABLE] 块)')
|
||
parts: List[str] = [P.diagram_priority_preamble(labels)]
|
||
for it in stack:
|
||
if it.kind == 'figure':
|
||
parts.append(P.get_figure_addon())
|
||
else:
|
||
parts.append(P.get_table_addon())
|
||
return ''.join(parts)
|
||
|
||
|
||
class DiagramIntentAgent:
|
||
"""可配置规则实例:对单节计算栈并渲染附加提示词。"""
|
||
|
||
def __init__(self, rules: Optional[Dict[str, Any]] = None) -> None:
|
||
self.rules = rules or load_diagram_rules()
|
||
|
||
@classmethod
|
||
def load_default(cls) -> 'DiagramIntentAgent':
|
||
return cls(load_diagram_rules())
|
||
|
||
def plan(
|
||
self,
|
||
section_title: str,
|
||
outline_text: str,
|
||
enable_figure: bool,
|
||
enable_table: bool,
|
||
) -> DiagramStack:
|
||
r = self.rules
|
||
oc = r.get('outline_context_lines') or {}
|
||
before = int(oc.get('before', 4))
|
||
after = int(oc.get('after', 6))
|
||
ctx = extract_outline_window(
|
||
outline_text, section_title, before, after,
|
||
)
|
||
fig_s, tbl_s = score_figure_table(section_title, ctx, r)
|
||
return build_stack(fig_s, tbl_s, r, enable_figure, enable_table)
|
||
|
||
def render_for_section(
|
||
self,
|
||
section_title: str,
|
||
outline_text: str,
|
||
enable_figure: bool,
|
||
enable_table: bool,
|
||
) -> str:
|
||
"""渲染图表附加提示或禁用禁止指令。始终返回控制内容以确保禁用生效。"""
|
||
if not enable_figure and not enable_table:
|
||
return P.get_chart_forbid_addon()
|
||
stack = self.plan(
|
||
section_title, outline_text, enable_figure, enable_table,
|
||
)
|
||
return stack_to_addon(stack)
|
||
|
||
|
||
# 模块级默认实例,供 generator 单次调用
|
||
_default_agent: Optional[DiagramIntentAgent] = None
|
||
|
||
|
||
def get_diagram_agent() -> DiagramIntentAgent:
|
||
global _default_agent
|
||
if _default_agent is None:
|
||
_default_agent = DiagramIntentAgent.load_default()
|
||
return _default_agent
|
||
|
||
|
||
def invalidate_diagram_agent_cache() -> None:
|
||
global _default_agent
|
||
_default_agent = None
|