Skip to content

自定义 RAG Agent

本教程将引导你使用 LangGraph 的基本组件构建一个可自主决策的 RAG(检索增强生成)Agent。与简单的 "检索-生成" 管道不同,Agentic RAG 允许模型根据查询自主决定是否需要检索、检索什么内容,以及是否需要多次检索或直接生成答案。

概述

传统 RAG 系统遵循固定的 "检索 → 生成" 流水线。Agentic RAG 则赋予 LLM 更大的控制权:

  • 判断是否需要检索:对于模型已掌握的知识,可直接回答。
  • 自主改写查询:如果检索结果不理想,可重写查询再次检索。
  • 多步推理:对于复杂问题,可分解为多个子查询分别检索。

我们使用 LangGraph 的 StateGraph 来实现这一流程。

环境准备

bash
pip install langchain langgraph langchain-openai chromadb

配置 OpenAI API 密钥:

python
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key: ")

定义状态

LangGraph 的核心是状态图(StateGraph)。我们首先定义图的状态——即在节点间传递的数据结构。

python
from typing import Annotated, List, Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, MessagesState
from langchain_core.messages import BaseMessage

class AgentState(MessagesState):
    """Agent 的状态,继承自 MessagesState 以支持消息历史。"""
    question: str
    generation: str
    documents: List[str]
    retries: int

构建检索工具

使用 Chroma 向量数据库创建检索器:

python
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# 加载文档并分块
loader = TextLoader("data/knowledge_base.txt")
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200
)
docs = text_splitter.split_documents(documents)

# 创建向量存储
vectorstore = Chroma.from_documents(
    documents=docs,
    embedding=OpenAIEmbeddings()
)

retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

构建 Agent 节点

1. 检索节点

python
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

def retrieve_node(state: AgentState) -> AgentState:
    """根据问题检索相关文档片段。"""
    question = state["question"]
    retrieved_docs = retriever.invoke(question)
    return {
        **state,
        "documents": [doc.page_content for doc in retrieved_docs]
    }

2. 生成节点

python
SYSTEM_PROMPT = """你是一个专业的问答助手。根据提供的上下文信息回答用户问题。
如果上下文信息不足以回答问题,请如实说明,不要编造答案。

上下文:
{context}"""

def generate_node(state: AgentState) -> AgentState:
    """基于检索到的文档生成答案。"""
    context = "\n\n".join(state.get("documents", []))
    prompt = ChatPromptTemplate.from_messages([
        ("system", SYSTEM_PROMPT),
        ("human", "{question}")
    ])
    
    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
    chain = prompt | llm
    response = chain.invoke({
        "context": context,
        "question": state["question"]
    })
    
    return {
        **state,
        "generation": response.content
    }

3. 路由节点——决定是否需要检索

python
from langchain_core.output_parsers import JsonOutputParser

ROUTER_PROMPT = """根据用户问题,判断是否需要检索外部知识来回答。
如果问题涉及最新的信息、特定数据、或模型训练知识范围之外的内容,请回答"retrieve"。
如果问题模型可以直接回答(如常识、推理、翻译等),请回答"direct"。

仅返回 JSON 格式:{{"decision": "retrieve" 或 "direct"}}"""

def router_node(state: AgentState) -> AgentState:
    """判断是否需要执行检索。"""
    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
    prompt = ChatPromptTemplate.from_messages([
        ("system", ROUTER_PROMPT),
        ("human", state["question"])
    ])
    chain = prompt | llm | JsonOutputParser()
    result = chain.invoke({})
    return {**state, "next_action": result["decision"]}

def should_retrieve(state: AgentState) -> Literal["retrieve", "direct"]:
    """条件边:根据 router 节点的输出决定下一步。"""
    return state.get("next_action", "retrieve")

4. 直接回答节点

python
def direct_answer_node(state: AgentState) -> AgentState:
    """不检索,直接回答。"""
    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
    response = llm.invoke(state["question"])
    return {
        **state,
        "generation": response.content,
        "documents": []
    }

5. 查询重写节点(可选)

python
QUERY_REWRITE_PROMPT = """根据原始问题和已检索到的文档,判断是否需要改写查询以获得更好的结果。
如果需要改写,返回改写后的查询;如果当前结果已足够,返回原始问题。

原始问题:{question}

已检索到的文档摘要:
{summary}"""

def query_rewrite_node(state: AgentState) -> AgentState:
    """当检索结果不理想时重写查询。"""
    # 将文档合并为摘要
    documents = state.get("documents", [])
    summary = "\n".join([doc[:200] for doc in documents[:2]])
    
    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.3)
    prompt = ChatPromptTemplate.from_messages([
        ("system", QUERY_REWRITE_PROMPT),
        ("human", "请决定是否改写查询。")
    ])
    response = llm.invoke({
        "question": state["question"],
        "summary": summary
    })
    
    return {
        **state,
        "question": response.content,
        "retries": state.get("retries", 0) + 1
    }

def should_rewrite(state: AgentState) -> Literal["rewrite", "generate"]:
    """判断是否需要重写查询。"""
    if state.get("retries", 0) < 2:
        return "rewrite"
    return "generate"

组装 StateGraph

现在将所有节点连接成图:

python
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.memory import MemorySaver

# 初始化图
workflow = StateGraph(AgentState)

# 添加节点
workflow.add_node("router", router_node)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("rewrite", query_rewrite_node)
workflow.add_node("generate", generate_node)
workflow.add_node("direct_answer", direct_answer_node)

# 设置入口点
workflow.set_entry_point("router")

# 添加条件边
workflow.add_conditional_edges(
    "router",
    should_retrieve,
    {
        "retrieve": "retrieve",
        "direct": "direct_answer"
    }
)

# 检索后判断是否需要重写
workflow.add_conditional_edges(
    "retrieve",
    should_rewrite,
    {
        "rewrite": "rewrite",
        "generate": "generate"
    }
)

# 重写后重新检索
workflow.add_edge("rewrite", "retrieve")

# 生成和直接回答结束
workflow.add_edge("generate", END)
workflow.add_edge("direct_answer", END)

# 编译图
memory = MemorySaver()
agent = workflow.compile(checkpointer=memory)

可视化编译后的图结构:

python
from IPython.display import Image, display

display(Image(agent.get_graph().draw_mermaid_png()))

运行 Agent

python
# 需要检索的问题
config = {"configurable": {"thread_id": "1"}}
result = agent.invoke(
    {"question": "LangGraph 中的 StateGraph 是什么?"},
    config
)
print(result["generation"])
python
# 无需检索的问题
config = {"configurable": {"thread_id": "2"}}
result = agent.invoke(
    {"question": "Python 的列表推导式怎么用?"},
    config
)
print(result["generation"])

使用 create_agent API 简化

LangGraph 提供了 create_agent 便捷函数,可以快速创建标准的 ReAct Agent:

python
from langgraph.prebuilt import create_agent

# 将检索器包装为工具
from langchain_core.tools import tool

@tool
def retrieve_knowledge(question: str) -> str:
    """从知识库中检索与问题相关的信息。"""
    docs = retriever.invoke(question)
    return "\n\n".join([doc.page_content for doc in docs])

# 使用 create_agent
llm = ChatOpenAI(model="gpt-4o-mini")
rag_agent = create_agent(
    llm,
    tools=[retrieve_knowledge],
    system_prompt="你是一个 RAG 助手,使用检索工具获取信息后回答用户问题。"
)

# 添加消息历史记忆
from langgraph.checkpoint.memory import MemorySaver
rag_agent_with_memory = create_agent(
    llm,
    tools=[retrieve_knowledge],
    system_prompt="你是一个 RAG 助手,使用检索工具获取信息后回答用户问题。",
    checkpointer=MemorySaver()
)

# 运行
result = rag_agent_with_memory.invoke(
    {"messages": [{"role": "user", "content": "LangGraph 是什么?"}]},
    config={"configurable": {"thread_id": "3"}}
)
print(result["messages"][-1].content)

进阶:自问自答模式

更高级的 RAG Agent 可以自主分解问题:

python
SUBQUERY_PROMPT = """将复杂问题分解为多个子问题,逐一检索后再综合回答。

原始问题:{question}
最多子问题数:{max_subqueries}"""

def decompose_and_retrieve(state: AgentState) -> AgentState:
    """将复杂问题分解为子查询并分别检索。"""
    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
    prompt = ChatPromptTemplate.from_messages([
        ("system", SUBQUERY_PROMPT),
        ("human", "请分解问题。")
    ])
    chain = prompt | llm
    response = chain.invoke({
        "question": state["question"],
        "max_subqueries": 3
    })
    
    # 分别检索每个子查询
    all_docs = []
    for subq in response.content.split("\n"):
        if subq.strip():
            docs = retriever.invoke(subq)
            all_docs.extend([doc.page_content for doc in docs])
    
    return {
        **state,
        "documents": all_docs
    }

下一步

本站为非官方中文学习站点,不代表 LangChain 官方。部分内容参考官方文档并重新整理为中文学习笔记。