如何构建一个可以回答有关 SQL 数据库问题的 Agent 呢?

从高层次来看,SQL Agent 需要执行下面的步骤:

  1. 从数据库中获取可用的表
  2. 确定哪些表与问题相关
  3. 获取相关表的架构(schemas)
  4. 根据问题和架构中的信息生成查询
  5. 使用 LLM 仔细检查查询中是否存在常见错误
  6. 执行查询并返回结果
  7. 纠正数据库引擎发现的错误,直到查询成功
  8. 根据结果制定应对措施

设置

安装依赖项

pip install -U langgraph langchain_community "langchain[openai]"

选择 LLM

from langchain_community.chat_models.tongyi import ChatTongyi

llm = ChatTongyi(model='qwen-plus', api_key=api_key)

配置数据库

需要创建一个数据库用于交互。创建一个 SQLite 数据库。SQLite 是一个轻量级数据库,易于设置和使用。我们将加载该chinook数据库,它是一个代表数字媒体商店的示例数据库。

import requests

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"

response = requests.get(url)

if response.status_code == 200:
# Open a local file in binary write mode
with open("Chinook.db", "wb") as file:
# Write the content of the response (the file) to the local file
file.write(response.content)
print("File downloaded and saved as Chinook.db")
else:
print(f"Failed to download the file. Status code: {response.status_code}")

使用包中提供的便捷 SQL 数据库包装器langchain_community与数据库进行交互。该包装器提供了一个简单的接口来执行 SQL 查询并获取结果:

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f'Sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')
Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]

数据库交互工具

langchain-community实现了一些用于与我们的交互的内置工具SQLDatabase,包括列出表、读取表模式以及检查和运行查询的工具:

from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
print(f"{tool.name}: {tool.description}\n")
sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error me
ssage will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.

sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3

sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.

sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!

使用 prebuilt agent

上面我们获取了需要的工具,现在我们可以通过一行代码初始化一个预建的代理。

为了定义 Agent 的行为,我们需要编写一个 System Prompt。

from langgraph.prebuilt import create_react_agent

system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
dialect=db.dialect,
top_k=5,
)

agent = create_react_agent(
llm,
tools,
prompt=system_prompt,
)

在这个 System Prompt 中包含许多指令,例如始终在其他工具之前或之后运行特定工具。后续可以通过图表的结构强制执行这些行为,从而提供更高程度的控制并简化提示。

运行 Agent 观察结构:

================================ Human Message =================================

Which sales agent made the most in sales in 2009?
================================== Ai Message ==================================
Tool Calls:
sql_db_list_tables (call_542d412460ad4a76b337b4)
Call ID: call_542d412460ad4a76b337b4
Args:
tool_input:
================================= Tool Message =================================
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (call_6475c9b2c7374bbe87d96e)
Call ID: call_6475c9b2c7374bbe87d96e
Args:
table_names: Employee, Invoice
================================= Tool Message =================================
Name: sql_db_schema


CREATE TABLE "Employee" (
"EmployeeId" INTEGER NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"FirstName" NVARCHAR(20) NOT NULL,
"Title" NVARCHAR(30),
"ReportsTo" INTEGER,
"BirthDate" DATETIME,
"HireDate" DATETIME,
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60),
PRIMARY KEY ("EmployeeId"),
FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId LastName FirstName Title ReportsTo BirthDate HireDate Address City State Country PostalCode Phone Fax Email
1 Adams Andrew General Manager None 1962-02-18 00:00:00 2002-08-14 00:00:00 11120 Jasper Ave NW Edmonton AB Canada T5K 2N1 +1 (780) 428-9482 +1 (780) 428-3457 andrew@chinookcorp.com
2 Edwards Nancy Sales Manager 1 1958-12-08 00:00:00 2002-05-01 00:00:00 825 8 Ave SW Calgary AB Canada T2P 2T3 +1 (403) 262-3443 +1 (403) 262-3322 nancy@chinookcorp.com
3 Peacock Jane Sales Support Agent 2 1973-08-29 00:00:00 2002-04-01 00:00:00 1111 6 Ave SW Calgary AB Canada T2P 5M5 +1 (403) 262-3443 +1 (403) 262-6712 jane@chinookcorp.com
*/


CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (call_f1709a0a36104b1392acf6)
Call ID: call_f1709a0a36104b1392acf6
Args:
table_names: Customer
================================= Tool Message =================================
Name: sql_db_schema


CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_query_checker (call_d7acd55bb78b468d97fc00)
Call ID: call_d7acd55bb78b468d97fc00
Args:
query: SELECT e.EmployeeId, e.FirstName, e.LastName, SUM(i.Total) AS TotalSales FROM Employee e JOIN Customer c ON e.EmployeeId = c.SupportRepId JOIN Invoice i ON i.CustomerId = c.CustomerId WHERE strftime('%Y', i.InvoiceDate) = '2009' GROUP BY e.EmployeeId ORDER BY TotalSales DESC LIMIT 1;
================================= Tool Message =================================
Name: sql_db_query_checker

SELECT e.EmployeeId, e.FirstName, e.LastName, SUM(i.Total) AS TotalSales
FROM Employee e
JOIN Customer c ON e.EmployeeId = c.SupportRepId
JOIN Invoice i ON i.CustomerId = c.CustomerId
WHERE strftime('%Y', i.InvoiceDate) = '2009'
GROUP BY e.EmployeeId
ORDER BY TotalSales DESC
LIMIT 1;
================================== Ai Message ==================================
Tool Calls:
sql_db_query (call_f7366014021146ae819af7)
Call ID: call_f7366014021146ae819af7
Args:
query: SELECT e.EmployeeId, e.FirstName, e.LastName, SUM(i.Total) AS TotalSales FROM Employee e JOIN Customer c ON e.EmployeeId = c.SupportRepId JOIN Invoice i ON i.CustomerId = c.CustomerId WHERE strftime('%Y', i.InvoiceDate) = '2009' GROUP BY e.EmployeeId ORDER BY TotalSales DESC LIMIT 1;
================================= Tool Message =================================
Name: sql_db_query

[(5, 'Steve', 'Johnson', 164.34)]
================================== Ai Message ==================================

The sales agent who made the most in sales in 2009 is Steve Johnson, with a total sales amount of $164.34.

效果很好:Agent 正确列出了表,获得了模式,编写了查询,检查了查询,并运行它以告知其最终响应。

自定义 Agent

Prebuilt Agent 可以让我们快速上手,但是每一步中 Agent 都可以去使用全部的 Tool。

在上面的案例中,我们通过设计 System Prompt来限制他的行为。

比如我们要求 Agent 始终从 “list tables” Tool 开始。

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

并在执行查询之前始终运行查询检查工具。

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

为了解决这个问题,我们可以通过自定义 Agent 在 LangGraph 中实现更高程度的控制。

下面,我们将实现一个简单的 ReAct Agent 设置,并为特定的工具调用设置专用节点。我们将使用与 Prebuilt Agent 相同的状态。

我们为以下步骤构建专用节点:

  • 列出数据库表 - Listing DB tables
  • 调用“获取架构”工具 - Calling the “get schema” tool
  • 生成查询 - Generating a query
  • 检查查询 - Checking the query

通过将这些步骤存放在专用节点,我们可以:

  • 在需要的时候让 Agent 强制调用工具
  • 自定义与每个步骤相关的提示
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode

# 从工具集中找到名为 "sql_db_schema" 的工具(用于获取数据库 schema)
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
# 创建一个工具节点,用于获取 schema
get_schema_node = ToolNode([get_schema_tool], name="get_schema")

# 找到名为 "sql_db_query" 的工具(用于执行 SQL 查询)
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
# 创建一个工具节点,用于运行 SQL 查询
run_query_node = ToolNode([run_query_tool], name="run_query")


# 示例:手动构造一个工具调用以列出所有表
def list_tables(state: MessagesState):
tool_call = {
"name": "sql_db_list_tables", # 要调用的工具名称
"args": {}, # 参数为空,表示不需要额外输入
"id": "abc123", # 工具调用的唯一 ID
"type": "tool_call", # 指明这是一条工具调用
}

# 构造 AI 消息表示要调用的工具
tool_call_message = AIMessage(content="", tool_calls=[tool_call])

# 查找实际的工具对象(从 tools 列表中)
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")

# 调用工具并获取响应
tool_message = list_tables_tool.invoke(tool_call)
# 构造 AI 的响应消息,说明有哪些表可用
response = AIMessage(f"Available tables: {tool_message.content}")

# 返回状态更新(消息列表)
return {"messages": [tool_call_message, tool_message, response]}


# 示例:用于调用获取数据库 Schema 的工具
def call_get_schema(state: MessagesState):
# 注意:LangChain 要求模型必须支持 tool_choice="any"(让模型自由选择工具)
# 也支持 tool_choice="工具名称"(强制模型选择某个特定工具)

# 将数据库 schema 工具绑定到 LLM 上,并允许模型自行决定是否使用该工具
llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")

# 向模型传入当前对话历史,让它根据上下文判断是否需要调用工具
response = llm_with_tools.invoke(state["messages"])

# 返回更新后的状态,其中包含模型的响应消息
return {"messages": [response]}


