跳转到主要内容

在这本烹饪书中,我们将研究一个Text2SQL用例,在这个用例中,我们从头开始,没有一个漂亮干净的问题、SQL查询或预期响应数据集。尽管eval数据集在学术环境中很受欢迎,但在现实世界中往往不可用。在这种情况下,我们将使用一些简单的手写问题和LLM来构建一个数据集,以基于SQL数据集生成样本。


在此过程中,我们将介绍eval过程的以下组成部分:

在开始之前,请确保您拥有Braintrust帐户。如果你没有,请注册。

 

设置环境


接下来的几个命令将安装一些库,并包含text2sql应用程序的一些帮助代码。您可以在自己的工具中自由复制/粘贴/调整/重用此代码。

%pip install -U autoevals braintrust duckdb datasets openai pyarrow pydantic --quiet


​正在下载数据


我们将使用一个NBA数据集,其中包括2014-2018年的比赛信息。让我们先下载它并四处看看。
我们将使用DuckDB作为数据库,因为它很容易直接嵌入到笔记本中。

import duckdb
from datasets import load_dataset

data = load_dataset("suzyanil/nba-data")["train"]

conn = duckdb.connect(database=":memory:", read_only=False)
conn.register("nba", data.to_pandas())

conn.query("SELECT * FROM nba LIMIT 5").to_df().to_dict(orient="records")[0]
{'Unnamed: 0': 1,
 'Team': 'ATL',
 'Game': 1,
 'Date': '10/29/14',
 'Home': 'Away',
 'Opponent': 'TOR',
 'WINorLOSS': 'L',
 'TeamPoints': 102,
 'OpponentPoints': 109,
 'FieldGoals': 40,
 'FieldGoalsAttempted': 80,
 'FieldGoals.': 0.5,
 'X3PointShots': 13,
 'X3PointShotsAttempted': 22,
 'X3PointShots.': 0.591,
 'FreeThrows': 9,
 'FreeThrowsAttempted': 17,
 'FreeThrows.': 0.529,
 'OffRebounds': 10,
 'TotalRebounds': 42,
 'Assists': 26,
 'Steals': 6,
 'Blocks': 8,
 'Turnovers': 17,
 'TotalFouls': 24,
 'Opp.FieldGoals': 37,
 'Opp.FieldGoalsAttempted': 90,
 'Opp.FieldGoals.': 0.411,
 'Opp.3PointShots': 8,
 'Opp.3PointShotsAttempted': 26,
 'Opp.3PointShots.': 0.308,
 'Opp.FreeThrows': 27,
 'Opp.FreeThrowsAttempted': 33,
 'Opp.FreeThrows.': 0.818,
 'Opp.OffRebounds': 16,
 'Opp.TotalRebounds': 48,
 'Opp.Assists': 26,
 'Opp.Steals': 13,
 'Opp.Blocks': 9,
 'Opp.Turnovers': 9,
 'Opp.TotalFouls': 22}

 

Prototyping text2sql

现在我们已经准备好了基本数据,让我们实现text2sql逻辑。一开始不要过于复杂。我们以后总是可以改进它的实施!

import os
from textwrap import dedent

import braintrust
import openai

client = braintrust.wrap_openai(
    openai.AsyncClient(
        api_key=os.environ["OPENAI_API_KEY"],
        base_url="https://api.braintrust.dev/v1/proxy",  # This is optional and allows us to cache responses
    )
)

columns = conn.query("DESCRIBE nba").to_df().to_dict(orient="records")

TASK_MODEL = "gpt-4o"


@braintrust.traced
async def generate_query(input):
    response = await client.chat.completions.create(
        model=TASK_MODEL,
        temperature=0,
        messages=[
            {
                "role": "system",
                "content": dedent(
                    f"""\
        You are a SQL expert, and you are given a single table named nba with the following columns:
        {", ".join(column["column_name"] + ": " + column["column_type"] for column in columns)}

        Write a SQL query corresponding to the user's request. Return just the query text, with no
        formatting (backticks, markdown, etc.).
"""
                ),
            },
            {
                "role": "user",
                "content": input,
            },
        ],
    )
    return response.choices[0].message.content


query = await generate_query("Who won the most games?")
print(query)

 

SELECT Team, COUNT(*) AS Wins
FROM nba
WHERE WINorLOSS = 'W'
GROUP BY Team
ORDER BY Wins DESC
LIMIT 1;

Awesome, let’s try running the query!

 

