Chapter 04

自定义 Module:把多步 LLM 调用组合起来

真实任务很少一次调用就完事。继承 dspy.Module,把 Signature、检索、判断、路由拼成一个可复用、可优化的程序。

最简 RAG Module

class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate = dspy.ChainOfThought("context, question -> answer")

    def forward(self, question):
        passages = self.retrieve(question).passages
        return self.generate(context=passages, question=question)

# 配好 RM 就能直接跑
dspy.configure(lm=lm, rm=dspy.ColBERTv2(url="http://..."))
rag = RAG()
rag("谁发明了 Transformer?").answer

dspy.Module 的约定

__init__
把子 Module(Predict/CoT/Retrieve/自定义)存成实例属性。DSPy 会自动发现并纳入优化范围。
forward(self, ...inputs)
定义实际逻辑。参数名要和 Signature 的 InputField 对应。返回 dspy.Prediction 或一个子 Module 的调用结果。
__call__
不要自己覆盖!基类会自动走 forward,并把所有中间调用记录进 trace,Optimizer 要用。

条件分支 Module

class SmartQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.classify = dspy.Predict("question -> kind: Literal['factual','opinion','calc']")
        self.retrieve = dspy.Retrieve(k=5)
        self.factual = dspy.ChainOfThought("context, question -> answer")
        self.opinion = dspy.ChainOfThought("question -> answer")
        self.calc    = dspy.ProgramOfThought("question -> answer")

    def forward(self, question):
        kind = self.classify(question=question).kind
        if kind == "factual":
            ctx = self.retrieve(question).passages
            return self.factual(context=ctx, question=question)
        if kind == "calc":
            return self.calc(question=question)
        return self.opinion(question=question)

把分类器也当成 Module——优化器会同时优化分类器 + 每个分支。

多跳检索:MultiHop

一个问题可能需要多次检索(比如"比较 A 和 B",得分别查 A 和 B):

class MultiHopRAG(dspy.Module):
    def __init__(self, max_hops=3, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_query = dspy.ChainOfThought(
            "context, question -> search_query"
        )
        self.generate_answer = dspy.ChainOfThought(
            "context, question -> answer"
        )
        self.max_hops = max_hops

    def forward(self, question):
        context = []
        for _ in range(self.max_hops):
            query = self.generate_query(
                context=context, question=question
            ).search_query
            passages = self.retrieve(query).passages
            context = dspy.deduplicate(context + passages)
        return self.generate_answer(context=context, question=question)
关键技巧:同名 Module 可以多次调用
上面 generate_query 在循环里被调用多次,但只是一个 Module。Optimizer 优化一次,循环里每次调用都用同一份优化后的 prompt/demos。

参数化 Module

通过 __init__ 把超参暴露出来,便于不同场景复用:

class Summarizer(dspy.Module):
    def __init__(self, max_words=80, style="neutral"):
        super().__init__()
        self.max_words = max_words
        self.style = style
        sig = type(
            "SumSig", (dspy.Signature,),
            {
                "__doc__": f"产生 {style} 风格的摘要,不超过 {max_words} 字。",
                "text": dspy.InputField(),
                "summary": dspy.OutputField(),
                "__annotations__": {"text": str, "summary": str},
            },
        )
        self.predict = dspy.Predict(sig)

    def forward(self, text):
        return self.predict(text=text)

嵌套 Module

Module 里用别的 Module,DSPy 的 trace 会正确 unroll:

class ArticleWriter(dspy.Module):
    def __init__(self):
        super().__init__()
        self.outline = dspy.ChainOfThought("topic -> outline: list[str]")
        self.section_writer = SectionWriter()        # 子 Module
        self.editor = dspy.ChainOfThought("draft -> polished")

    def forward(self, topic):
        outline = self.outline(topic=topic).outline
        sections = [self.section_writer(topic=topic, point=p).text for p in outline]
        draft = "\n\n".join(sections)
        return self.editor(draft=draft)

查看 Module 内部

rag = RAG()

# 列出所有 Predict 子模块
for name, pred in rag.named_predictors():
    print(name, pred.signature)

# 查看最近一次调用 trace
out = rag("...")
dspy.inspect_history(n=3)    # 打印最近 3 次 LLM 调用的完整 prompt/response

保存与加载

# 保存 Module 状态(含优化后的 demos 和指令)
rag.save("rag_compiled.json")

# 加载
new = RAG()
new.load("rag_compiled.json")

调试 Module

三板斧
dspy.inspect_history():看 prompt 真的长什么样
② 临时把某个子 Module 换成 dspy.Predict("... -> ..."),二分法定位问题
③ 开 LM 的 cache=False 排除缓存干扰

Module 设计准则

  1. 职责单一:一个 Module 最好只负责 1-3 步,太深了难以优化
  2. 接口明确:forward 参数名和 Signature 对齐,不要用 **kwargs 魔法
  3. 子 Module 命名语义化:self.query_genself.cot1
  4. 避免副作用:Module 不写文件、不改全局,Optimizer 要反复调用
  5. 可测:forward 逻辑能 mock LM 跑单测,至少覆盖主分支

常见坑

症状解法
忘记 super().__init__()保存/优化失败永远调父类构造
在 forward 里 new ModuleOptimizer 抓不到子 Module 必须在 __init__ 里定义
forward 返回 dict属性访问报错返回 dspy.Prediction(...) 或子 Module 调用
循环里不 dedupecontext 爆炸dspy.deduplicate 或自行维护 set

本章小结