Chapter 08

Assertion & Suggestion:让 LLM 自己改错

Optimizer 在训练期帮你调 prompt,Assertion 在运行时帮你守约束。LLM 输出不合规?DSPy 会把错误塞回 prompt 让它自己重试。

Assert vs Suggest

API违反时用途
dspy.Assert最终抛异常 → 调用失败硬约束(格式错 = 彻底不可用)
dspy.Suggest尝试几次还不过就警告但继续软约束(越满足越好,但不强求)

Assert 示例:必须是合法 JSON

import dspy, json

class GenSpec(dspy.Signature):
    """为 API 生成请求体示例"""
    endpoint: str = dspy.InputField()
    body: str = dspy.OutputField(desc="JSON 字符串")

class SpecGen(dspy.Module):
    def __init__(self):
        super().__init__()
        self.gen = dspy.ChainOfThought(GenSpec)

    def forward(self, endpoint):
        pred = self.gen(endpoint=endpoint)

        # 硬约束:必须能 parse JSON
        try:
            json.loads(pred.body)
        except json.JSONDecodeError as e:
            dspy.Assert(
                False,
                msg=f"body 必须是合法 JSON,当前解析失败: {e}",
                target_module=self.gen,
            )

        return pred

backtrack 机制

DSPy 的 Assert 不是"立刻 raise":

  1. 第一次 Assert 失败 → DSPy 捕获错误消息
  2. 把错误消息作为反馈,附加到 target_module 的 prompt 再调一次
  3. 默认最多 retry N 次(通过 dspy.settings.max_backtrack 配)
  4. 仍然失败 → 抛出 DSPyAssertionError

加了 Assert 的 Module 必须用 assert_transform_module 包一层才启用 backtrack:

from dspy.primitives.assertions import assert_transform_module, backtrack_handler

spec_gen = assert_transform_module(SpecGen(), backtrack_handler)

spec_gen(endpoint="/api/order")   # 出错自动重试

Suggest 示例:鼓励短答案

class Summarizer(dspy.Module):
    def __init__(self):
        super().__init__()
        self.sum = dspy.ChainOfThought("text -> summary")

    def forward(self, text):
        pred = self.sum(text=text)

        # 软约束:超过 80 字 → 重试几次,还不行就放行(分数打折)
        dspy.Suggest(
            len(pred.summary) < 80,
            msg="摘要过长,请压缩到 80 字以内",
            target_module=self.sum,
        )

        return pred

Suggest 失败不会让调用爆炸,适合"能守最好、守不住也能用"的场景。

典型约束模式

1. 事实引用检查

def forward(self, question, context):
    pred = self.cot(context=context, question=question)

    # 答案里至少有一个 context 片段的关键短语
    has_cite = any(p[:30] in pred.answer for p in context)
    dspy.Suggest(has_cite, msg="答案必须引用 context 中的原话片段", target_module=self.cot)
    return pred

2. 枚举值限制

allowed = {"bug", "feature", "billing"}
dspy.Assert(
    pred.category in allowed,
    msg=f"category 必须是 {allowed} 之一,当前: {pred.category}",
    target_module=self.classify,
)

3. 数值范围

dspy.Suggest(
    0 <= pred.confidence <= 1,
    msg="confidence 必须在 [0,1]",
    target_module=self.score,
)

4. 长度约束

dspy.Suggest(
    30 <= len(pred.title) <= 80,
    msg="标题 30-80 字最佳",
    target_module=self.title_gen,
)

5. 去重

dspy.Suggest(
    len(pred.tags) == len(set(pred.tags)),
    msg="tags 不应有重复",
    target_module=self.tag_gen,
)

多个 Assertion 串联

def forward(self, text):
    pred = self.gen(text=text)

    dspy.Assert(pred.body, msg="body 不能为空", target_module=self.gen)
    dspy.Suggest(len(pred.body) < 1000, msg="不超过 1000 字", target_module=self.gen)
    dspy.Suggest(not pred.body.startswith("抱歉"), msg="不要道歉", target_module=self.gen)

    return pred

DSPy 会按顺序逐个检查,违反的那条会触发 backtrack。

和 Optimizer 结合

Assertion 不仅在运行时生效,训练期也会被 Optimizer 感知:

所以加 Assertion 相当于给 Optimizer 多一条 reward 信号——强烈建议生产 Module 都带上。

配置

dspy.settings.configure(
    max_backtrack=3,            # 最多重试 3 次
    bypass_assert=False,        # 生产保持 False;评估时可临时 True 测原始分
    bypass_suggest=False,
)

Assertion 的设计原则

  1. 约束要确定性可测:json.loads(x) 行,"语义合理" 不行(那是 metric 的事)
  2. msg 要具体可改:LLM 只看这段文字去修,写"输出不对"是废话
  3. 先硬约束后软约束:顺序影响 retry 的优先级
  4. 别堆 10 条:超过 3-4 条后 retry 大概率都无法同时满足,精简到最关键的
  5. 失败日志记下:线上 Assert 失败要进 metrics,看哪条最常出问题

调试 Assertion 失败

from dspy.primitives.assertions import DSPyAssertionError

try:
    pred = spec_gen(endpoint="/api/order")
except DSPyAssertionError as e:
    print("最后一次失败原因:", e.msg)
    print("中间尝试:")
    for attempt in e.history:
        print(attempt.prompt[-500:])
        print("---")

本章小结