def execute_query(query):
    return conn.query(query).fetchdf().to_dict(orient="records")


execute_query(query)

 

[{'Team': 'GSW', 'Wins': 265}]

初始评估Eval()由三部分组成——数据、任务和分数。我们将从数据开始。

 

创建初始数据集


让我们手写几个例子来引导数据集。尝试手写问题和SQL查询/输出将是一件非常痛苦的事情,而且可能很脆弱。相反,我们只写一些问题,并尝试在没有预期输出的情况下评估输出

.

 

questions = [
    "Which team won the most games?",
    "Which team won the most games in 2015?",
    "Who led the league in 3 point shots?",
    "Which team had the biggest difference in records across two consecutive years?",
    "What is the average number of free throws per year?",
]

Task function

现在让我们编写一个任务函数。该函数应接收输入(问题)并返回输出(SQL查询和结果)。

@braintrust.traced
async def text2sql(question):
    query = await generate_query(question)
    results = None
    error = None
    try:
        results = execute_query(query)
    except duckdb.Error as e:
        error = str(e)

    return {
        "query": query,
        "results": results,
        "error": error,
    }

分数


目前,我们无法获得太多分数,但我们至少可以检查SQL查询是否有效。如果我们生成了一个无效的查询,错误字段将不为空。

async def no_error(output):
    return output["error"] is None

Eval

就是这样!现在,让我们把这些东西插在一起,进行一次评估。

from braintrust import EvalAsync

PROJECT_NAME = "LLM Eval for Text2SQL"

await EvalAsync(
    PROJECT_NAME,
    experiment_name="Initial dataset",
    data=[{"input": q} for q in questions],
    task=text2sql,
    scores=[no_error],
)
Experiment Initial dataset is running at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/Initial%20dataset
LLM Eval for Text2SQL [experiment_name=Initial dataset] (data): 5it [00:00, 33078.11it/s]
LLM Eval for Text2SQL [experiment_name=Initial dataset] (tasks):   0%|          | 0/5 [00:00<?, ?it/s]
=========================SUMMARY=========================
60.00% 'no_error' score

See results for Initial dataset at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/Initial%20dataset
EvalResultWithSummary(...)

Ok! It looks like 3/5 of our queries are valid. Let’s take a closer look in the Braintrust UI.eval results

解读结果

现在我们运行了初始eval,看起来其中两个结果有效,两个产生SQL错误,一个不正确。为了最好地利用这些结果:

让我们将好的数据捕获到数据集中。由于我们的eval管道做了生成引用查询和结果的艰苦工作,我们现在可以保存这些查询和结果,并确保我们未来所做的更改不会使结果倒退。
add to dataset

  • The incorrect query didn’t seem to get the date format correct. That would probably be improved by showing a sample of the data to the model.
invalid query

  • There are two binder errors, which may also have to do with not understanding the data format.
binder errors

Updating the eval

Let’s start by reworking our data  函数用于提取我们存储在Braintrust中的黄金数据,并用手写问题对其进行扩展。由于可能存在一些重叠,我们会自动排除数据集中已经存在的任何问题。

from braintrust import init_dataset


def load_data():
    golden_data = init_dataset(PROJECT_NAME, "Golden data")
    golden_questions = set(d["input"] for d in golden_data)
    return list(golden_data) + [
        {"input": q} for q in questions if q not in golden_questions
    ]


load_data()[0]
{'id': '614006b1-a8b1-40c2-b700-3634c4fb14f5',
 '_xact_id': '1000193117554478505',
 'created': '2024-05-29 16:23:59.989+00',
 'project_id': 'b8d44d19-7999-49b0-911b-1f0aaafc5bac',
 'dataset_id': 'a6c337e3-f7f7-4a96-8529-05cb172f847e',
 'input': 'Which team won the most games?',
 'expected': {'error': None,
  'query': "SELECT Team, COUNT(*) AS Wins\nFROM nba\nWHERE WINorLOSS = 'W'\nGROUP BY Team\nORDER BY Wins DESC\nLIMIT 1;",
  'results': [{'Team': 'GSW', 'Wins': 265}]},
 'metadata': {},
 'tags': [],
 'span_id': '614006b1-a8b1-40c2-b700-3634c4fb14f5',
 'root_span_id': '614006b1-a8b1-40c2-b700-3634c4fb14f5'}

Now, let’s tweak the prompt to include a sample of each row.