generate_query_system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
""".format(
dialect=db.dialect, # 当前数据库方言(如 MySQL、PostgreSQL)
top_k=5, # 返回的记录数量限制
)


def generate_query(state: MessagesState):
system_message = {
"role": "system",
"content": generate_query_system_prompt,
}
# We do not force a tool call here, to allow the model to
# respond naturally when it obtains the solution.
llm_with_tools = llm.bind_tools([run_query_tool])
response = llm_with_tools.invoke([system_message] + state["messages"])

return {"messages": [response]}


check_query_system_prompt = """
You are a SQL expert with a strong attention to detail.
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes,
just reproduce the original query.

You will call the appropriate tool to execute the query after running this check.
""".format(dialect=db.dialect)


def check_query(state: MessagesState):
# 构造系统提示词
system_message = {
"role": "system",
"content": check_query_system_prompt,
}

# 从消息中获取上一条工具调用(即待检查的 SQL 查询)
tool_call = state["messages"][-1].tool_calls[0]
# 构造用户消息,内容为 SQL 查询语句
user_message = {"role": "user", "content": tool_call["args"]["query"]}

# 绑定工具并允许模型选择是否调用
llm_with_tools = llm.bind_tools([run_query_tool], tool_choice="any")
# 提交系统提示词和 SQL 查询,让模型检查并生成回复
response = llm_with_tools.invoke([system_message, user_message])
# 保持消息 ID 一致,确保流程追踪一致性
response.id = state["messages"][-1].id

return {"messages": [response]}

这里会有一个小问题,

LangChain 支持的 tool_choice="any" 是为 OpenAI、Anthropic 等模型设计的。但阿里通义模型(如 Qwen)目前只接受 "none""auto"以及 object结构 ,否则会抛出 400 参数错误。

所以我去查询了通义千问的 API,结果如下:

tool_choice string 或 object (可选)
在使用tools参数时,用于控制模型调用指定工具。有三种取值:

  • “none”表示不调用工具。tools参数为空时,默认值为”none”。
  • “auto”表示由模型判断是否调用工具,可能调用也可能不调用。tools参数不为空时,默认值为”auto”。
  • object结构可以指定模型调用的工具。例如tool_choice={“type”: “function”, “function”: {“name”: “user_function”}}。
    • type只支持指定为”function”。
    • function
      • name表示期望被调用的工具名称,例如”get_current_time”。

下面给出配置方法:

  • 场景1:让模型自动判断是否调用工具

    llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="auto")
  • 场景2:不调用工具

    llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="none")
  • 场景3:强制调用某个工具(如 get_schema

    llm_with_tools = llm.bind_tools(
    [get_schema_tool],
    tool_choice={"type": "function", "function": {"name": "get_schema"}}
    )

最后,我们使用 Graph API 将这些步骤组合成一个工作流。我们在查询生成步骤中定义一个条件边(conditional edge),如果生成了查询,该条件边将路由到查询检查器(query checker);如果不存在任何工具调用,则条件边将结束,这样 LLM 就已对查询作出响应。

def should_continue(state: MessagesState) -> Literal[END, "check_query"]:
messages = state["messages"] # 获取当前的消息历史
last_message = messages[-1] # 取最新的一条消息
if not last_message.tool_calls:# 如果该消息没有包含工具调用
return END # 说明不需要再调用工具,流程结束
else:
return "check_query" # 如果有工具调用,转到 "check_query" 节点


builder = StateGraph(MessagesState)
builder.add_node(list_tables)
builder.add_node(call_get_schema)
builder.add_node(get_schema_node, "get_schema")
builder.add_node(generate_query)
builder.add_node(check_query)
builder.add_node(run_query_node, "run_query")

builder.add_edge(START, "list_tables")
builder.add_edge("list_tables", "call_get_schema")
builder.add_edge("call_get_schema", "get_schema")
builder.add_edge("get_schema", "generate_query")
builder.add_conditional_edges(
"generate_query", # 从 generate_query 节点出发
should_continue, # 根据 should_continue 函数判断跳转去向
)
builder.add_edge("check_query", "run_query")
builder.add_edge("run_query", "generate_query")

agent = builder.compile()

我们可以将 Graph 可视化:

from IPython.display import Image

Image(agent.get_graph().draw_mermaid_png())
with open(f"agent.png", "wb") as f:
f.write(agent.get_graph().draw_mermaid_png())

image-20250627160224977

接下来可以像之前一样调用 Graph:

# 设置用户输入的问题
question = "Which sales agent made the most in sales in 2009?"

# 使用流式执行 agent,逐步推进工作流
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]}, # 初始化状态(用户消息)
stream_mode="values", # 返回值为字典中的 values 格式
):
step["messages"][-1].pretty_print() # 打印当前步骤中的最后一条消息

参考内容: