372 lines
12 KiB
Python
372 lines
12 KiB
Python
"""
|
||
技术评分驱动的章节字数分配:读取 data/word_allocation_rules.json,
|
||
结合 VOLUME_PRESETS 的 base/core 与项目 rating_json,为每个叶节点生成
|
||
min_chars、word_count_spec(及可选 max_tokens)。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 与 modules/generator.VOLUME_PRESETS 保持一致
|
||
VOLUME_PRESETS: Dict[str, Tuple[int, int, str, int]] = {
|
||
'concise': (1200, 2500, '精简版', 5000),
|
||
'standard': (2000, 4000, '标准版', 8000),
|
||
'detailed': (3000, 5500, '详细版', 12000),
|
||
'full': (4000, 7000, '充实版', 16000),
|
||
}
|
||
|
||
_PROVIDER_TOKEN_LIMITS = {
|
||
'deepseek': 8192,
|
||
'qwen': 8192,
|
||
'openai': 16384,
|
||
'ollama': 8192,
|
||
'doubao': 8192,
|
||
'kimi': 8192,
|
||
}
|
||
|
||
DEFAULT_RULES: Dict[str, Any] = {
|
||
'schema_version': 1,
|
||
'alpha': 0.85,
|
||
'budget_mode': 'target_pages',
|
||
'per_section_floor': None,
|
||
'per_section_cap': None,
|
||
'relevance': {'method': 'keyword_overlap', 'min_rating_weight': 0.01},
|
||
'rating_parse': {},
|
||
'prompt': {'top_k_rating_items': 4, 'intro_line': ''},
|
||
'max_tokens_scale': False,
|
||
}
|
||
|
||
|
||
def rules_path() -> str:
|
||
return os.path.join(config.DATA_DIR, 'word_allocation_rules.json')
|
||
|
||
|
||
def load_rules(path: Optional[str] = None) -> Dict[str, Any]:
|
||
"""加载规则 JSON;文件缺失或解析失败时返回内置 DEFAULT_RULES。"""
|
||
p = path or rules_path()
|
||
data = dict(DEFAULT_RULES)
|
||
if not os.path.isfile(p):
|
||
return data
|
||
try:
|
||
with open(p, encoding='utf-8') as f:
|
||
raw = json.load(f)
|
||
if isinstance(raw, dict):
|
||
for k, v in raw.items():
|
||
if k.startswith('_'):
|
||
continue
|
||
if k == 'relevance' and isinstance(v, dict):
|
||
data['relevance'] = {**data.get('relevance', {}), **v}
|
||
elif k == 'prompt' and isinstance(v, dict):
|
||
data['prompt'] = {**data.get('prompt', {}), **v}
|
||
else:
|
||
data[k] = v
|
||
except Exception as e:
|
||
logger.warning('加载 word_allocation_rules.json 失败,使用内置默认: %s', e)
|
||
return data
|
||
|
||
|
||
def _as_float(x: Any, default: float = 0.0) -> float:
|
||
if x is None:
|
||
return default
|
||
if isinstance(x, (int, float)):
|
||
return float(x)
|
||
if isinstance(x, str):
|
||
s = re.sub(r'[^\d.\-]', '', x)
|
||
if not s:
|
||
return default
|
||
try:
|
||
return float(s)
|
||
except ValueError:
|
||
return default
|
||
return default
|
||
|
||
|
||
def _item_name(d: Dict[str, Any]) -> str:
|
||
for k in ('name', 'title', 'item_name', '评分项', '评分项名称', 'indicator'):
|
||
v = d.get(k)
|
||
if isinstance(v, str) and v.strip():
|
||
return v.strip()
|
||
return ''
|
||
|
||
|
||
def _item_weight(d: Dict[str, Any]) -> float:
|
||
for k in ('weight', 'score', '分值', 'max_score', '满分', 'points'):
|
||
if k in d:
|
||
w = _as_float(d.get(k), 0.0)
|
||
if w > 0:
|
||
return w
|
||
return 1.0
|
||
|
||
|
||
def _collect_rating_dicts(obj: Any, acc: List[Dict[str, Any]]) -> None:
|
||
if isinstance(obj, dict):
|
||
acc.append(obj)
|
||
for v in obj.values():
|
||
_collect_rating_dicts(v, acc)
|
||
elif isinstance(obj, list):
|
||
for v in obj:
|
||
_collect_rating_dicts(v, acc)
|
||
|
||
|
||
def parse_rating_json(raw: Optional[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
从 rating_json 字符串解析评分项列表。
|
||
每项: { 'name': str, 'weight': float, 'keywords': List[str] }
|
||
"""
|
||
if not raw or not isinstance(raw, str) or not raw.strip():
|
||
return []
|
||
try:
|
||
root = json.loads(raw.strip())
|
||
except json.JSONDecodeError:
|
||
return []
|
||
|
||
dicts: List[Dict[str, Any]] = []
|
||
_collect_rating_dicts(root, dicts)
|
||
|
||
items: List[Dict[str, Any]] = []
|
||
seen: set = set()
|
||
for d in dicts:
|
||
name = _item_name(d)
|
||
if not name or len(name) < 2:
|
||
continue
|
||
key = name.lower()
|
||
if key in seen:
|
||
continue
|
||
w = _item_weight(d)
|
||
kws: List[str] = []
|
||
kw = d.get('keywords') or d.get('keyword') or d.get('要点')
|
||
if isinstance(kw, list):
|
||
kws = [str(x).strip() for x in kw if isinstance(x, (str, int, float)) and str(x).strip()]
|
||
elif isinstance(kw, str) and kw.strip():
|
||
kws = [kw.strip()]
|
||
seen.add(key)
|
||
items.append({'name': name, 'weight': w, 'keywords': kws})
|
||
|
||
return items
|
||
|
||
|
||
def _title_tokens(title: str) -> List[str]:
|
||
if not title:
|
||
return []
|
||
s = re.sub(r'[\s\d..、,,;;::/\\()()【】\[\]「」]+', ' ', title)
|
||
parts = [p for p in s.split() if len(p) >= 2]
|
||
toks = list(parts)
|
||
for m in re.findall(r'[\u4e00-\u9fff]{2,}', title):
|
||
if m not in toks:
|
||
toks.append(m)
|
||
return toks
|
||
|
||
|
||
def _overlap_score(title: str, item: Dict[str, Any]) -> float:
|
||
tokens = _title_tokens(title)
|
||
if not tokens:
|
||
return 0.0
|
||
blob = item['name'] + ''.join(item.get('keywords') or [])
|
||
hit = sum(1 for t in tokens if t and t in blob)
|
||
score = hit / max(len(tokens), 1)
|
||
if item['name'] in title or title in item['name']:
|
||
score = max(score, 0.85)
|
||
for kw in item.get('keywords') or []:
|
||
if isinstance(kw, str) and len(kw) >= 2 and kw in title:
|
||
score = max(score, 0.7)
|
||
return min(1.0, score)
|
||
|
||
|
||
def _raw_utilities(
|
||
leaves: List[Dict[str, Any]],
|
||
items: List[Dict[str, Any]],
|
||
min_w: float,
|
||
) -> Tuple[List[float], List[List[Tuple[str, float]]]]:
|
||
"""每节 u_i = sum_j w_j * c_ij;返回 u 与每节 top 相关项 (name, contrib)。"""
|
||
filtered = [it for it in items if it['weight'] >= min_w]
|
||
if not filtered:
|
||
filtered = items
|
||
n = len(leaves)
|
||
u = [0.0] * n
|
||
top_lists: List[List[Tuple[str, float]]] = [[] for _ in range(n)]
|
||
|
||
for i, leaf in enumerate(leaves):
|
||
title = leaf.get('section_title') or ''
|
||
contribs: List[Tuple[str, float]] = []
|
||
for it in filtered:
|
||
c = _overlap_score(title, it)
|
||
contrib = it['weight'] * c
|
||
if contrib > 0:
|
||
contribs.append((it['name'], contrib))
|
||
u[i] += contrib
|
||
contribs.sort(key=lambda x: -x[1])
|
||
top_lists[i] = contribs[:12]
|
||
|
||
max_u = max(u) if u else 0.0
|
||
if max_u <= 0:
|
||
u = [1.0] * n
|
||
else:
|
||
u = [x / max_u for x in u]
|
||
return u, top_lists
|
||
|
||
|
||
def _clamp_int(x: int, lo: int, hi: int) -> int:
|
||
return max(lo, min(hi, x))
|
||
|
||
|
||
def _water_adjust(
|
||
targets: List[int],
|
||
budget: int,
|
||
floor_v: int,
|
||
cap_v: int,
|
||
priority: List[float],
|
||
) -> List[int]:
|
||
"""在 [floor_v, cap_v] 内将 targets 整数化并尽量使 sum 接近 budget。"""
|
||
n = len(targets)
|
||
if n == 0:
|
||
return []
|
||
if floor_v > cap_v:
|
||
floor_v, cap_v = cap_v, floor_v
|
||
if n * floor_v > budget:
|
||
floor_v = max(1, budget // n)
|
||
if n * cap_v < budget:
|
||
cap_v = max(floor_v, (budget + n - 1) // n)
|
||
cur = [_clamp_int(t, floor_v, cap_v) for t in targets]
|
||
s = sum(cur)
|
||
delta = budget - s
|
||
order = sorted(range(n), key=lambda i: -priority[i])
|
||
inv_order = sorted(range(n), key=lambda i: priority[i])
|
||
step = 0
|
||
max_steps = max(n * 2000, abs(delta) + n)
|
||
while delta != 0 and step < max_steps:
|
||
step += 1
|
||
if delta > 0:
|
||
moved = False
|
||
for i in order:
|
||
if cur[i] < cap_v:
|
||
cur[i] += 1
|
||
delta -= 1
|
||
moved = True
|
||
break
|
||
if not moved:
|
||
break
|
||
else:
|
||
moved = False
|
||
for i in inv_order:
|
||
if cur[i] > floor_v:
|
||
cur[i] -= 1
|
||
delta += 1
|
||
moved = True
|
||
break
|
||
if not moved:
|
||
break
|
||
return cur
|
||
|
||
|
||
def compute_leaf_allocations(
|
||
volume_key: str,
|
||
leaves: List[Dict[str, Any]],
|
||
rating_raw: Optional[str],
|
||
rules: Optional[Dict[str, Any]] = None,
|
||
) -> Optional[Dict[int, Dict[str, Any]]]:
|
||
"""
|
||
为每个叶节点计算 target_chars、word_count_spec、max_tokens。
|
||
|
||
有技术评分项时按标题相关性分配;无评分项时,若规则为按目标页控总篇且已设页数,
|
||
则均分全稿总预算 B=目标页数×每页字数(否则返回 None,调用方沿用旧逻辑)。
|
||
leaves: [{'id': int, 'section_title': str}, ...]
|
||
"""
|
||
rules = rules or load_rules()
|
||
if not leaves:
|
||
return {}
|
||
|
||
base, core, _, preset_tokens = VOLUME_PRESETS.get(
|
||
volume_key, VOLUME_PRESETS['standard']
|
||
)
|
||
floor_default = int(base * 0.5)
|
||
cap_default = core
|
||
floor_v = int(rules['per_section_floor']) if rules.get('per_section_floor') is not None else floor_default
|
||
cap_v = int(rules['per_section_cap']) if rules.get('per_section_cap') is not None else cap_default
|
||
floor_v = min(floor_v, cap_v)
|
||
alpha = float(rules.get('alpha', 0.85))
|
||
alpha = max(0.0, min(1.0, alpha))
|
||
min_w = float(rules.get('relevance', {}).get('min_rating_weight', 0.01))
|
||
|
||
n = len(leaves)
|
||
mode = (rules.get('budget_mode') or 'anchor_mean').strip()
|
||
pages_cfg = int(getattr(config, 'TARGET_PAGES', 0) or 0)
|
||
pce = max(1, int(getattr(config, 'PAGE_CHAR_ESTIMATE', 700) or 700))
|
||
if mode == 'target_pages' and pages_cfg > 0:
|
||
budget = int(round(pages_cfg * pce))
|
||
elif mode == 'anchor_base':
|
||
budget = int(round(n * base))
|
||
else:
|
||
budget = int(round(n * (base + core) / 2.0))
|
||
|
||
items = parse_rating_json(rating_raw)
|
||
if not items:
|
||
if not (mode == 'target_pages' and pages_cfg > 0):
|
||
return None
|
||
u = [1.0] * n
|
||
top_lists = [[] for _ in range(n)]
|
||
mid = 0.5 * (base + core)
|
||
raw_float = [float(mid)] * n
|
||
else:
|
||
u, top_lists = _raw_utilities(leaves, items, min_w)
|
||
band = core - base
|
||
raw_float = [
|
||
base + band * (alpha * u[i] + (1.0 - alpha) * 0.5) for i in range(n)
|
||
]
|
||
|
||
targets = [int(round(x)) for x in raw_float]
|
||
adjusted = _water_adjust(targets, budget, floor_v, cap_v, u)
|
||
|
||
provider = getattr(config, 'MODEL_PROVIDER', 'openai')
|
||
tok_limit = _PROVIDER_TOKEN_LIMITS.get(provider, 8192)
|
||
base_max_tok = min(preset_tokens, tok_limit)
|
||
scale_tokens = bool(rules.get('max_tokens_scale', False))
|
||
|
||
prompt_cfg = rules.get('prompt') or {}
|
||
top_k = int(prompt_cfg.get('top_k_rating_items', 4))
|
||
intro = (prompt_cfg.get('intro_line') or '').strip() or (
|
||
'本节须对下列技术评分要点作实质展开(结合工艺、流程、标准与可验证措施,禁止空泛承诺与复述招标文件):'
|
||
)
|
||
|
||
out: Dict[int, Dict[str, Any]] = {}
|
||
for i, leaf in enumerate(leaves):
|
||
sid = int(leaf['id'])
|
||
min_chars = max(1, adjusted[i])
|
||
contribs = top_lists[i][:top_k]
|
||
if contribs:
|
||
lines = '\n'.join(f' · {name}' for name, _ in contribs[:top_k])
|
||
spec = (
|
||
f'- 字数硬性要求(必须达到,不达标将续写补足):本节正文不少于 {min_chars} 字\n'
|
||
f'- {intro}\n{lines}\n'
|
||
f'- 内容须由可检验的技术与管理措施支撑,禁止堆砌套话与重复背景'
|
||
)
|
||
else:
|
||
spec = (
|
||
f'- 字数硬性要求(必须达到,不达标将续写补足):本节正文不少于 {min_chars} 字\n'
|
||
f'- 须紧扣章节标题与标书目录定位,充分展开可执行方案细节\n'
|
||
f'- 内容须由可检验的技术与管理措施支撑,禁止堆砌套话与重复背景'
|
||
)
|
||
|
||
max_tok = base_max_tok
|
||
if scale_tokens and base > 0:
|
||
max_tok = int(min(tok_limit, max(1024, base_max_tok * min_chars / base)))
|
||
|
||
out[sid] = {
|
||
'target_chars': min_chars,
|
||
'word_count_spec': spec,
|
||
'max_tokens': max_tok,
|
||
}
|
||
return out
|
||
|
||
|
||
def continuation_threshold(target_chars: int) -> int:
|
||
"""与 generator._get_min_chars 一致:续写到约目标字数的 65% 即停(多轮叠加逼近全文目标)。"""
|
||
return int(max(200, target_chars * 0.65))
|