samples = conn.query("SELECT * FROM nba LIMIT 1").to_df().to_dict(orient="records")[0]


@braintrust.traced
async def generate_query(input):
    response = await client.chat.completions.create(
        model=TASK_MODEL,
        temperature=0,
        messages=[
            {
                "role": "system",
                "content": dedent(f"""\
        You are a SQL expert, and you are given a single table named nba with the following columns:

        Column | Type | Example
        -------|------|--------
        {"\n".join(f"{column['column_name']} | {column['column_type']} | {samples[column['column_name']]}" for column in columns)}

        Write a DuckDB SQL query corresponding to the user's request. Return just the query text, with no
        formatting (backticks, markdown, etc.).
"""),
            },
            {
                "role": "user",
                "content": input,
            },
        ],
    )
    return response.choices[0].message.content


print(await generate_query("Which team won the most games in 2015?"))
SELECT Team, COUNT(*) AS Wins
FROM nba
WHERE WINorLOSS = 'W' AND Date LIKE '%/15'
GROUP BY Team
ORDER BY Wins DESC
LIMIT 1;

看起来好多了!最后,让我们添加一个评分函数,将结果(如果存在)与预期结果进行比较。

from autoevals import JSONDiff, Sql
def extract_values(results):
    return [list(result.values()) for result in results]

def correct_result(output, expected):
    if (
        expected is None
        or expected.get("results") is None
        or output.get("results") is None
    ):
        return None
    return JSONDiff()(
        output=extract_values(output["results"]),
        expected=extract_values(expected["results"]),
    ).score


def correct_sql(input, output, expected):
    if expected is None or expected.get("query") is None or output.get("query") is None:
        return None
    return Sql()(input=input, output=output["query"], expected=expected["query"]).score

Great. Let’s plug these pieces together and run an eval!

await EvalAsync(
    PROJECT_NAME,
    experiment_name="With samples",
    data=load_data,
    task=text2sql,
    scores=[no_error, correct_result, correct_sql],
)
Experiment With samples is running at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/With%20samples
LLM Eval for Text2SQL [experiment_name=With samples] (data): 5it [00:00, 17848.10it/s]
LLM Eval for Text2SQL [experiment_name=With samples] (tasks):   0%|          | 0/5 [00:00<?, ?it/s]
=========================SUMMARY=========================
With samples compared to Initial dataset:
80.00% (+20.00%) 'no_error'       score	(1 improvements, 0 regressions)
100.00% 'correct_result' score
100.00% 'correct_sql'    score

5.78s duration

See results for With samples at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/With%20samples
EvalResultWithSummary(...)

Amazing. It looks like we removed one of the errors, and got a result for the incorrect query.updated evalLet’s add the “Which team won the most games in 2015?” row to our dataset, since its answer now looks correct.

生成更多数据

现在我们已经有了一个基本的流程,让我们生成一些数据。我们将使用数据集本身来生成预期的查询,并使用一个模型来描述查询。这是一种比让它生成查询更稳健的方法,因为我们希望模型比从头开始生成查询更准确地描述查询。

import json

from pydantic import BaseModel


class Question(BaseModel):
    sql: str
    question: str


class Questions(BaseModel):
    questions: list[Question]


logger = braintrust.init_logger("question generator")

response = await client.chat.completions.create(
    model="gpt-4o",
    temperature=0,
    messages=[
        {
            "role": "user",
            "content": dedent(f"""\
        You are a SQL expert, and you are given a single table named nba with the following columns:

        Column | Type | Example
        -------|------|--------
        {"\n".join(f"{column['column_name']} | {column['column_type']} | {samples[column['column_name']]}" for column in columns)}

        Generate SQL queries that would be interesting to ask about this table. Return the SQL query as a string, as well as the
        question that the query answers."""),
        }
    ],
    tools=[
        {
            "type": "function",
            "function": {
                "name": "generate_questions",
                "description": "Generate SQL queries that would be interesting to ask about this table.",
                "parameters": Questions.model_json_schema(),
            },
        }
    ],
    tool_choice={"type": "function", "function": {"name": "generate_questions"}},
)

generated_questions = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["questions"]
generated_questions[0]
{'sql': "SELECT Team, COUNT(*) as TotalGames, SUM(CASE WHEN WINorLOSS = 'W' THEN 1 ELSE 0 END) as Wins, SUM(CASE WHEN WINorLOSS = 'L' THEN 1 ELSE 0 END) as Losses FROM nba GROUP BY Team;",
 'question': 'What is the total number of games played, wins, and losses for each team?'}
