Chapter 05

Dependency Injection · 像 FastAPI 一样写 Agent

Agent 离不开外部世界——数据库、API 客户端、当前用户、配置。Pydantic AI 把这事处理得和 FastAPI 的 Depends() 一模一样:注入、测试、替换,全透明

一、为什么 Agent 需要 DI?

考虑一个典型的客服 Agent——它要能查订单。"查订单"这个工具内部要访问数据库;数据库连接不能硬编码在工具里,否则:

DI(依赖注入)是工程界的经典答案:让"需要什么"显式声明,"谁来提供"由调用方决定。FastAPI 用 Depends(),Spring 用 @Autowired,Pydantic AI 用 deps_type + RunContext.deps

二、最小示例:三步把 DB 注入进去

Step 1:定义 Deps 数据类

from dataclasses import dataclass
import httpx

@dataclass
class AppDeps:
    db: object                    # 假装是 DB 连接池
    http: httpx.AsyncClient
    current_user_id: int
    api_key: str

@dataclass(或 Pydantic BaseModel,都行),字段类型要写清楚——它们会成为 RunContext[AppDeps] 的类型提示,让 IDE/mypy 享受补全。

Step 2:Agent 绑定 deps_type

from pydantic_ai import Agent, RunContext

agent = Agent(
    "openai:gpt-4o",
    deps_type=AppDeps,
    system_prompt="你是电商客服助手。",
)

Step 3:工具里通过 ctx.deps 访问

@agent.tool
async def my_orders(ctx: RunContext[AppDeps]) -> list[dict]:
    """查当前登录用户的所有订单。"""
    return await ctx.deps.db.fetch(
        "SELECT id, total, status FROM orders WHERE user_id = ?",
        ctx.deps.current_user_id,
    )

@agent.tool
async def track_parcel(ctx: RunContext[AppDeps], tracking_no: str) -> dict:
    """查询物流状态。"""
    r = await ctx.deps.http.get(
        f"https://api.kuaidi.com/track/{tracking_no}",
        headers={"X-API-Key": ctx.deps.api_key},
    )
    return r.json()

Step 4:调用时注入真实依赖

async def handle_request(user_id: int, question: str):
    async with httpx.AsyncClient() as http:
        deps = AppDeps(
            db=my_db_pool,
            http=http,
            current_user_id=user_id,
            api_key=os.environ["KUAIDI_KEY"],
        )
        result = await agent.run(question, deps=deps)
        return result.output
关键一句话:工具只声明"我需要 AppDeps",由调用方(业务代码)组装 deps。这就是控制反转(IoC)——工具代码不关心"从哪来"。

三、和 FastAPI 的 Depends() 对比

FastAPI

def get_db():
    return db_pool

@app.get("/orders")
async def list_orders(
    user_id: int,
    db = Depends(get_db),
):
    return await db.fetch(...)

# 测试
app.dependency_overrides[get_db] = lambda: mock_db

Pydantic AI

@dataclass
class Deps:
    db: object

agent = Agent("...", deps_type=Deps)

@agent.tool
async def list_orders(
    ctx: RunContext[Deps],
):
    return await ctx.deps.db.fetch(...)

# 测试
with agent.override(deps=Deps(db=mock_db)):
    ...

核心区别:

四、Deps 用 dataclass 还是 BaseModel?

两个都行,实践中推荐 dataclass:

Pydantic 官方示例几乎全用 dataclass——这是明确的推荐。

五、override:测试时的黄金用法

Agent.override() 是一个上下文管理器,临时替换 deps / model / 工具。这是测试时保命级的功能。

import pytest
from pydantic_ai.models.test import TestModel

class MockDb:
    async def fetch(self, sql, *args):
        return [{"id": 1, "total": 99.0, "status": "paid"}]

@pytest.mark.asyncio
async def test_my_orders():
    mock_deps = AppDeps(db=MockDb(), http=None, current_user_id=1, api_key="")
    with agent.override(deps=mock_deps, model=TestModel()):
        result = await agent.run("我的订单都有哪些?")
        assert len(result.output) >= 1

override 能同时替换:depsmodeltoolsets。离开 with 块后自动还原。第 9 章讲测试时还会细讲 TestModel/FunctionModel。

六、Deps 也能进 system_prompt

第 2 章讲动态 system_prompt 时用过这个模式,这里正式系统化:RunContext.deps 里的信息经常要塞给 LLM 当上下文。

@agent.system_prompt
def add_user_context(ctx: RunContext[AppDeps]) -> str:
    return f"当前用户 ID: {ctx.deps.current_user_id},请只给他本人能看到的信息。"

@agent.system_prompt
async def add_user_profile(ctx: RunContext[AppDeps]) -> str:
    profile = await ctx.deps.db.fetch_one(
        "SELECT name, vip_level FROM users WHERE id = ?",
        ctx.deps.current_user_id,
    )
    return f"用户档案: {profile}"
注意缓存:动态 system_prompt 每次 run 都会被调用,如果里面有查询请先做好请求级缓存——比如在 Deps 里预查一次,工具函数只读 Deps 字段。

七、生命周期:Deps 应该怎么构造

Deps 里常见的两类东西:

  1. 长生命周期对象:DB 连接池、httpx.AsyncClient、配置——进程启动时建一次,重复用
  2. 短生命周期信息:当前用户 ID、本次请求 ID——每个请求不同

实践中,FastAPI 集成的标准做法是:

from fastapi import FastAPI, Depends, Request
from contextlib import asynccontextmanager

@asynccontextmanager
async def lifespan(app: FastAPI):
    app.state.db = await create_pool()
    app.state.http = httpx.AsyncClient()
    yield
    await app.state.db.close()
    await app.state.http.aclose()

app = FastAPI(lifespan=lifespan)

def build_deps(request: Request) -> AppDeps:
    return AppDeps(
        db=request.app.state.db,               # 长生命周期
        http=request.app.state.http,           # 长生命周期
        current_user_id=request.state.user_id, # 每请求不同,鉴权中间件填
        api_key=os.environ["KUAIDI_KEY"],
    )

@app.post("/chat")
async def chat(question: str, deps: AppDeps = Depends(build_deps)):
    result = await agent.run(question, deps=deps)
    return {"answer": result.output}

这一招把 FastAPI 的 Depends 和 Pydantic AI 的 deps_type 串联在一起——前者负责从 request 里生出 Deps,后者负责把 Deps 交给 Agent 用。完美无缝。

八、嵌套工具 & 共享 Deps

Deps 在 Agent 生命周期里是共享且不可变的——所有工具、system_prompt、output_validator 看到的是同一个实例。你甚至可以在工具里修改它(但不推荐——测试噩梦):

@dataclass
class Deps:
    db: object
    call_count: int = 0        # mutable,别这么用,示例而已

@agent.tool
async def count_me(ctx: RunContext[Deps]):
    ctx.deps.call_count += 1
    return ctx.deps.call_count

推荐的做法:Deps 尽量设计成只读,工具需要累计状态就自己开 deps.counter = CounterClass() 里面装个线程安全的计数器。

九、真实例子:订单查询 Agent 全套

import asyncio
from dataclasses import dataclass
from typing import Literal
import httpx
from pydantic import BaseModel
from pydantic_ai import Agent, ModelRetry, RunContext

# ─── Domain ───────────────────────────────
class Order(BaseModel):
    id: int
    total: float
    status: Literal["paid", "shipped", "delivered", "refunded"]

class Answer(BaseModel):
    reply: str
    orders_referenced: list[int]

# ─── Deps ─────────────────────────────────
@dataclass
class Deps:
    db: object
    http: httpx.AsyncClient
    user_id: int
    user_name: str

# ─── Agent ────────────────────────────────
agent = Agent(
    "openai:gpt-4o",
    deps_type=Deps,
    output_type=Answer,
    system_prompt="你是电商客服。只能回答关于当前登录用户的订单问题。",
    retries=2,
)

@agent.system_prompt
def add_user(ctx: RunContext[Deps]) -> str:
    return f"当前用户: {ctx.deps.user_name} (ID: {ctx.deps.user_id})"

@agent.tool
async def list_my_orders(ctx: RunContext[Deps], limit: int = 10) -> list[Order]:
    """列出当前用户最近的订单。"""
    rows = await ctx.deps.db.fetch(
        "SELECT id, total, status FROM orders WHERE user_id=? ORDER BY created_at DESC LIMIT ?",
        ctx.deps.user_id, limit,
    )
    return [Order(**r) for r in rows]

@agent.tool
async def get_order(ctx: RunContext[Deps], order_id: int) -> Order:
    """查一笔具体的订单。如果订单不属于当前用户会拒绝。"""
    row = await ctx.deps.db.fetch_one(
        "SELECT user_id, id, total, status FROM orders WHERE id=?", order_id,
    )
    if not row:
        raise ModelRetry(f"订单 {order_id} 不存在。")
    if row["user_id"] != ctx.deps.user_id:
        raise ModelRetry(f"订单 {order_id} 不属于当前用户,请选别的。")
    return Order(**row)

# ─── Runtime ──────────────────────────────
async def main():
    async with httpx.AsyncClient() as http:
        deps = Deps(db=real_db, http=http, user_id=42, user_name="张三")
        result = await agent.run("我最近的订单都发货了吗?", deps=deps)
        print(result.output.reply)
        print("涉及订单:", result.output.orders_referenced)

这个 Agent:

十、Deps 里放什么,不放什么

适合放 Deps不适合放 Deps
DB 连接池单次查询的结果缓存(短命)
HTTP client / SDK clientAgent 自己(循环引用)
当前登录用户、tenant_idLLM 的 system_prompt 文本(prompt 应写在构造时)
请求 trace_id / loggermessage_history(用专门参数传)
配置(API keys, feature flags)每次 run 都变化的大对象(会拖慢 trace)
向量数据库 client绑死到某次调用的一次性计算结果

十一、八个常见坑

  1. 忘记 deps_type=:Agent 构造时没声明,工具里的 ctx.deps 就是 None,类型提示也全丢。
  2. Deps 类放了未序列化对象但又试图 log 它:Logfire/OpenTelemetry 打点时会 try 序列化,写个 __repr__ 控制打印内容。
  3. 在 async 工具里用同步 DB 客户端:依然阻塞事件循环。Deps 里的客户端必须是 async 版本。
  4. Deps 里塞了 Session/Scope 对象跨 run 共享:典型的 race condition 源头。请求级对象只给请求级 Agent run 用。
  5. 忘记 override 是上下文管理器:写成 agent.override(...) 不配 with——不会生效。
  6. Deps 类字段太多:二十个字段的 Deps 是代码味道,通常意味着该拆多个 Agent 或该按领域拆 Deps。
  7. 在动态 system_prompt 里 hit 数据库:每次 run 都多一次查询。Deps 构造时预查或加缓存。
  8. Deps 里直接放秘钥字符串:随 trace 一起打出来,泄漏风险。放 secret 用 SecretStr 或在 __repr__ 里屏蔽。

十二、本章小结

三条心法:
deps_type + RunContext.deps 是 Pydantic AI 版的 FastAPI Depends——工具声明需要,业务代码提供。
② Deps 设计成只读、长生命周期对象 + 请求级元数据的组合,别塞临时结果。
③ 测试用 agent.override(deps=...) 上下文管理器替换为 mock,自然隔离,不需要 monkey patch。