generated_dataset = []
for q in generated_questions:
    try:
        result = execute_query(q["sql"])
        generated_dataset.append(
            {
                "input": q["question"],
                "expected": {
                    "results": result,
                    "error": None,
                    "query": q["sql"],
                },
                "metadata": {
                    "category": "Generated",
                },
            }
        )
    except duckdb.Error as e:
        print(f"Query failed: {q['sql']}", e)
        print("Skipping...")

generated_dataset[0]
Query failed: SELECT Team, AVG(FieldGoals.) as AvgFieldGoalPercentage, AVG(X3PointShots.) as Avg3PointPercentage, AVG(FreeThrows.) as AvgFreeThrowPercentage FROM nba GROUP BY Team; Parser Error: syntax error at or near ")"
Skipping...
Query failed: SELECT Team, AVG(Opp.FieldGoals.) as AvgOppFieldGoalPercentage, AVG(Opp.3PointShots.) as AvgOpp3PointPercentage, AVG(Opp.FreeThrows.) as AvgOppFreeThrowPercentage FROM nba GROUP BY Team; Parser Error: syntax error at or near ")"
Skipping...
{'input': 'What is the total number of games played, wins, and losses for each team?',
 'expected': {'results': [{'Team': 'ATL',
    'TotalGames': 328,
    'Wins': 175.0,
    'Losses': 153.0},
   {'Team': 'CHI', 'TotalGames': 328, 'Wins': 160.0, 'Losses': 168.0},
   {'Team': 'NYK', 'TotalGames': 328, 'Wins': 109.0, 'Losses': 219.0},
   {'Team': 'POR', 'TotalGames': 328, 'Wins': 185.0, 'Losses': 143.0},
   {'Team': 'DEN', 'TotalGames': 328, 'Wins': 149.0, 'Losses': 179.0},
   {'Team': 'UTA', 'TotalGames': 328, 'Wins': 177.0, 'Losses': 151.0},
   {'Team': 'BRK', 'TotalGames': 328, 'Wins': 107.0, 'Losses': 221.0},
   {'Team': 'CHO', 'TotalGames': 328, 'Wins': 153.0, 'Losses': 175.0},
   {'Team': 'DAL', 'TotalGames': 328, 'Wins': 149.0, 'Losses': 179.0},
   {'Team': 'LAC', 'TotalGames': 328, 'Wins': 202.0, 'Losses': 126.0},
   {'Team': 'DET', 'TotalGames': 328, 'Wins': 152.0, 'Losses': 176.0},
   {'Team': 'GSW', 'TotalGames': 328, 'Wins': 265.0, 'Losses': 63.0},
   {'Team': 'IND', 'TotalGames': 328, 'Wins': 173.0, 'Losses': 155.0},
   {'Team': 'MIA', 'TotalGames': 328, 'Wins': 170.0, 'Losses': 158.0},
   {'Team': 'MIL', 'TotalGames': 328, 'Wins': 160.0, 'Losses': 168.0},
   {'Team': 'SAC', 'TotalGames': 328, 'Wins': 121.0, 'Losses': 207.0},
   {'Team': 'OKC', 'TotalGames': 328, 'Wins': 195.0, 'Losses': 133.0},
   {'Team': 'PHI', 'TotalGames': 328, 'Wins': 108.0, 'Losses': 220.0},
   {'Team': 'PHO', 'TotalGames': 328, 'Wins': 107.0, 'Losses': 221.0},
   {'Team': 'SAS', 'TotalGames': 328, 'Wins': 230.0, 'Losses': 98.0},
   {'Team': 'BOS', 'TotalGames': 328, 'Wins': 196.0, 'Losses': 132.0},
   {'Team': 'HOU', 'TotalGames': 328, 'Wins': 217.0, 'Losses': 111.0},
   {'Team': 'LAL', 'TotalGames': 328, 'Wins': 99.0, 'Losses': 229.0},
   {'Team': 'MIN', 'TotalGames': 328, 'Wins': 123.0, 'Losses': 205.0},
   {'Team': 'TOR', 'TotalGames': 328, 'Wins': 215.0, 'Losses': 113.0},
   {'Team': 'CLE', 'TotalGames': 328, 'Wins': 211.0, 'Losses': 117.0},
   {'Team': 'MEM', 'TotalGames': 328, 'Wins': 162.0, 'Losses': 166.0},
   {'Team': 'NOP', 'TotalGames': 328, 'Wins': 157.0, 'Losses': 171.0},
   {'Team': 'ORL', 'TotalGames': 328, 'Wins': 114.0, 'Losses': 214.0},
   {'Team': 'WAS', 'TotalGames': 328, 'Wins': 179.0, 'Losses': 149.0}],
  'error': None,
  'query': "SELECT Team, COUNT(*) as TotalGames, SUM(CASE WHEN WINorLOSS = 'W' THEN 1 ELSE 0 END) as Wins, SUM(CASE WHEN WINorLOSS = 'L' THEN 1 ELSE 0 END) as Losses FROM nba GROUP BY Team;"},
 'metadata': {'category': 'Generated'}}

Awesome, let’s update our dataset with the new data.

def load_data():
    golden_data = init_dataset(PROJECT_NAME, "Golden data")
    golden_questions = set(d["input"] for d in golden_data)
    return (
        [{**x, "metadata": {"category": "Golden data"}} for x in golden_data]
        + [
            {"input": q, "metadata": {"category": "Handwritten question"}}
            for q in questions
            if q not in golden_questions
        ]
        + [x for x in generated_dataset if x["input"] not in golden_questions]
    )
await EvalAsync(
    PROJECT_NAME,
    experiment_name="Generated data",
    data=load_data,
    task=text2sql,
    scores=[no_error, correct_result, correct_sql],
)
Experiment Generated data is running at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/Generated%20data
LLM Eval for Text2SQL [experiment_name=Generated data] (data): 13it [00:00, 36916.69it/s]
LLM Eval for Text2SQL [experiment_name=Generated data] (tasks):   0%|          | 0/13 [00:00<?, ?it/s]
=========================SUMMARY=========================
Generated data compared to With samples:
84.62% (-) 'no_error'       score	(0 improvements, 0 regressions)
69.72% (-) 'correct_result' score	(0 improvements, 0 regressions)
63.64% (-) 'correct_sql'    score	(0 improvements, 0 regressions)

4.23s (-155.93%) 'duration'	(0 improvements, 0 regressions)

See results for Generated data at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/Generated%20data
EvalResultWithSummary(...)
eval 3

太神了!现在我们有一个丰富的数据集可以使用,也有一些故障需要调试。从这里开始,你可以尝试调查一些生成的数据是否需要改进,或者尝试调整提示以提高准确性,或者甚至可能是更冒险的事情,比如将错误反馈给模型,让它迭代出更好的查询。最重要的是,我们有一个很好的工作流程来迭代应用程序和数据集。

Trying GPT-4

Just for fun, let’s wrap things up by trying out GPT-4. All we need to do is switch the model name, and run our Eval() function again.

 

TASK_MODEL = "gpt-4"

await EvalAsync(
    PROJECT_NAME,
    experiment_name="Try gpt-4",
    data=load_data,
    task=text2sql,
    scores=[no_error, correct_result, correct_sql],
)
Experiment Try gpt-4 is running at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/Try%20gpt-4
LLM Eval for Text2SQL [experiment_name=Try gpt-4] (data): 13it [00:00, 25491.33it/s]
LLM Eval for Text2SQL [experiment_name=Try gpt-4] (tasks):   0%|          | 0/13 [00:00<?, ?it/s]
=========================SUMMARY=========================
Try gpt-4 compared to Generated data:
46.14% (-23.58%) 'correct_result' score	(1 improvements, 5 regressions)
84.62% (-) 'no_error'       score	(1 improvements, 1 regressions)
54.55% (-09.09%) 'correct_sql'    score	(1 improvements, 2 regressions)

6.77s (+254.27%) 'duration'	(0 improvements, 1 regressions)

See results for Try gpt-4 at https://www.braintrust.dev/app/braintrustdata.com/p/LLM%20Eval%20for%20Text2SQL/experiments/Try%20gpt-4
EvalResultWithSummary(...)

有趣。看起来那不是灌篮。每个分数都有一些回归:Braintrust可以轻松过滤回归,并查看并排差异:diff

结论

在这本烹饪书中,我们介绍了为text2sql应用程序构建数据集的过程。我们从几个手写示例开始,并使用LLM在数据集上迭代以生成更多示例。我们使用eval框架来跟踪我们的进度,并对模型和数据集进行迭代以改进结果。最后,我们尝试了一个更强大的模型,看看它是否可以改善结果。快乐撤离!

文章链接