fooSynaptic

Any problem, please Contact me: 2313990450@qq.com

本文深入解读 LLM 智能体领域的三个重要应用扩展:VOYAGER(终身学习)、Project Sid(AI文明)和 Agent Hospital(可进化医疗智能体)。


一、VOYAGER:开放世界具身终身学习智能体

论文: An Open-Ended Embodied Agent with Large Language Models
会议: NeurIPS 2023 (FMDM Workshop)
作者: Guanzhi Wang 等 (NVIDIA, Caltech, UT Austin)
项目主页: voyager.minedojo.org

1.1 核心创新

VOYAGER 是首个 LLM 驱动的具身终身学习智能体,在 Minecraft 中持续探索世界、获取技能、做出新发现,无需人类干预。

三大核心组件:

组件 功能 技术实现
自动课程 提出适当难度的任务 GPT-4 + 探索进度 + 智能体状态
技能库 存储和检索可复用代码 向量数据库 + 嵌入检索
迭代提示 自我改进代码生成 环境反馈 + 执行错误 + 自我验证

1.2 系统架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
┌─────────────────────────────────────────────────────────────────────┐
│ VOYAGER 系统架构 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────┐ │
│ │ GPT-4 API │◀──────────────────────────────────┐ │
│ │ (黑盒调用) │ │ │
│ └───────────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌───────────────┐ ┌───────────────┐ ┌────────┴────────┐ │
│ │ 自动课程生成 │ │ 代码生成 │ │ 自我验证 │ │
│ │ (GPT-4提示) │ │ (GPT-4提示) │ │ (GPT-4提示) │ │
│ └───────────────┘ └───────────────┘ └─────────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌───────────────┐ ┌───────────────┐ ┌─────────────────┐ │
│ │ 任务队列 │ │ Minecraft │ │ 技能库 │ │
│ └───────────────┘ │ 环境执行 │ │ (向量数据库) │ │
│ └───────────────┘ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘

1.3 自动课程系统

设计理念: 自下而上展开,由好奇心驱动

输入提示组件:

  1. 指令: 鼓励多样化行为并施加约束
  2. 智能体当前状态: 物品栏、装备、位置、生命值等
  3. 先前任务记录: 已完成和失败的任务
  4. 额外上下文: GPT-3.5 自问自答

示例提示:

“我的最终目标是发现尽可能多的多样化事物…下一个任务不应该太难,因为我可能还没有必要的资源或学会足够的技能来完成它。”

1.4 技能库机制

技能表示: 可执行的 JavaScript 代码

1
2
3
4
5
6
7
8
9
10
11
// 示例技能: 制作木镐
async function craftWoodenPickaxe(bot) {
// 首先获取木材
await mineBlock(bot, "oak_log", 1);
// 制作木板
await craftItem(bot, "oak_planks", 4);
// 制作木棍
await craftItem(bot, "stick", 2);
// 制作木镐
await craftItem(bot, "wooden_pickaxe", 1);
}

存储与检索:

  • : 程序描述的嵌入向量(GPT-3.5生成)
  • : 可执行的JavaScript代码
  • 检索: 余弦相似度 + 任务上下文

1.5 迭代提示机制

三种反馈类型:

反馈类型 来源 作用
环境反馈 程序执行日志 显示中间进度,如”需要多7个铁锭”
执行错误 程序解释器 揭示语法错误和无效操作
自我验证 GPT-4评论家 判断任务完成,提供改进建议

代码生成的12个提示组件:

# 组件 描述
1 代码生成指南 编写规范和约束
2 控制原语API 高级API(exploreUntil, mineBlock等)
3 Mineflayer API 底层游戏控制API
4 检索的技能 从技能库检索的相关代码
5 上一轮代码 用于迭代改进
6 环境反馈 聊天日志中的执行信息
7 执行错误 解释器错误信息
8 自我验证批评 验证模块的反馈
9 智能体状态 物品栏、位置、生命值等
10 任务 自动课程提出的任务
11 任务上下文 GPT-3.5生成的解决建议
12 思维链提示 要求解释→计划→代码的顺序

1.6 实验结果

vs 基线方法:

指标 VOYAGER AutoGPT ReAct Reflexion
独特物品发现 63 19 ~10 ~10
倍数 3.3x 1x - -

科技树解锁速度:

级别 VOYAGER AutoGPT 提升
木制工具 6分钟 92分钟 15.3x
石制工具 11分钟 94分钟 8.5x
铁制工具 21分钟 135分钟 6.4x
钻石工具 102分钟 N/A 唯一成功

消融实验结论:

  • 自动课程至关重要:移除后物品发现下降93%
  • 自我验证最重要:移除后物品发现下降73%
  • GPT-4 vs GPT-3.5:GPT-4获得5.7倍更多独特物品

1.7 关键洞见

代码即记忆: VOYAGER 将”学习”转化为”运行时组合”——通过检索已有技能并迭代改进代码,而不是更新模型权重。

传统方法 VOYAGER
微调模型参数 黑盒API调用
隐式知识存储 显式代码技能库
难以解释 代码可读可执行
灾难性遗忘 技能永久保存

二、Project Sid:迈向AI文明的多智能体模拟

论文: Many-agent simulations toward AI civilization
机构: Altera.AL
发布日期: 2024年10月
规模: 10-1000+ 智能体

2.1 核心问题

为什么我们应该尝试构建AI文明?

为了让智能体与人类社会共存,他们需要是自主的和协作的。文明进步——通过智能体在人类文明中共存和进步的能力来衡量——代表了AI智能体能力的终极基准

2.2 构建AI文明的挑战

挑战 问题描述
单智能体不进展 幻觉积累、陷入重复动作循环
多智能体不协调 错误沟通导致幻觉传播
缺乏基准 无法量化文明进步

一致性问题示例:

智能体Abby被Bob要求”给我一把镐”时,聊天模块回应”当然可以!”,但函数调用模块选择”探索”。Bob可能然后尝试用想象的镐采矿。

2.3 PIANO 架构

PIANO = Parallel Information Aggregation via Neural Orchestration
(通过神经编排的并行信息聚合)

两大设计原则:

原则 问题 解决方案
并发性 慢速思考不应阻止快速反应 多模块并行运行,不同时间尺度
一致性 多输出模块可能产生冲突 认知控制器(CC)作为瓶颈

10个核心模块:

模块 功能
记忆 存储/检索对话、动作、观察
动作意识 评估自身状态和性能
目标生成 基于经验创建新目标
社会意识 解释他人社会线索
说话 解释和生成语音
技能执行 执行环境中的动作
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
┌─────────────────────────────────────────────────────────────┐
│ PIANO 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 并发模块: 认知控制器(瓶颈): │
│ ┌─────────┐ ┌───────────────┐ │
│ │ 记忆 │──────────────▶ │ │ │
│ ├─────────┤ │ 信息综合 │ │
│ │ 社会 │──────────────▶ │ ↓ │ │
│ ├─────────┤ │ 高层决策 │ │
│ │ 目标 │──────────────▶ │ ↓ │ │
│ ├─────────┤ │ 决策广播 │ │
│ │ 动作 │──────────────▶ │ │ │
│ └─────────┘ └───────────────┘ │
│ ↑ │ │
│ │ ▼ │
│ │ ┌───────────────┐ │
│ │ │ 输出模块 │ │
│ │ │ 说话/动作/... │ │
│ │ └───────────────┘ │
│ └─────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘

2.4 文明进步基准

基准1:专业化

定义: 智能体自主发展专业角色

三个标准:

  1. 在选择和转换角色方面表现自主性
  2. 专业化通过互动涌现,无需明确指导
  3. 角色体现在与专业化一致的行为中

实验结果 (30智能体,20分钟):

现象 发现
角色多样性 农民、矿工、工程师、守卫、探险家、铁匠
角色持久性 每个智能体角色在时间上大体稳定
角色-行为一致性 艺术家专注采花,农民专注收集种子

武术社会 vs 艺术社会:

  • 武术社会特有角色:侦察兵、战略家
  • 艺术社会特有角色:策展人、收藏家

基准2:集体规则

定义: 智能体遵守和改变法律

实验设置:

  • 25个选民智能体
  • 3个影响者(亲税/反税)
  • 1个选举经理
  • 税法:交20%物品到社区箱子

关键发现:

现象 结果
遵守法律 平均交付~20%物品
影响者影响 亲税/反税影响者显著改变选民态度
宪法变更 税率从20%降到5-10%时,行为相应调整

基准3:文化传播

实验规模: 500智能体 (6城镇 + 农村)

关键现象:

现象 发现
模因多样性 不同城镇流行不同模因
模因动态 流行度随时间上升和下降
宗教传播 20个牧师传播”飞天面条神教”
皈依扩散 皈依者数量持续增加,未饱和

2.5 量化结果

指标 结果
30分钟内获取物品 平均17个独特物品
4小时物品饱和 ~320个(1/3总物品)
社会感知准确性 r = 0.81(5+观察者)
最大规模 1000+ 智能体

2.6 局限性

  1. 缺乏视觉推理: 限制空间导航和建造能力
  2. 缺乏内在驱动: 无生存、好奇心等催化社会发展
  3. 无法从头涌现: 基于预训练知识,无法模拟创新涌现

三、Agent Hospital:可进化的医疗智能体

论文: A Simulacrum of Hospital with Evolvable Medical Agents
机构: 清华大学 AIR
发布日期: 2024年5月

3.1 核心创新

医生培养的两个阶段:

阶段 内容 时长
阶段1 知识获取(学校) ~20年
阶段2 技能获取(医院) ~3年

现有医疗AI主要集中在阶段1(如Med-PaLM)。Agent Hospital 解决阶段2:从实践中获取专业技能。

3.2 系统架构

Agent Hospital = 虚拟医院,所有患者、护士、医生都是LLM驱动的智能体

系统规模:

指标 数量
科室 32个
覆盖疾病 339种
医生智能体 42个
护士智能体 4个
功能区域 16个

3.3 治疗闭环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
┌─────────────────────────────────────────────────────────────┐
│ 治疗闭环 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 疾病发作 ──▶ 2. 分诊 ──▶ 3. 挂号 │
│ │ │
│ ▼ │
│ 8. 康复反馈 ◀── 7. 取药 ◀── 6. 诊断 │
│ │ ▲ │
│ │ │ │
│ └─────▶ 4. 就诊 ──▶ 5. 检查 ─┘ │
│ │
│ 额外事件:医生智能体在非工作时间阅读医学书籍 │
│ │
└─────────────────────────────────────────────────────────────┘

3.4 SEAL 框架

SEAL = Simulacrum-based Evolutionary Agent Learning
(基于仿真的进化智能体学习)

两个组件:

组件 功能
仿真系统构建 构建虚拟世界,自动生成数据
智能体进化 从成功/失败中学习

3.5 MedAgent-Zero 进化机制

“Zero”含义: 不使用任何人工标注数据

学习来源:

来源 内容 作用
成功案例 正确的诊断和治疗 作为参考案例检索
失败案例 错误的诊断或治疗 反思避免重复错误
医学教材 专业医学知识 巩固和整合知识
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────────┐
│ MedAgent-Zero 进化流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 治疗患者智能体 │
│ ↓ │
│ 2. 收到患者反馈(康复/未康复) │
│ ↓ │
│ ┌─────────────────┬─────────────────┐ │
│ │ 成功案例 │ 失败案例 │ │
│ │ │ │ │
│ │ 存储为参考案例 │ 反思获取经验 │ │
│ │ 用于未来检索 │ 避免重复错误 │ │
│ └─────────────────┴─────────────────┘ │
│ ↓ │
│ 3. 阅读医学教材巩固知识 │
│ ↓ │
│ 4. 能力持续提升 │
│ │
└─────────────────────────────────────────────────────────────┘

3.6 实验结果

进化效果 (诊断准确率):

治疗患者数 准确率 提升
0 (初始) ~60% -
1,000 ~72% +20%
10,000 ~85% +42%
50,000 ~93% +55%

MedQA 基准测试 (美国医师执照考试):

方法 准确率
GPT-4 (少样本) 78.4%
Med-PaLM 2 86.5%
Agent Hospital (进化后) 88.7%

亮点: 无需使用基准的标注训练数据!

3.7 与 Generative Agents 的关系

维度 Generative Agents Agent Hospital
灵感来源 原创 受GA启发
环境 虚拟小镇 虚拟医院
智能体数量 25个 46+
任务类型 社交模拟 医疗诊断
能力进化 有(核心创新)
评估方式 定性 定量(MedQA)

3.8 SEAL 的通用性

方法论公式:

1
领域工作流程 → 构建仿真系统 → 自动生成数据 → 智能体进化

优势:

优势 说明
无需人工标注 数据由虚拟世界自动生成
领域适应 直接适应特定应用需求
成本低 减少数据标注开销
可扩展 可模拟大量场景和时间

潜在应用: 法律咨询、金融投资、教育培训、客户服务


四、三大应用扩展对比

4.1 核心差异

维度 VOYAGER Project Sid Agent Hospital
核心目标 终身学习技能 AI文明模拟 医疗智能体进化
环境 Minecraft Minecraft 虚拟医院
智能体数量 1 10-1000+ 46+
时间跨度 数小时 4小时+ 持续
学习机制 技能库积累 社会互动 经验反思

4.2 创新贡献

论文 核心创新
VOYAGER 代码即记忆,技能可组合复用
Project Sid 文明进步基准:专业化、规则、文化
Agent Hospital 智能体能力可进化,虚拟技能迁移现实

4.3 适用场景

场景 推荐方法 原因
开放世界游戏 VOYAGER 技能积累和终身学习
社会科学研究 Project Sid 大规模社会动态模拟
专业领域AI Agent Hospital 从实践中持续进化
多智能体协作 Project Sid PIANO架构支持一致性

五、技术演进路线

5.1 从基础到应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
基础框架 (2022-2023):
├── ReAct: 推理+行动
├── Reflexion: 语言反馈学习
└── Generative Agents: 记忆+反思

应用扩展 (2023-2024):
├── VOYAGER: 终身学习 + 技能库
├── Project Sid: 大规模文明模拟
└── Agent Hospital: 专业领域进化

未来趋势 (2025+):
├── Agent OS化: AutoGen, LangGraph
├── 多模态融合: 视觉+语言+行动
└── 商业化部署: Operator, Claude

5.2 规模演进

时间 论文 智能体数量 涌现现象
2023/04 Generative Agents 25 社交行为
2023/05 VOYAGER 1 终身学习
2024/05 Agent Hospital 46+ 能力进化
2024/10 Project Sid 500-1000+ 文明进步

5.3 关键技术突破

突破 论文 意义
代码作为记忆 VOYAGER 可执行、可组合的知识表示
文明进步基准 Project Sid 量化多智能体社会能力
无标注进化 Agent Hospital 从实践中自动学习
千智能体规模 Project Sid 验证大规模可行性

六、实践建议

6.1 技术选型

需求 推荐技术栈
单智能体技能学习 VOYAGER (技能库 + 迭代提示)
多智能体协作 Project Sid (PIANO架构)
专业领域应用 Agent Hospital (SEAL框架)
通用任务完成 ReAct + Reflexion

6.2 架构设计

理想组合:

1
2
3
4
理想智能体 = VOYAGER的技能库
+ Project Sid的社会意识
+ Agent Hospital的进化机制
+ Generative Agents的记忆系统

6.3 规模化考虑

规模 关键挑战 解决方案
1-10 单智能体能力 技能库 + 反思
10-50 协调一致性 PIANO架构
50-500 计算资源 并行模块
500+ 涌现管理 文明基准

七、关键论文原文引用

VOYAGER

“VOYAGER is the first LLM-powered embodied lifelong learning agent that explores the world, acquires diverse skills, and makes novel discoveries without human intervention.”

Project Sid

“We show how 10-1000+ AI agents behave and progress in agent societies. These simulations reveal that agents can achieve meaningful progress—autonomously developing specialized roles, adhering to and modifying collective rules, and engaging in cultural and religious propagation.”

Agent Hospital

“Doctor agents can evolve by treating a large number of patient agents, without the need for manually curated training data. After treating tens of thousands of patient agents (which may take several years for real-world doctors), the evolved doctor agents surpassed state-of-the-art medical AI methods on the MedQA benchmark.”


返回总览 | 上一篇:基础框架篇

本文深入解读 LLM 智能体领域的三大基础框架:ReAct、Reflexion 和 Generative Agents,分析它们的核心架构、技术创新和应用场景。


一、ReAct:推理与行动的协同

论文: Synergizing Reasoning and Acting in Language Models
会议: ICLR 2023
作者: Shunyu Yao 等 (普林斯顿大学 & Google Research)
被引用: 32次(领域内最高)

1.1 核心思想

人类智能的一个独特特征是能够无缝结合面向任务的动作与语言推理。考虑在厨房做菜的例子:

  • 在任何两个具体动作之间,我们可能用语言进行推理,以跟踪进度
  • 处理异常或根据情况调整计划
  • 认识到何时需要外部信息

ReAct 的核心理念:将智能体的动作空间扩展为 Â = A ∪ L

其中:

  • A = 原始动作空间(与环境交互)
  • L = 语言空间(思想/推理轨迹)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
┌─────────────────────────────────────────────────────────────┐
│ ReAct 工作流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 问题 ──▶ 思想1 ──▶ 动作1 ──▶ 观察1 │
│ │ │
│ ▼ │
│ 思想2 ──▶ 动作2 ──▶ 观察2 │
│ │ │
│ ▼ │
│ 思想3 ──▶ 动作3 ──▶ 答案 │
│ │
│ 思想:不影响环境,用于推理和规划 │
│ 动作:与环境交互,获取新信息 │
│ 观察:环境反馈 │
│ │
└─────────────────────────────────────────────────────────────┘

1.2 推理轨迹类型

类型 示例
问题分解 “我需要搜索x,找到y,然后找到z”
信息提取 “x于1844年创立”
常识推理 “1844 < 1989,所以x更老”
搜索重构 “也许我可以搜索/查找x来代替”
答案综合 “…所以答案是x”

1.3 实验结果

知识密集型任务 (HotpotQA, Fever):

方法 HotpotQA EM Fever Acc
CoT 29.4 56.3
Act 25.7 58.9
ReAct 27.4 60.9
ReAct + CoT-SC 35.1 64.6

决策任务 (ALFWorld, WebShop):

方法 ALFWorld WebShop
BUTLER (模仿学习) 37% -
Act 45% -
ReAct 71% +10%

1.4 成功与失败模式

类型 ReAct CoT
成功-真阳性 94% 86%
成功-假阳性 6% 14%
失败-幻觉 0% 56%
失败-推理错误 47% 16%

关键洞见

  • ReAct 的幻觉率为0%,因为可以与外部知识库交互验证
  • CoT 的主要失败模式是幻觉(56%)
  • ReAct 的主要失败是推理错误(47%),因结构约束降低了灵活性

1.5 ReAct 示例

问题: 除了Apple Remote,还有什么设备可以控制Apple Remote最初设计交互的程序?

1
2
3
4
5
6
7
8
9
10
思想1: 我需要搜索Apple Remote,找到它最初设计交互的程序
动作1: Search[Apple Remote]
观察1: Apple Remote是一个遥控器...最初设计用于控制Front Row...

思想2: Apple Remote最初设计控制Front Row。我需要搜索Front Row
动作2: Search[Front Row software]
观察2: Front Row是一个已停产的媒体中心软件...由Apple Remote或键盘功能键控制

思想3: Front Row由Apple Remote或键盘功能键控制。所以答案是键盘功能键
动作3: Finish[keyboard function keys]

二、Reflexion:语言反馈强化学习

论文: Language Agents with Verbal Reinforcement Learning
会议: NeurIPS 2023
作者: Noah Shinn 等 (Northeastern & Princeton)
被引用: 17次

2.1 核心思想

传统强化学习通过梯度更新权重学习,需要大量样本和昂贵的微调。Reflexion 提出用语言反馈替代梯度信号:

1
2
3
4
5
6
7
8
9
10
11
12
13
┌─────────────────────────────────────────────────────────────┐
│ 传统RL vs Reflexion │
├─────────────────────────────────────────────────────────────┤
│ │
│ 传统RL: │
│ 状态 ──▶ 动作 ──▶ 奖励 ──▶ 梯度更新 ──▶ 参数变化 │
│ │
│ Reflexion: │
│ 状态 ──▶ 动作 ──▶ 反馈 ──▶ 语言反思 ──▶ 记忆存储 │
│ │ │
│ └──────────────────▶ 下次尝试 │
│ │
└─────────────────────────────────────────────────────────────┘

2.2 三大核心组件

Actor(执行者)

基于 LLM 构建,生成文本和动作。可以是:

  • Chain of Thought (CoT)
  • ReAct
  • 其他智能体架构

Evaluator(评估者)

评估 Actor 生成输出的质量:

  • 精确匹配(EM)评分: 推理任务
  • 预定义启发式: 决策任务
  • LLM作为评估者: 编程任务

Self-Reflection(自我反思)

核心创新:将稀疏奖励信号转化为详细的语言化反思

1
2
3
4
5
6
7
8
9
10
输入: 
- 任务描述
- 失败轨迹: [动作1, 观察1, 动作2, 观察2, ...]
- 奖励信号: 二元或标量
- 历史反思

输出:
- 错误诊断
- 改进方案
- 具体建议

2.3 记忆机制

类型 内容 作用
短期记忆 当前轨迹历史 即时决策
长期记忆 自我反思输出(滑动窗口) 跨尝试学习

2.4 实验结果

AlfWorld 决策任务:

方法 成功率
ReAct ~50%
ReAct + Reflexion (启发式) 97% (130/134)
ReAct + Reflexion (GPT) 88%

HotPotQA 推理任务:

方法 准确率提升
CoT (GT) 基准
+ 情景记忆 +6%
+ Reflexion +14%

编程任务 (HumanEval):

基准 先前SOTA GPT-4 Reflexion
HumanEval (PY) 65.8% 80.1% 91.0%
HumanEval (RS) - 60.0% 68.0%
Leetcode Hard - 7.5% 15.0%

2.5 Reflexion 示例

任务: 用台灯检查杯子

第一次尝试(失败):

1
2
3
4
5
6
7
8
> go to drawer 1
抽屉1是关着的
> ... (搜索杯子)
> take mug 1 from desk 1
你从桌子1拿起杯子1
> use desklamp 1
什么都没发生
状态: 失败

反思:

在这个环境中,我的计划是先找杯子再找台灯并使用。然而,任务说的是用台灯检查杯子。我应该先找台灯,再找杯子。我注意到台灯在桌子1上。下一次尝试中,我会先去桌子1,找到台灯,然后找杯子并用台灯检查它。

第二次尝试(成功):

1
2
3
4
5
6
7
> go to desk 1
桌子1上有:台灯1、杯子1...
> take mug 1 from desk 1
你从桌子1拿起杯子1
> use desklamp 1
你打开了台灯1
状态: 成功

三、Generative Agents:人类行为的交互式拟像

论文: Interactive Simulacra of Human Behavior
会议: UIST 2023
作者: Joon Sung Park 等 (斯坦福大学 & Google)
被引用: 20次

3.1 核心思想

构建模拟可信人类行为的计算软件智能体:

  • 醒来、做早餐、去上班
  • 艺术家画画,作者写作
  • 形成观点,注意彼此,主动发起对话
  • 回忆和反思过去,规划未来

3.2 核心架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
┌──────────────────────────────────────────────────────────────────┐
│ 记忆流 (Memory Stream) │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
│ │ 观察 │ │ 反思 │ │ 计划 │ │
│ │ Observations│ │ Reflections │ │ Plans │ │
│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │
└──────────────────────────────────────────────────────────────────┘


┌─────────────────────────────────────┐
│ 记忆检索 │
│ (时近性 + 重要性 + 相关性) │
└──────────────┬──────────────────────┘


┌─────────────────────────────────────┐
│ 行为生成 │
│ (Plan, React, Dialogue) │
└─────────────────────────────────────┘

3.3 记忆检索公式

组件 描述 实现
时近性 最近访问的记忆分数更高 指数衰减函数,衰减因子0.995
重要性 区分平凡记忆和核心记忆 LLM评分1-10
相关性 与当前情况相关的记忆 嵌入向量余弦相似度

3.4 反思机制

触发条件: 重要性分数总和 > 150(约每天2-3次)

反思生成过程:

  1. 确定反思内容: 用最近100条记忆查询

    • 提示:”仅根据上述信息,我们可以回答哪3个最突出的高层次问题?”
  2. 检索相关记忆: 使用问题作为检索查询

  3. 提取洞察:

    • 输出格式:”洞察(因为1, 5, 3)”

反思树: 叶节点=观察,非叶节点=越来越抽象的反思

1
2
3
4
5
          [Klaus对研究充满热情]  ← 元反思
/ \
[Klaus致力于研究] [Klaus和Maria有共同兴趣] ← 反思
/ \ / \
[写论文] [读书] [讨论项目] [图书馆相遇] ← 观察

3.5 规划机制

递归分解日程:

  1. 粗略计划: 一天的议程大纲
  2. 小时级分解: 每小时的活动块
  3. 细粒度分解: 5-15分钟的具体动作

示例:

  • 粗略:”下午1:00到5:00创作新音乐”
  • 小时级:”下午1:00:开始为音乐创作头脑风暴…”
  • 细粒度:”下午4:00:拿一些小零食。下午4:05:在工作区周围短暂散步…”

3.6 涌现的社会行为

实验设置: 25个智能体,Smallville小镇

涌现现象:

现象 描述
信息扩散 Sam的市长候选资格传播到32%智能体
关系记忆 智能体记住新认识的人及对话内容
协调活动 Isabella的情人节派对:5人自发出席
网络密度 从0.167增加到0.74

情人节派对案例:

  1. Isabella计划2月14日下午5-7点的派对
  2. 她花一天装饰咖啡馆
  3. Maria帮忙装饰,并邀请暗恋的Klaus
  4. 最终5个智能体在正确时间出现

3.7 评估结果

条件 TrueSkill评分
完整架构 29.89
无反思 26.88
无反思、无计划 25.64
人类众包 22.95
无记忆(先前SOTA) 21.21

效应大小: 完整架构 vs 先前SOTA = 8个标准差


四、三大框架对比

4.1 核心差异

维度 ReAct Reflexion Generative Agents
核心目标 任务完成 从失败学习 行为拟真
知识表示 推理轨迹 语言化反思 记忆流
学习方式 单次推理 跨尝试积累 持续记忆+反思
时间跨度 单任务 多次尝试 天/周级
是否微调

4.2 记忆机制对比

特性 ReAct Reflexion Generative Agents
存储内容 当前轨迹 语言化反思 观察+反思+计划
存储形式 上下文 滑动窗口 记忆流列表
检索方式 时间顺序 时近性+重要性+相关性
失败经验 ✅ 重点 ⚠️ 不强调
抽象层次 单层 双层 多层(反思树)

4.3 反思机制对比

特性 ReAct Reflexion Generative Agents
有无反思 ❌ 无 ✅ 核心 ✅ 核心
触发条件 - 每次失败后 重要性>150
输出 - 错误分析+改进 高层次洞察
目的 - 任务成功率 概念抽象

4.4 适用场景

场景 推荐方法 原因
知识问答 ReAct 与外部知识库交互
决策任务 Reflexion 从失败中学习
编程调试 Reflexion 需要多次尝试改进
社会模拟 Generative Agents 需要记忆和人格一致性
角色扮演 Generative Agents 需要丰富的背景记忆

五、组合使用建议

5.1 理想组合架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
┌────────────────────────────────────────────────────────────────────┐
│ 理想智能体架构 │
├────────────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ Generative Agents的记忆流 │ │
│ │ • 完整的经历记录 │ │
│ │ • 多层次反思 │ │
│ │ • 社交关系追踪 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ + │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ ReAct的推理-行动范式 │ │
│ │ • 思想与动作交替 │ │
│ │ • 与外部环境交互 │ │
│ │ • 减少幻觉 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ + │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ Reflexion的失败反思 │ │
│ │ • 失败经验的语言化 │ │
│ │ • 错误诊断与改进建议 │ │
│ │ • 跨尝试学习 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │
└────────────────────────────────────────────────────────────────────┘

5.2 实现要点

  1. 使用 ReAct 作为基础行动框架:思想+动作交替执行
  2. 添加 Generative Agents 的记忆系统:持久化所有经历
  3. 集成 Reflexion 的失败反思:从错误中学习
  4. 定期触发高层次反思:形成长期理解

六、关键论文原文引用

ReAct

“We propose ReAct — a general paradigm to combine reasoning and acting with language models for solving diverse language reasoning and decision making tasks.”

Reflexion

“Reflexion converts binary or scalar feedback from the environment into verbal feedback in the form of a textual summary, which is then added as additional context for the LLM agent in the next episode.”

Generative Agents

“Generative agents wake up, cook breakfast, and head to work; artists paint, authors write; they form opinions, notice each other, and initiate conversations; they remember and reflect on days past as they plan the next day.”


返回总览 | 下一篇:应用扩展篇

本系列是 LLM 驱动的游戏智能体领域核心论文的解读与总结,涵盖 103+ 篇论文,164 条引用关系的系统性分析。


领域概述

随着大型语言模型(LLM)的快速发展,研究者们开始探索将 LLM 作为智能体”大脑”的可能性。这些智能体不仅能理解和生成文本,还能规划、反思、与环境交互,甚至形成复杂的社会行为。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
┌─────────────────────────────────────────────────────────────┐
│ LLM 游戏智能体技术栈 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ │
│ │ 应用层 │ 游戏/模拟/机器人 │
│ └────────┬────────┘ │
│ │ │
│ ┌────────▼────────┐ │
│ │ 智能体框架 │ ReAct / Reflexion / VOYAGER │
│ └────────┬────────┘ │
│ │ │
│ ┌────────▼────────┐ │
│ │ 核心能力 │ 记忆 / 规划 / 反思 / 工具使用 │
│ └────────┬────────┘ │
│ │ │
│ ┌────────▼────────┐ │
│ │ 基础模型 │ GPT-4 / Claude / Llama │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘

核心论文引用关系

基于 103 篇论文的引用网络分析,以下是领域内最具影响力的基础性工作:

排名 论文 会议 被引用 核心贡献
1 ReAct ICLR 2023 32 推理+行动交替范式
2 Generative Agents UIST 2023 20 记忆-反思-规划架构
3 Reflexion NeurIPS 2023 17 语言反馈强化学习
4 VOYAGER NeurIPS 2023 - 技能库+终身学习

技术层次金字塔

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
                ┌─────────────────────┐
│ 🎯 上层应用 │
│ 竞技/社交/特定游戏 │
│ (Werewolf, Poker, │
│ StarCraft等) │
└──────────┬──────────┘

┌────────────────┼────────────────┐
│ │ │
┌─────────▼─────────┐ ┌────▼────┐ ┌────────▼────────┐
│ 🏗️ 中间层 │ │模拟仿真 │ │ 🤝 多智能体 │
│ 环境适配层 │ │ │ │ 协作层 │
│ (Crafter, │ │Generative│ │ │
│ Minecraft) │ │ Agents │ │ │
└─────────┬───────┘ └────┬────┘ └────────┬───────┘
│ │ │
└──────────────┼───────────────┘

┌────────▼────────┐
│ 🔧 基础框架层 │
│ │
│ • ReAct │ ← 推理+行动范式
│ • Reflexion │ ← 自我反思机制
│ • Grounding RL │ ← 环境交互学习
└─────────────────┘

系列文章目录

基础框架篇

文章 核心内容
基础框架:ReAct / Reflexion / Generative Agents 三大核心框架的详细对比分析

应用扩展篇

文章 核心内容
应用扩展:VOYAGER / Project Sid / Agent Hospital 终身学习、AI文明、医疗智能体

研究脉络时间线

2023年:基础奠定

时间 论文 会议 核心贡献
2022/10 ReAct ICLR 2023 推理与行动协同范式
2023/03 Reflexion NeurIPS 2023 语言反馈强化学习
2023/04 Generative Agents UIST 2023 25智能体小镇模拟
2023/05 VOYAGER NeurIPS 2023 Minecraft终身学习

2024年:深度发展

时间 论文 核心贡献
2024/05 Agent Hospital 可进化医疗智能体
2024/10 Project Sid 500-1000+智能体文明模拟
2024/10 Claude Computer Use 商业级计算机控制

2025年:产业化

时间 趋势 代表产品
2025 Agent OS化 AutoGen, LangGraph
2025 商业化加速 OpenAI Operator
2025 多模态融合 视觉+语言+行动

游戏类型与论文分布

游戏类型 论文数 代表论文
文字冒险 22 ReAct, Reflexion, ALFWorld
Minecraft 15 VOYAGER, GITM, JARVIS-1
社会模拟 12 Generative Agents, Project Sid
竞技游戏 15 PokéLLMon, StarCraft II
合作游戏 7 Co-LLM-Agents, TeamCraft
对话游戏 16 Werewolf, Avalon

核心技术对比

记忆机制

方法 存储内容 检索方式 特点
VOYAGER 可执行代码 语义相似度 技能可复用
Generative Agents 自然语言 时近性+重要性+相关性 多层抽象
Reflexion 语言化反思 时间顺序 失败学习

反思机制

方法 触发条件 输出 目的
VOYAGER 每轮执行后 成功/失败+批评 任务验证
Generative Agents 重要性>150 高层次洞察 概念抽象
Reflexion 每次失败后 详细反思 错误诊断

学习方式

方法 是否微调 知识形式 学习目标
传统RL ✅ 梯度更新 策略网络 奖励最大化
VOYAGER ❌ 提示工程 代码技能库 技能积累
Reflexion ❌ 语言强化 反思记忆 任务成功率

关键洞见

1. 无需微调的力量

三大核心框架(ReAct、Reflexion、Generative Agents)都证明:仅通过提示工程和运行时机制,无需微调模型参数,就能实现复杂的智能体行为

2. 记忆是关键

有效的记忆机制是智能体成功的基础:

  • VOYAGER:代码即记忆,技能可复用
  • Generative Agents:记忆即人格,反思即成长
  • Reflexion:反思即学习,失败即进步

3. 协同优于孤立

单一能力 协同能力
仅推理 → 幻觉严重 推理+行动 → ReAct
仅行动 → 无法规划 行动+反思 → Reflexion
单智能体 → 能力有限 多智能体 → 涌现社会行为

4. 规模带来涌现

规模 涌现现象
25 智能体 社交行为、信息传播 (Generative Agents)
50 智能体 长期关系、角色分化 (Project Sid)
500+ 智能体 文化传播、宗教涌现 (Project Sid)

实践建议

场景匹配

场景 推荐方法 原因
开放世界游戏 VOYAGER 技能可复用、可组合
社会模拟 Generative Agents 丰富记忆和人格一致性
决策任务 Reflexion 失败反思对决策优化关键
医疗/专业领域 Agent Hospital 可进化的专业智能体

组合架构

理想的智能体应该结合三者优势:

1
2
3
理想架构 = Generative Agents的记忆流
+ VOYAGER的技能库
+ Reflexion的失败反思

工业趋势

主要玩家

公司 产品 核心能力
OpenAI GPT-4V Agent, Operator 通用Agent能力
Anthropic Claude Computer Use 计算机自主控制
Microsoft AutoGen 0.4 企业级多Agent框架
Altera AI Project Sid AI文明模拟

开源生态

框架 定位 热度
AutoGen 多Agent对话与协作 🔥🔥🔥
LangGraph 状态机Agent工作流 🔥🔥🔥
MetaGPT 多角色软件开发 🔥🔥
CrewAI 角色扮演Agent团队 🔥🔥

参考资源

论文列表

代码仓库

论文 代码
ReAct github.com/ysymyth/ReAct
Reflexion github.com/noahshinn024/reflexion
Generative Agents github.com/joonspk-research/generative_agents
VOYAGER voyager.minedojo.org

下一篇:基础框架篇 | 应用扩展篇

Designing Data-Intensive Applications: The Big Ideas Behind Reliable, Scalable, and Maintainable Systems

这是一份关于《数据密集型应用系统设计》(DDIA) 的完整读书笔记,本书被誉为”数据系统领域的圣经”。

书籍信息

项目 内容
书名 Designing Data-Intensive Applications (DDIA)
中文名 数据密集型应用系统设计
作者 Martin Kleppmann(剑桥大学分布式系统研究员)
出版时间 2017年3月

核心主题

本书围绕三个核心概念展开:

  • 可靠性 (Reliability):系统在遇到故障时仍能正确工作
  • 可扩展性 (Scalability):系统能够应对负载增长
  • 可维护性 (Maintainability):系统易于理解、修改和扩展

全书结构

第一部分:数据系统基础

查看详细笔记

章节 核心内容
第1章 可靠性、可扩展性、可维护性的定义与实践
第2章 关系模型、文档模型、图模型的对比与选择
第3章 存储引擎原理:B-Tree、LSM-Tree、OLTP vs OLAP
第4章 数据编码格式与模式演化:JSON、Protobuf、Avro

第二部分:分布式数据

查看详细笔记

章节 核心内容
第5章 数据复制:主从、多主、无主复制策略
第6章 数据分区:分区策略、再平衡、请求路由
第7章 事务:ACID、隔离级别、分布式事务
第8章 分布式系统挑战:网络、时钟、故障模型
第9章 一致性与共识:CAP、Paxos、Raft

第三部分:衍生数据

查看详细笔记

章节 核心内容
第10章 批处理:MapReduce、Spark、数据流引擎
第11章 流处理:Kafka、Flink、事件时间与水位线
第12章 数据系统未来:数据集成、端到端正确性、伦理

学习路线

入门路线(适合初学者)

1
2
3
4
5
第1章 → 第2章 → 第3章 → 第4章(建立基础)

第5章 → 第6章(理解分布式基础)

第10章 → 第11章(了解数据处理)

进阶路线(适合有经验的开发者)

1
2
3
4
5
第7章 → 第8章 → 第9章(深入分布式)

第12章(展望未来)

回顾第1-4章填补知识空白

专题路线

方向 推荐阅读顺序
数据库 2 → 3 → 5 → 6 → 7
分布式系统 5 → 6 → 8 → 9
数据工程 3 → 10 → 11 → 12

核心要点速览

数据模型选择

1
2
3
4
5
关系模型 ──── 结构化数据、复杂查询、事务支持

文档模型 ──── 灵活模式、树状结构、局部性好

图模型 ───── 复杂关系、社交网络、知识图谱

存储引擎对比

引擎 优化目标 典型应用
B-Tree 读取优化 OLTP 数据库
LSM-Tree 写入优化 日志、时序数据
列存储 分析优化 OLAP、数据仓库

分布式系统核心权衡

处理范式对比

范式 数据特性 延迟 典型框架
批处理 有界、静态 分钟~小时 Spark, Hadoop
流处理 无界、持续 毫秒~秒 Flink, Kafka Streams

延伸资源


本读书笔记整理于 2025年,基于 DDIA 第一版内容编写

本文是 DDIA 第一部分的完整读书笔记,涵盖第 1-4 章:可靠性与可扩展性、数据模型、存储引擎、数据编码。


第1章:可靠性、可扩展性与可维护性

1.1 数据密集型应用的组成

现代数据密集型应用通常由多个组件组合:

1
2
3
4
5
6
7
8
9
10
┌─────────────────────────────────────────────────────────────┐
│ 数据密集型应用架构 │
├─────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 数据库 │ │ 缓存 │ │ 搜索索引 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 流处理 │ │ 批处理 │ │ 消息队列 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
类型 特点 主要瓶颈 典型场景
数据密集型 数据量大、复杂、变化快 磁盘I/O、网络带宽 Web应用、社交网络
计算密集型 计算量大、算法复杂 CPU处理能力 科学计算、图形渲染

1.2 可靠性(Reliability)

可靠性:系统在面对故障时,仍能正确运行

术语 英文 定义
故障 Fault 系统中某个组件偏离其规格
失效 Failure 整个系统停止为用户提供服务
容错 Fault-tolerant 系统能够处理某些类型的故障

故障类型及应对

1. 硬件故障

  • 硬盘损坏、内存故障、电源中断
  • 应对:RAID、双电源、热备份、多副本

2. 软件错误

  • 比硬件故障更难预测,可能导致系统性问题
  • 应对:全面测试、进程隔离、监控告警

3. 人为错误

  • 研究表明,运维人员的配置错误是系统中断的首要原因
  • 应对:良好的API设计、沙箱环境、快速回滚

1.3 可扩展性(Scalability)

可扩展性:系统应对负载增长的能力

描述性能:百分位数

百分位 含义 用途
p50 中位数 典型响应时间
p95 95%的请求快于此值 较高标准
p99 99%的请求快于此值 SLA标准

Twitter 案例:推拉模式

方案 描述 优点 缺点
拉模式 读取时间线时查询关注者的推文 写入简单 读取开销大
推模式 发推时写入所有关注者的缓存 读取快速 写入开销大(大V问题)
混合模式 普通用户推模式,大V拉模式 平衡读写 实现复杂

扩展策略

  • 纵向扩展:使用更强大的机器(简单但有上限)
  • 横向扩展:使用多台普通机器(需要复杂设计)
  • 弹性扩展:根据负载自动增减资源

1.4 可维护性(Maintainability)

软件的大部分成本不在于最初的开发,而在于后续的维护

方面 目标
可操作性 让运维团队能够轻松保持系统运行
简单性 让新工程师能够轻松理解系统
可演化性 让工程师能够轻松对系统进行修改

第2章:数据模型与查询语言

2.1 数据模型分层

1
2
3
4
5
6
7
8
9
┌─────────────────────────────────────────────┐
│ 应用程序层(对象、数据结构、API) │
├─────────────────────────────────────────────┤
│ 数据模型层(表、文档、图) │
├─────────────────────────────────────────────┤
│ 存储层(字节序列、内存、磁盘) │
├─────────────────────────────────────────────┤
│ 硬件层(电信号、磁场) │
└─────────────────────────────────────────────┘

2.2 关系模型

概念 描述 示例
关系(Relation) users表
元组(Tuple) 行/记录 一个用户记录
属性(Attribute) name, email, age
1
2
3
4
5
6
7
8
9
10
11
-- 规范化设计:使用外键引用
CREATE TABLE positions (
position_id INT PRIMARY KEY,
title VARCHAR(100)
);

CREATE TABLE users (
user_id INT PRIMARY KEY,
name VARCHAR(100),
position_id INT REFERENCES positions(position_id)
);

优势:数据一致性、灵活查询、ACID事务、成熟稳定
局限:对象-关系阻抗不匹配、模式僵化、扩展困难

2.3 文档模型

代表产品:MongoDB, CouchDB, Elasticsearch

1
2
3
4
5
6
7
8
9
{
"user_id": 1,
"name": "张三",
"positions": [
{"title": "软件工程师", "company": "ABC公司"},
{"title": "技术总监", "company": "XYZ公司"}
],
"skills": ["Java", "Python", "分布式系统"]
}
特性 文档模型 关系模型
数据结构 嵌套/层次化 扁平/规范化
模式 灵活(Schema-on-read) 严格(Schema-on-write)
局部性 好(整个文档一起存储) 差(需要JOIN多表)
多对多关系 较难处理 容易处理

2.4 图模型

适用场景:社交网络、知识图谱、推荐系统

1
2
3
4
5
6
7
8
// Neo4j Cypher 示例
CREATE (alice:Person {name: 'Alice', age: 30})
CREATE (bob:Person {name: 'Bob', age: 25})
CREATE (alice)-[:FOLLOWS {since: '2020-01-01'}]->(bob)

-- 找出 Alice 关注的人所关注的人
MATCH (alice:Person {name: 'Alice'})-[:FOLLOWS]->()-[:FOLLOWS]->(fof)
RETURN fof.name

2.5 数据模型选择指南

1
2
3
4
5
6
7
8
9
10
11
12
13
14
     ┌───────────────┐
│ 数据关系复杂吗?│
└───────────────┘
│ │
复杂 简单
│ │
▼ ▼
┌──────────┐ ┌─────────────┐
│多对多关系?│ │树状/层次结构?│
└──────────┘ └─────────────┘
│ │ │ │
是 否 是 否
▼ ▼ ▼ ▼
图模型 关系模型 文档模型 关系模型

第3章:存储与检索

3.1 两大存储引擎家族

类型 优化目标 典型应用 代表产品
日志结构存储 写入优化 高写入负载 LevelDB, RocksDB, Cassandra
原地更新存储 读取优化 事务处理 MySQL InnoDB, PostgreSQL

3.2 哈希索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
┌─────────────────────────────────────────────┐
│ 内存中的哈希表 │
├─────────────────────────────────────────────┤
│ key1 ──> 字节偏移量 0 │
│ key2 ──> 字节偏移量 128 │
│ key3 ──> 字节偏移量 256 │
└─────────────────────────────────────────────┘


┌─────────────────────────────────────────────┐
│ 磁盘上的日志文件 │
├─────────────────────────────────────────────┤
│ offset 0 : key1,value1 │
│ offset 128: key2,value2 │
└─────────────────────────────────────────────┘

局限:所有键必须在内存中、范围查询慢

3.3 LSM-Tree(Log-Structured Merge-Tree)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Level 0 (内存):
┌────────────┐
│ Memtable │ ← 当前写入(平衡树)
└────────────┘

Level 1 (磁盘):
┌────┐ ┌────┐ ┌────┐
│SS1 │ │SS2 │ │SS3 │ ← 较新的 SSTable
└────┘ └────┘ └────┘

Level 2 (磁盘):
┌────────┐ ┌────────┐
│ SS4 │ │ SS5 │ ← 经过合并的 SSTable
└────────┘ └────────┘

优化技术

  • 布隆过滤器:快速判断键是否存在(避免无效磁盘读取)
  • 压缩策略:Size-Tiered(写密集)、Leveled(读密集)

3.4 B-Tree

1
2
3
4
5
6
7
8
9
             ┌───────────────┐
│ [30, 70] │ ← 根节点
└───────────────┘
/ │ \
┌───────┐ ┌───────────┐ ┌───────┐
│[10,20]│ │[40,50,60] │ │[80,90]│ ← 内部节点
└───────┘ └───────────┘ └───────┘
/ │ \ / │ \ \ / │ \
叶子节点(包含实际数据或指向数据的指针)
特性 B-Tree LSM-Tree
写入 原地更新 追加写入
读取 快(一次定位) 可能需要检查多个文件
写放大 较低 较高(压缩开销)
空间利用 可能碎片化 更紧凑

3.5 OLTP vs OLAP

特性 OLTP OLAP
全称 在线事务处理 在线分析处理
主要操作 增删改查 复杂查询、聚合
访问模式 基于键的随机访问 扫描大量记录
数据量 GB~TB TB~PB
响应时间 毫秒级 秒~分钟级

列式存储优势

  • 只读取需要的列
  • 相同类型数据更易压缩
  • 向量化处理

第4章:数据编码与演化

4.1 兼容性概念

1
2
3
4
5
6
7
后向兼容(Backward Compatibility):
新代码能读取旧代码写的数据
v3 代码 ─── 读取 ───> v1 数据 ✓

前向兼容(Forward Compatibility):
旧代码能读取新代码写的数据
v1 代码 ─── 读取 ───> v3 数据 ✓

滚动升级:不停机部署新版本,同时运行新旧版本代码

4.2 编码格式对比

特性 JSON Protobuf Thrift Avro
可读性
空间效率 最高
模式 可选 必须 必须 必须
模式演化 手动 支持 支持 支持
动态模式 天然 困难 困难 支持

4.3 Protocol Buffers 示例

1
2
3
4
5
message Person {
required string user_name = 1;
optional int64 favorite_number = 2;
repeated string interests = 3;
}

兼容性规则

操作 后向兼容 前向兼容
添加可选字段
删除可选字段
添加必填字段
修改字段标签

4.4 数据流模式

1. 数据库:数据可能比代码更持久,需要能够读取多年前写入的数据

2. 服务调用(REST/RPC)

  • REST:基于HTTP,资源导向
  • gRPC:基于Protobuf,高性能

3. 消息传递

  • 解耦、缓冲、异步、可靠性
  • 代表:Kafka, RabbitMQ

本章关键要点

  1. 可靠性不是追求零故障,而是在故障发生时系统仍能正常工作
  2. 百分位数比平均值更能反映用户真实体验
  3. 抽象是管理复杂性的最重要工具
  4. 没有万能的数据模型,选择取决于应用场景
  5. LSM-Tree优化写入,B-Tree优化读取
  6. 兼容性是渐进式部署的前提

延伸阅读

  • 《Site Reliability Engineering》:Google SRE 实践
  • 《Database Internals》:数据库内部原理深入指南
  • 《Building Microservices》:服务间通信详解

本文是 DDIA 第二部分的完整读书笔记,涵盖第 5-9 章:数据复制、数据分区、事务、分布式挑战、一致性与共识。


第5章:数据复制

5.1 复制的目的

目的 说明
高可用性 部分节点故障时系统仍可用
降低延迟 将数据放在离用户更近的地方
提高读吞吐 多个副本可以并行处理读请求

5.2 主从复制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
┌──────────────────────────────────────────────┐
│ 主从复制架构 │
├──────────────────────────────────────────────┤
│ 客户端写入 │
│ │ │
│ ▼ │
│ ┌────────┐ 复制日志 ┌────────┐ │
│ │ 主节点 │ ─────────> │ 从节点1 │ │
│ │(Leader)│ │(Follower)│ │
│ └────────┘ └────────┘ │
│ │ │ │
│ │ 复制日志 │ │
│ ▼ ▼ │
│ ┌────────┐ ┌────────┐ │
│ │ 从节点2 │ │ 从节点3 │ │
│ └────────┘ └────────┘ │
│ │
│ 客户端可以从任意节点读取 │
└──────────────────────────────────────────────┘

同步复制 vs 异步复制

方式 优点 缺点
同步 数据一致,从节点数据最新 延迟高,从节点故障会阻塞
异步 延迟低,不受从节点影响 数据可能丢失
半同步 平衡一致性和可用性 实现复杂

故障转移(Failover)

1
2
3
4
5
1. 检测主节点故障(心跳超时)
2. 选举新主节点(选择数据最新的节点)
3. 重新配置系统
- 客户端发送写请求到新主节点
- 其他从节点从新主节点复制

问题:数据丢失、脑裂、超时设置

5.3 复制延迟问题

读己之写一致性:用户写入后立即读取可能看到旧数据

单调读:用户刷新页面可能路由到不同从节点,看到数据”回退”

一致前缀读:因果关系打乱(先看到回答再看到问题)

5.4 多主复制

1
2
3
4
┌──────────┐         ┌──────────┐
│ 数据中心1 │ <─────> │ 数据中心2 │
│ (主节点1) │ 异步 │ (主节点2) │
└──────────┘ └──────────┘

冲突解决策略

  • 最后写入胜出(LWW)
  • 合并值
  • CRDT
  • 自定义逻辑

5.5 无主复制(Quorum)

公式W + R > N

符号 含义
N 副本总数
W 写入需要确认的副本数
R 读取需要查询的副本数
1
2
3
4
示例:N=3, W=2, R=2
写入成功需要 2 个副本确认
读取需要查询 2 个副本
W + R = 4 > N = 3 → 保证读到最新值

第6章:数据分区

6.1 分区的目的

将大数据集分散到多个节点,每个节点只存储部分数据

6.2 分区策略

按键范围分区

1
分区1: A-F    分区2: G-N    分区3: O-Z

优点:支持范围查询
缺点:可能导致热点

按哈希分区

1
分区 = hash(key) % N

优点:均匀分布
缺点:不支持范围查询

6.3 次级索引分区

类型 特点
本地索引 写入快,查询需要scatter-gather
全局索引 查询快,写入需要分布式事务

6.4 分区再平衡

策略 说明
固定分区数 预先创建多个分区
动态分区 根据数据大小自动分裂/合并
按节点分区 每个节点固定数量的分区

第7章:事务

7.1 ACID 特性

特性 说明
原子性 (Atomicity) 全部成功或全部失败
一致性 (Consistency) 从一个有效状态到另一个有效状态
隔离性 (Isolation) 并发事务互不干扰
持久性 (Durability) 提交后数据不丢失

7.2 隔离级别

级别 脏读 不可重复读 幻读
读未提交
读已提交
可重复读
可串行化

7.3 串行化实现

真正的串行执行

  • 单线程处理事务
  • 适合事务短、数据在内存

两阶段锁定(2PL)

1
2
增长阶段:只能获取锁
收缩阶段:只能释放锁

可串行化快照隔离(SSI)

  • 乐观并发控制
  • 检测冲突后回滚

7.4 分布式事务

两阶段提交(2PC)

1
2
3
4
5
6
7
阶段1(准备):
协调者 ──准备?──> 所有参与者
<──准备好──

阶段2(提交):
协调者 ──提交──> 所有参与者
<──完成──

问题:协调者故障时参与者阻塞


第8章:分布式系统的挑战

8.1 不可靠的网络

  • 请求可能丢失
  • 请求可能排队很久
  • 响应可能丢失
  • 无法区分节点故障和网络故障

8.2 不可靠的时钟

时钟类型 用途
墙上时钟 获取当前时间(NTP同步)
单调时钟 测量时间间隔

问题:不同机器时钟可能不同步

8.3 进程暂停

  • GC 暂停
  • 虚拟机暂停
  • 页面交换

8.4 拜占庭故障

节点可能发送任意消息(故意或无意)


第9章:一致性与共识

9.1 一致性模型

1
2
3
弱 ←─────────────────────────────────────→ 强

最终一致性 因果一致性 顺序一致性 线性一致性

9.2 线性一致性

系统表现得好像只有一个数据副本,所有操作都是原子的

1
2
3
4
5
6
7
8
时间线:
────────────────────────────────────────>

客户端A: ├── 写入 x=1 ──┤
客户端B: ├── 读取 x ──┤ → 必须返回 1
客户端C: ├── 读取 x ──┤ → 必须返回 1

一旦任何客户端读到新值,所有后续读取都必须返回新值

9.3 CAP 定理

1
2
3
4
5
6
7
8
9
10
11
      C (Consistency)
╱╲
╱ ╲
╱ 只能 ╲
╱ 选择2个 ╲
A ────────── P
(Availability) (Partition Tolerance)

网络分区时必须选择:
- CP:保证一致性,牺牲可用性
- AP:保证可用性,牺牲一致性

9.4 共识算法

共识的性质

性质 说明
一致同意 所有节点决定相同的值
完整性 决定的值必须是某个节点提议的
终止性 最终会做出决定

9.5 Raft 算法

1
2
3
4
5
6
7
8
9
10
11
┌──────────┐
│ 领导者 │ ← 处理所有客户端请求
│ (Leader) │
└──────────┘

│ 复制日志

┌──────────┐ ┌──────────┐
│ 跟随者1 │ │ 跟随者2 │
│(Follower)│ │(Follower)│
└──────────┘ └──────────┘

选举流程

  1. 跟随者超时未收到心跳
  2. 转变为候选人,增加任期
  3. 向其他节点请求投票
  4. 获得多数票则成为领导者

9.6 Paxos 算法

1
2
3
4
5
6
7
阶段1(Prepare):
提议者 ──Prepare(n)──> 接受者
<──Promise──

阶段2(Accept):
提议者 ──Accept(n,v)──> 接受者
<──Accepted──

Raft vs Paxos

方面 Raft Paxos
可理解性
领导者 稳定领导者 每次共识可能不同
日志 连续日志 可能有空洞

9.7 共识的应用

  • ZooKeeper:Kafka 使用进行领导者选举
  • etcd:Kubernetes 存储集群状态
  • Consul:服务发现和配置

本章关键要点

  1. 复制的核心挑战是处理变更
  2. 异步复制有数据丢失风险
  3. Quorum 使用 W + R > N 保证一致性
  4. 分区支持水平扩展
  5. 隔离级别是性能与一致性的权衡
  6. 线性一致性是最强的一致性保证
  7. CAP 定理表明网络分区时需要取舍
  8. Raft 比 Paxos 更易理解

延伸阅读

  • 《Amazon Dynamo》论文:无主复制经典论文
  • 《Paxos Made Simple》:Lamport 的简化版 Paxos
  • 《In Search of an Understandable Consensus Algorithm》:Raft 论文

本文是 DDIA 第三部分的完整读书笔记,涵盖第 10-12 章:批处理、流处理、数据系统的未来。


第10章:批处理

10.1 系统类型对比

类型 特点 示例
在线服务 请求-响应模式,低延迟 Web 服务、API
批处理系统 处理大量数据,高吞吐 MapReduce、Spark
流处理系统 实时处理数据流 Kafka Streams、Flink
1
2
3
4
5
6
7
8
9
在线服务: 用户请求 ──> [服务] ──> 响应 (毫秒级)

批处理:
┌────────────────┐ ┌────────────────┐
│ 大量输入数据 │ ──> │ 批处理作业 │ ──> 输出结果
│ (TB级别) │ │ (运行数小时) │
└────────────────┘ └────────────────┘

流处理: 事件流 ──> [处理] ──> 输出流 (持续进行)

10.2 Unix 工具的批处理

1
2
3
4
5
6
7
# 找出访问量最高的 URL
cat access.log |
awk '{print $7}' | # 提取 URL 字段
sort | # 排序
uniq -c | # 计数
sort -rn | # 按计数降序排序
head -n 10 # 取前10

Unix 哲学

  • 每个程序做好一件事
  • 输出可以成为另一个程序的输入
  • 快速原型开发

10.3 MapReduce

1
2
3
4
5
6
7
8
9
10
11
12
13
14
MapReduce 流程:

输入数据 Map 阶段 Shuffle Reduce 阶段 输出
┌─────┐ ┌─────────┐ ┌──────────┐
│分片1│ ──────>│ Map 1 │ ─┐ ┌─>│ Reduce 1 │ ──> 结果1
└─────┘ └─────────┘ │ │ └──────────┘
│ ┌───────────┐ │
┌─────┐ ┌─────────┐ └>│ 按键分组 │ ─┤
│分片2│ ──────>│ Map 2 │ ───│ Shuffle │ │
└─────┘ └─────────┘ ┌─>│ │ ─┤
│ └───────────┘ │ ┌──────────┐
┌─────┐ ┌─────────┐ │ └>│ Reduce 2 │ ──> 结果2
│分片3│ ──────>│ Map 3 │ ─┘ └──────────┘
└─────┘ └─────────┘

词频统计示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Map 函数
def map(document):
for word in document.split():
emit(word, 1)

# Reduce 函数
def reduce(word, counts):
emit(word, sum(counts))

# 执行流程:
# 输入: "hello world hello"
# Map 输出: ("hello", 1), ("world", 1), ("hello", 1)
# Shuffle 后: "hello": [1, 1], "world": [1]
# Reduce 输出: ("hello", 2), ("world", 1)

10.4 Join 策略

策略 特点 适用场景
排序-合并 Join 两边都排序后合并 Reduce 端 Join
广播 Join 小表广播到所有节点 一边数据量小
分区 Join 按相同键分区 两边都很大

10.5 现代批处理框架

框架 特点
Spark 内存计算,比 MapReduce 快 10-100x
Flink 统一批流处理
Presto/Trino 交互式 SQL 查询

第11章:流处理

11.1 批处理 vs 流处理

方面 批处理 流处理
数据 有界,静态 无界,持续到达
延迟 分钟~小时 毫秒~秒
触发 定时/手动 事件驱动
结果 一次性输出 持续更新

11.2 消息系统

传统消息队列

代表:RabbitMQ, ActiveMQ, Amazon SQS

1
2
3
4
5
6
7
8
9
10
┌─────────────────────────────┐
│ 队列 │
├─────────────────────────────┤
│ msg1 │ msg2 │ msg3 │ msg4 │
└──┬───┴──┬───┴──┬───┴──┬────┘
▼ ▼ ▼ ▼
消费1 消费2 消费1 消费2

每条消息只被一个消费者处理
处理后消息被删除

日志型消息系统(Kafka)

1
2
3
4
5
6
7
8
分区0: [msg0] [msg3] [msg6] [msg9]  ──> 消费者A
分区1: [msg1] [msg4] [msg7] [msg10] ──> 消费者B
分区2: [msg2] [msg5] [msg8] [msg11] ──> 消费者C

特点:
- 消息持久化,可重放
- 保序(分区内)
- 多消费者组可独立消费

11.3 Kafka 架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
┌─────────────────────────────────────────────────────┐
│ Kafka Cluster │
├─────────────────────────────────────────────────────┤
│ Topic: orders │
│ ┌────────────────┐ ┌────────────────┐ │
│ │ Partition 0 │ │ Partition 1 │ │
│ │ [0][1][2][3]...│ │ [0][1][2][3]...│ │
│ │ Leader: B1 │ │ Leader: B2 │ │
│ │ Replica: B2 │ │ Replica: B1 │ │
│ └────────────────┘ └────────────────┘ │
│ │
│ Broker 1 Broker 2 │
└─────────────────────────────────────────────────────┘

┌───────────────┼───────────────┐
▼ ▼ ▼
Consumer 1 Consumer 2 Consumer 3
Group A Group A Group B

11.4 变更数据捕获(CDC)

捕获数据库的变更,将其转换为事件流

1
2
3
4
5
6
7
8
9
10
┌─────────────┐      ┌─────────────┐      ┌─────────────┐
│ 应用程序 │ ──> │ 数据库 │ ──> │ CDC 工具 │
└─────────────┘ └─────────────┘ │ (Debezium) │
│ └─────────────┘
变更日志 │
│ 事件流到 Kafka
▼ │
┌───────────┐ ┌───────────┐
│ binlog │ │ Kafka │
└───────────┘ └───────────┘

应用场景

  • 同步搜索索引、缓存
  • 微服务间数据同步
  • 实时 ETL

11.5 事件溯源(Event Sourcing)

1
2
3
4
5
6
7
8
9
10
11
传统方式:直接修改状态
账户余额: 100 ──> 90 ──> 140 ──> 110

事件溯源:存储事件
事件日志:
1. 初始存款 100
2. 取款 10
3. 存款 50
4. 取款 30

重放事件 ──> 计算当前状态

11.6 时间语义

时间类型 定义 用途
事件时间 事件实际发生的时间 业务逻辑
处理时间 事件到达处理系统的时间 系统监控
摄入时间 事件进入流处理系统的时间 折中方案

11.7 窗口操作

窗口类型 特点
滚动窗口 固定大小,不重叠
滑动窗口 固定大小,可重叠
会话窗口 按活动间隙分割
1
2
3
4
5
6
7
滚动窗口 (5分钟):
[00:00-05:00] [05:00-10:00] [10:00-15:00]

滑动窗口 (5分钟窗口, 1分钟滑动):
[00:00-05:00]
[01:00-06:00]
[02:00-07:00]

第12章:数据系统的未来

12.1 数据集成

将多个系统的数据统一管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
┌─────────────────────────────────────────────────────┐
│ 数据集成架构 │
├─────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 数据库 │ │ 缓存 │ │ 搜索引擎 │ │
│ └────┬─────┘ └────┬─────┘ └────┬─────┘ │
│ │ │ │ │
│ └──────────────┼──────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ 事件日志 │ │
│ │ (Kafka) │ │
│ └──────────────┘ │
│ │ │
│ ┌──────────────┼──────────────┐ │
│ ▼ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 分析系统 │ │ ML平台 │ │ 监控系统 │ │
│ └──────────┘ └──────────┘ └──────────┘ │
│ │
└─────────────────────────────────────────────────────┘

12.2 Lambda 架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Lambda 架构:
输入数据

┌─────────┴─────────┐
▼ ▼
┌──────────────┐ ┌──────────────┐
│ 批处理层 │ │ 速度层 │
│ (全量计算) │ │ (增量计算) │
│ │ │ │
│ MapReduce │ │ Storm/Flink │
└──────┬───────┘ └──────┬───────┘
│ │
▼ ▼
┌──────────────┐ ┌──────────────┐
│ 批处理视图 │ │ 实时视图 │
└──────┬───────┘ └──────┬───────┘
│ │
└─────────┬─────────┘

┌──────────────┐
│ 服务层 │
│ (合并结果) │
└──────────────┘

问题:需要维护两套代码

12.3 Kappa 架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
Kappa 架构:
输入数据


┌──────────────┐
│ 日志存储 │
│ (Kafka) │
└──────┬───────┘


┌──────────────┐
│ 流处理 │
│ (Flink) │
└──────┬───────┘


┌──────────────┐
│ 服务层 │
└──────────────┘

统一批流处理:重放日志即可重新计算

12.4 解绑数据库

将数据库的各个功能分解为独立组件

功能 传统数据库 解绑后
存储 B-Tree 分布式文件系统
事务 内置 独立事务管理器
索引 内置 外部搜索引擎
缓存 内置 Redis

12.5 端到端论证

某些功能只能在端到端层面正确实现

示例:exactly-once 消息传递

  • 中间件可能声称 exactly-once
  • 但网络可能中断
  • 真正的 exactly-once 需要应用层去重

12.6 数据流系统设计原则

  1. 不可变性:日志是追加的,事件不可修改
  2. 可重放性:可以从日志重建状态
  3. 分离计算和存储:计算层无状态
  4. 模式演化:支持模式向前/向后兼容

本章关键要点

  1. 批处理适合高吞吐的离线计算
  2. MapReduce 是分布式批处理的基础范式
  3. 流处理适合低延迟的实时计算
  4. Kafka 是现代流处理的核心组件
  5. CDC 连接传统数据库和事件流
  6. 事件溯源将状态变更作为事件序列存储
  7. 时间语义在流处理中至关重要
  8. Lambda/Kappa 架构是批流融合的尝试
  9. 解绑数据库提供更大的灵活性

延伸阅读

  • 《Streaming Systems》:流处理权威指南
  • 《Kafka: The Definitive Guide》:Kafka 权威指南
  • 《Designing Data-Intensive Applications》原书

本文综述神经网络在机器阅读理解和对话系统中的发展历程,从早期的注意力机制到现代大语言模型。

发展时间线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
2015-2016: 注意力机制兴起
└── Attentive Reader, Impatient Reader, BiDAF

2017-2018: 深度交互与预训练
└── R-Net, QANet, BERT

2019-2020: 大规模预训练
└── RoBERTa, ALBERT, T5

2021-2023: 大语言模型时代
└── GPT-3, ChatGPT, GPT-4, LLaMA

2024-: 检索增强与多模态
└── RAG, Vision-Language Models

核心技术演进

阶段一:注意力机制 (2015-2017)

问题:如何让模型”关注”与问题相关的上下文?

代表模型:Attentive Reader, BiDAF

阶段二:深度交互 (2017-2018)

问题:如何建模问题和上下文的复杂交互?

技术:多轮注意力、自注意力、门控机制

1
2
3
4
5
6
# 多轮推理 (R-Net 风格)
for layer in range(num_layers):
# 自注意力
context = self_attention(context, context)
# 交叉注意力
context = cross_attention(context, question)

阶段三:预训练语言模型 (2018-2020)

范式转变:从 task-specific 到 pretrain-finetune

$$
\theta^* = \arg\min_\theta \mathcal{L}{task}(\text{PLM}\theta(x), y)
$$

代表模型:BERT, RoBERTa, ALBERT

1
2
3
4
from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
# Fine-tune on SQuAD

阶段四:大语言模型 (2020-至今)

范式转变:从 fine-tuning 到 prompting

1
2
3
4
5
6
7
8
9
# Few-shot prompting
prompt = """
Context: The Eiffel Tower was built in 1889.
Question: When was the Eiffel Tower built?
Answer: 1889

Context: {context}
Question: {question}
Answer:"""

架构对比

模型 参数量 训练范式 SQuAD 2.0 F1
BiDAF ~2M 从零训练 77.3
BERT-base 110M 预训练+微调 88.5
BERT-large 340M 预训练+微调 90.9
RoBERTa-large 355M 预训练+微调 91.4
GPT-3 175B Few-shot ~88
GPT-4 ~1.8T Zero-shot ~95

现代 MRC 系统设计

RAG 架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class ModernMRC:
def __init__(self, retriever, reader):
self.retriever = retriever # Dense retriever
self.reader = reader # LLM

def answer(self, question: str, knowledge_base: str = None):
# 1. 检索
if knowledge_base:
docs = self.retriever.retrieve(question, knowledge_base)
context = "\n\n".join([d.text for d in docs])
else:
context = ""

# 2. 阅读理解/生成
prompt = self._build_prompt(question, context)
answer = self.reader.generate(prompt)

# 3. 后处理(可选:验证、引用)
return self._postprocess(answer, docs)

def _build_prompt(self, question, context):
if context:
return f"""Based on the following context, answer the question.

Context:
{context}

Question: {question}
Answer:"""
else:
return f"Question: {question}\nAnswer:"

多跳推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class MultiHopReasoner:
def __init__(self, retriever, llm, max_hops=3):
self.retriever = retriever
self.llm = llm
self.max_hops = max_hops

def reason(self, question):
reasoning_chain = []
current_query = question

for hop in range(self.max_hops):
# 检索
docs = self.retriever.retrieve(current_query)

# 生成中间推理
intermediate = self.llm.generate(
f"Based on: {docs}\nQuestion: {current_query}\n"
f"Provide intermediate reasoning or the final answer:"
)

reasoning_chain.append({
'query': current_query,
'docs': docs,
'reasoning': intermediate
})

# 检查是否已得到答案
if self._is_final_answer(intermediate):
break

# 生成下一跳查询
current_query = self._generate_next_query(question, reasoning_chain)

return self._synthesize_answer(question, reasoning_chain)

对话系统中的 MRC

对话式问答

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class ConversationalQA:
def __init__(self, mrc_model, history_length=5):
self.mrc_model = mrc_model
self.history = []
self.history_length = history_length

def ask(self, question, context=None):
# 将对话历史纳入问题
contextualized_question = self._contextualize(question)

# 获取答案
answer = self.mrc_model.answer(contextualized_question, context)

# 更新历史
self.history.append({'q': question, 'a': answer})
if len(self.history) > self.history_length:
self.history.pop(0)

return answer

def _contextualize(self, question):
if not self.history:
return question

history_text = "\n".join([
f"Q: {turn['q']}\nA: {turn['a']}"
for turn in self.history
])

return f"Conversation history:\n{history_text}\n\nCurrent question: {question}"

评估体系

传统指标

指标 定义 适用场景
EM 精确匹配 抽取式 QA
F1 Token 重叠 抽取式 QA
BLEU N-gram 重叠 生成式 QA
ROUGE 召回导向重叠 摘要、长答案

LLM 时代指标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# LLM-as-Judge
def llm_evaluate(question, reference, prediction):
prompt = f"""Evaluate the answer quality on a scale of 1-5:

Question: {question}
Reference Answer: {reference}
Model Answer: {prediction}

Criteria:
- Correctness: Is the information accurate?
- Completeness: Does it fully answer the question?
- Conciseness: Is it appropriately brief?

Score (1-5):"""

return llm.generate(prompt)

延伸阅读


转载请注明出处

本文从零开始实现一个机器阅读理解系统,涵盖数据处理、模型构建、训练和推理的完整流程。

任务定义

给定上下文 和问题 ,预测答案 中的位置:

数据处理

SQuAD 数据格式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class Example:
context: str
question: str
answer_text: str
start_position: int
end_position: int

def load_squad(file_path: str) -> List[Example]:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)

examples = []
for article in data['data']:
for paragraph in article['paragraphs']:
context = paragraph['context']
for qa in paragraph['qas']:
question = qa['question']
if qa.get('is_impossible', False):
continue
answer = qa['answers'][0]
examples.append(Example(
context=context,
question=question,
answer_text=answer['text'],
start_position=answer['answer_start'],
end_position=answer['answer_start'] + len(answer['text'])
))

return examples

Tokenization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoTokenizer

class MRCTokenizer:
def __init__(self, model_name: str, max_length: int = 384, doc_stride: int = 128):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self.doc_stride = doc_stride

def encode(self, example: Example):
# Tokenize question and context
encoding = self.tokenizer(
example.question,
example.context,
max_length=self.max_length,
truncation='only_second',
stride=self.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding='max_length',
)

# 找到答案在 token 序列中的位置
offset_mapping = encoding['offset_mapping'][0]

start_token = None
end_token = None

for idx, (start, end) in enumerate(offset_mapping):
if start <= example.start_position < end:
start_token = idx
if start < example.end_position <= end:
end_token = idx
break

return {
'input_ids': encoding['input_ids'][0],
'attention_mask': encoding['attention_mask'][0],
'start_position': start_token or 0,
'end_position': end_token or 0,
}

模型实现

基于 BERT 的 MRC 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
from transformers import AutoModel

class MRCModel(nn.Module):
def __init__(self, model_name: str, dropout: float = 0.1):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
hidden_size = self.bert.config.hidden_size

self.dropout = nn.Dropout(dropout)
self.start_classifier = nn.Linear(hidden_size, 1)
self.end_classifier = nn.Linear(hidden_size, 1)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)

# (batch, seq_len, 1) -> (batch, seq_len)
start_logits = self.start_classifier(sequence_output).squeeze(-1)
end_logits = self.end_classifier(sequence_output).squeeze(-1)

# Mask padding tokens
start_logits = start_logits.masked_fill(~attention_mask.bool(), -1e9)
end_logits = end_logits.masked_fill(~attention_mask.bool(), -1e9)

loss = None
if start_positions is not None and end_positions is not None:
loss_fct = nn.CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
loss = (start_loss + end_loss) / 2

return {
'loss': loss,
'start_logits': start_logits,
'end_logits': end_logits,
}

改进:联合 Start-End 预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class JointMRCModel(nn.Module):
"""联合预测 start 和 end,考虑 start-end 依赖"""

def __init__(self, model_name: str, max_answer_length: int = 30):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
hidden_size = self.bert.config.hidden_size
self.max_answer_length = max_answer_length

self.start_classifier = nn.Linear(hidden_size, 1)
self.end_classifier = nn.Linear(hidden_size * 2, 1)

def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
H = outputs.last_hidden_state # (batch, seq_len, hidden)

# Start prediction
start_logits = self.start_classifier(H).squeeze(-1)

if self.training and start_positions is not None:
# 训练时使用真实的 start 位置
start_indices = start_positions.unsqueeze(-1).unsqueeze(-1)
start_repr = H.gather(1, start_indices.expand(-1, -1, H.size(-1))).squeeze(1)
else:
# 推理时使用预测的 start 位置
start_indices = start_logits.argmax(dim=-1, keepdim=True).unsqueeze(-1)
start_repr = H.gather(1, start_indices.expand(-1, -1, H.size(-1))).squeeze(1)

# End prediction conditioned on start
start_repr_expanded = start_repr.unsqueeze(1).expand(-1, H.size(1), -1)
end_input = torch.cat([H, start_repr_expanded], dim=-1)
end_logits = self.end_classifier(end_input).squeeze(-1)

# 只允许 end >= start 且在 max_answer_length 范围内
# 这里简化处理,完整实现需要更复杂的 mask

return {'start_logits': start_logits, 'end_logits': end_logits}

训练流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torch.utils.data import DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

def train(model, train_dataloader, val_dataloader, epochs=3, lr=3e-5):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

best_f1 = 0
for epoch in range(epochs):
model.train()
total_loss = 0

for batch in tqdm(train_dataloader, desc=f'Epoch {epoch+1}'):
batch = {k: v.to(device) for k, v in batch.items()}

outputs = model(**batch)
loss = outputs['loss']

loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
scheduler.step()
optimizer.zero_grad()

total_loss += loss.item()

avg_loss = total_loss / len(train_dataloader)
print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

# Validation
f1 = evaluate(model, val_dataloader, device)
print(f'Validation F1: {f1:.4f}')

if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), 'best_model.pt')

return model

评估与推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import re
import string
from collections import Counter

def normalize_answer(s):
"""标准化答案用于评估"""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_f1(pred: str, gold: str) -> float:
pred_tokens = normalize_answer(pred).split()
gold_tokens = normalize_answer(gold).split()

common = Counter(pred_tokens) & Counter(gold_tokens)
num_same = sum(common.values())

if num_same == 0:
return 0

precision = num_same / len(pred_tokens)
recall = num_same / len(gold_tokens)

return 2 * precision * recall / (precision + recall)

def predict(model, tokenizer, context: str, question: str, device):
"""单条推理"""
model.eval()

encoding = tokenizer(
question, context,
max_length=384,
truncation='only_second',
return_tensors='pt'
)

encoding = {k: v.to(device) for k, v in encoding.items()}

with torch.no_grad():
outputs = model(**encoding)

start_idx = outputs['start_logits'].argmax().item()
end_idx = outputs['end_logits'].argmax().item()

# 确保 end >= start
if end_idx < start_idx:
end_idx = start_idx

# 解码答案
answer_tokens = encoding['input_ids'][0][start_idx:end_idx+1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

return answer

现代方法:使用 LLM

对于更复杂的问答需求,可以使用 LLM:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from openai import OpenAI

def llm_qa(context: str, question: str) -> str:
client = OpenAI()

response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个问答助手。根据给定的上下文回答问题。如果答案不在上下文中,请说'无法回答'。"},
{"role": "user", "content": f"上下文:{context}\n\n问题:{question}"}
],
temperature=0
)

return response.choices[0].message.content

延伸阅读


转载请注明出处

条件随机场 (CRF) 是序列标注的经典模型,尽管深度学习时代 BERT 等模型大放异彩,CRF 层仍然在 NER、词性标注等任务中发挥关键作用。

为什么需要 CRF?

独立分类的问题

如果对每个位置独立分类:

会导致标签不一致,例如:

1
2
3
输入: "北 京 是 中 国 首 都"
错误: B-LOC I-PER O B-LOC I-LOC I-LOC I-LOC
正确: B-LOC I-LOC O B-LOC I-LOC I-LOC I-LOC

CRF 的解决方案

CRF 建模整个序列的联合概率,考虑标签之间的转移约束

数学原理

条件概率

其中:

  • :发射分数(emission score)
  • :转移分数(transition score)
  • :配分函数(归一化项)

配分函数

直接计算复杂度为 ,使用前向算法可降至

PyTorch 实现

CRF Layer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn

class CRF(nn.Module):
def __init__(self, num_tags, batch_first=True):
super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first

# 转移矩阵: transitions[i, j] = 从标签 j 转移到标签 i 的分数
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))

# 起始和结束转移
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.end_transitions = nn.Parameter(torch.randn(num_tags))

def forward(self, emissions, tags, mask=None):
"""计算负对数似然损失"""
if mask is None:
mask = torch.ones_like(tags, dtype=torch.bool)

if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)

# 计算分子(正确路径的分数)
numerator = self._compute_score(emissions, tags, mask)

# 计算分母(配分函数)
denominator = self._compute_normalizer(emissions, mask)

# 负对数似然
return (denominator - numerator).mean()

def _compute_score(self, emissions, tags, mask):
"""计算给定标签序列的分数"""
seq_len, batch_size = tags.shape

# 起始分数
score = self.start_transitions[tags[0]]
score += emissions[0, torch.arange(batch_size), tags[0]]

for i in range(1, seq_len):
# 转移分数 + 发射分数
score += self.transitions[tags[i], tags[i-1]] * mask[i]
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

# 结束分数
last_tag_idx = mask.sum(dim=0) - 1
last_tags = tags.gather(0, last_tag_idx.unsqueeze(0)).squeeze(0)
score += self.end_transitions[last_tags]

return score

def _compute_normalizer(self, emissions, mask):
"""前向算法计算配分函数"""
seq_len, batch_size, num_tags = emissions.shape

# 初始化
score = self.start_transitions + emissions[0]

for i in range(1, seq_len):
# broadcast: (batch, num_tags, 1) + (num_tags, num_tags) + (batch, 1, num_tags)
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)

next_score = broadcast_score + self.transitions + broadcast_emissions
next_score = torch.logsumexp(next_score, dim=1)

# 应用 mask
score = torch.where(mask[i].unsqueeze(1), next_score, score)

# 添加结束分数
score += self.end_transitions

return torch.logsumexp(score, dim=1)

def decode(self, emissions, mask=None):
"""Viterbi 解码"""
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)

if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)

return self._viterbi_decode(emissions, mask)

def _viterbi_decode(self, emissions, mask):
"""Viterbi 算法"""
seq_len, batch_size, num_tags = emissions.shape

# 初始化
score = self.start_transitions + emissions[0]
history = []

for i in range(1, seq_len):
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)

next_score = broadcast_score + self.transitions + broadcast_emissions
next_score, indices = next_score.max(dim=1)

score = torch.where(mask[i].unsqueeze(1), next_score, score)
history.append(indices)

# 添加结束分数
score += self.end_transitions

# 回溯
best_tags_list = []
_, best_last_tag = score.max(dim=1)

for idx in range(batch_size):
best_tags = [best_last_tag[idx].item()]
seq_length = int(mask[:, idx].sum().item())

for hist in reversed(history[:seq_length-1]):
best_last_tag_idx = best_tags[-1]
best_tags.append(hist[idx, best_last_tag_idx].item())

best_tags.reverse()
best_tags_list.append(best_tags)

return best_tags_list

与 BiLSTM 结合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_tags):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim // 2,
num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_tags)
self.crf = CRF(num_tags)

def forward(self, x, tags, mask=None):
embeddings = self.embedding(x)
lstm_out, _ = self.lstm(embeddings)
emissions = self.fc(lstm_out)

return self.crf(emissions, tags, mask)

def predict(self, x, mask=None):
embeddings = self.embedding(x)
lstm_out, _ = self.lstm(embeddings)
emissions = self.fc(lstm_out)

return self.crf.decode(emissions, mask)

现代应用:BERT + CRF

尽管 BERT 已经很强大,但 CRF 层仍能带来一致性提升:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import BertModel

class BERT_CRF(nn.Module):
def __init__(self, bert_name, num_tags):
super().__init__()
self.bert = BertModel.from_pretrained(bert_name)
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.bert.config.hidden_size, num_tags)
self.crf = CRF(num_tags)

def forward(self, input_ids, attention_mask, tags=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
emissions = self.fc(sequence_output)

if tags is not None:
return self.crf(emissions, tags, attention_mask.bool())
else:
return self.crf.decode(emissions, attention_mask.bool())

性能对比(CoNLL-2003 NER)

模型 F1
BiLSTM 88.2
BiLSTM + CRF 90.1
BERT 92.4
BERT + CRF 92.8
RoBERTa + CRF 93.2

训练技巧

1. 标签平滑

1
2
3
4
5
6
7
8
def label_smoothing_loss(crf, emissions, tags, mask, epsilon=0.1):
"""带标签平滑的 CRF 损失"""
nll_loss = crf(emissions, tags, mask)

# 均匀分布的损失
uniform_loss = -torch.logsumexp(emissions, dim=-1).mean()

return (1 - epsilon) * nll_loss + epsilon * uniform_loss

2. 约束解码

1
2
3
4
5
6
7
8
# 添加硬约束:B-X 后面只能接 I-X 或 O
def add_constraints(transitions, tag2idx):
for tag_from, idx_from in tag2idx.items():
for tag_to, idx_to in tag2idx.items():
if tag_from.startswith('B-') or tag_from.startswith('I-'):
entity = tag_from[2:]
if tag_to.startswith('I-') and tag_to[2:] != entity:
transitions.data[idx_to, idx_from] = -1e9

延伸阅读

  • Lafferty et al., Conditional Random Fields (2001)
  • Huang et al., Bidirectional LSTM-CRF Models for Sequence Tagging (2015)
  • pytorch-crf Documentation

转载请注明出处

BiDAF (Bi-Directional Attention Flow) 是机器阅读理解领域的经典模型,其双向注意力机制对后续 Transformer 架构产生了深远影响。

核心创新

1. Memory-less Attention

传统动态注意力 vs BiDAF 的无记忆注意力:

特性 Dynamic Attention Memory-less Attention
依赖 前一时间步的 attended vector 仅当前 query 和 context
优势 可建模时序依赖 避免错误累积
缺点 错误会传播 无法建模长程依赖

2. 双向注意力

同时计算:

  • Context-to-Query (C2Q):每个 context 词最相关的 query 词
  • Query-to-Context (Q2C):对回答问题最关键的 context 词

模型架构

1
2
3
Input → Embedding → Encoding → Attention → Modeling → Output
│ │ │ │ │ │
词向量 字符CNN BiLSTM 双向注意力 BiLSTM Span预测

数学表达

相似度矩阵

其中 是 context 表示, 是 query 表示。

C2Q Attention

$$
\tilde{U}i = \sum_j a{ij} U_j, \quad a_i = \text{softmax}(S_i)
$$

Q2C Attention

融合表示

PyTorch 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn

class BiDAFAttention(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.W = nn.Linear(hidden_size * 3, 1, bias=False)

def forward(self, context, query, c_mask, q_mask):
"""
Args:
context: (batch, c_len, hidden)
query: (batch, q_len, hidden)
c_mask: (batch, c_len)
q_mask: (batch, q_len)
"""
batch, c_len, hidden = context.size()
q_len = query.size(1)

# 扩展维度以计算所有 (i, j) 对
c_expand = context.unsqueeze(2).expand(-1, -1, q_len, -1)
q_expand = query.unsqueeze(1).expand(-1, c_len, -1, -1)

# 计算相似度矩阵 S
cq = torch.cat([c_expand, q_expand, c_expand * q_expand], dim=-1)
S = self.W(cq).squeeze(-1) # (batch, c_len, q_len)

# Mask
q_mask_expand = q_mask.unsqueeze(1).expand(-1, c_len, -1)
S = S.masked_fill(~q_mask_expand, -1e9)

# C2Q attention
a = torch.softmax(S, dim=-1)
c2q = torch.bmm(a, query) # (batch, c_len, hidden)

# Q2C attention
b = torch.softmax(S.max(dim=-1)[0], dim=-1)
q2c = torch.bmm(b.unsqueeze(1), context) # (batch, 1, hidden)
q2c = q2c.expand(-1, c_len, -1)

# 融合
G = torch.cat([context, c2q, context * c2q, context * q2c], dim=-1)

return G

与 Transformer 的对比

特性 BiDAF Transformer
注意力方向 双向(C2Q, Q2C) 全方向自注意力
位置编码 BiLSTM 隐式编码 显式位置编码
并行化 受限于 RNN 完全并行
长距离依赖 受限 理论上无限
参数量 较少 较多

现代演进

BiDAF 的思想在现代模型中的体现:

1. Cross-Attention in Transformer

1
2
3
4
5
6
7
8
class CrossAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.mha = nn.MultiheadAttention(d_model, n_heads)

def forward(self, query, key_value):
# query 来自一个序列,key/value 来自另一个序列
return self.mha(query, key_value, key_value)

2. FiD (Fusion-in-Decoder)

用于 RAG 的架构,类似 BiDAF 的融合思想:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class FiD(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, question, passages):
# 独立编码每个 passage
encoded = []
for passage in passages:
enc = self.encoder(question + passage)
encoded.append(enc)

# 融合解码
fused = torch.cat(encoded, dim=1)
return self.decoder(fused)

实验结果(原论文)

在 SQuAD 1.1 上的表现:

模型 EM F1
BiDAF 67.7 77.3
BiDAF + Self Attention 72.1 81.1
BERT-base 80.8 88.5
GPT-4 (few-shot) ~90 ~95

延伸阅读


转载请注明出处

本文介绍机器阅读理解模型的完整实现,涵盖经典架构和现代最佳实践。

问题定义

输入

  • 问题
  • 文档

输出

  • 答案起始位置
  • 答案结束位置

经典架构

1
Input → Embedding → Encoding → Matching → Fusion → Decoding

各层详解

功能 现代替代
Embedding Token → Vector Subword Tokenization
Encoding 序列编码 Transformer Encoder
Matching Q-P 交互 Cross-Attention
Fusion 信息融合 Self-Attention
Decoding Span 预测 Linear + Softmax

PyTorch 实现

完整模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

class MRCModel(nn.Module):
"""基于 Transformer 的 MRC 模型"""

def __init__(
self,
model_name: str = "bert-base-chinese",
dropout: float = 0.1,
max_answer_length: int = 30
):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
hidden_size = self.encoder.config.hidden_size
self.max_answer_length = max_answer_length

self.dropout = nn.Dropout(dropout)
self.start_fc = nn.Linear(hidden_size, 1)
self.end_fc = nn.Linear(hidden_size, 1)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
start_positions: torch.Tensor = None,
end_positions: torch.Tensor = None,
):
# 编码
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
sequence_output = self.dropout(outputs.last_hidden_state)

# 预测 start/end
start_logits = self.start_fc(sequence_output).squeeze(-1)
end_logits = self.end_fc(sequence_output).squeeze(-1)

# Mask padding
mask = attention_mask.bool()
start_logits = start_logits.masked_fill(~mask, float('-inf'))
end_logits = end_logits.masked_fill(~mask, float('-inf'))

# 计算损失
loss = None
if start_positions is not None and end_positions is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
loss = (start_loss + end_loss) / 2

return {
'loss': loss,
'start_logits': start_logits,
'end_logits': end_logits,
}

def decode(
self,
start_logits: torch.Tensor,
end_logits: torch.Tensor,
attention_mask: torch.Tensor,
):
"""解码最佳答案 span"""
batch_size, seq_len = start_logits.shape

# 计算所有有效 (start, end) 对的分数
start_probs = F.softmax(start_logits, dim=-1)
end_probs = F.softmax(end_logits, dim=-1)

results = []
for b in range(batch_size):
best_score = float('-inf')
best_start, best_end = 0, 0

for start in range(seq_len):
if not attention_mask[b, start]:
continue
for end in range(start, min(start + self.max_answer_length, seq_len)):
if not attention_mask[b, end]:
continue
score = start_probs[b, start] * end_probs[b, end]
if score > best_score:
best_score = score
best_start, best_end = start, end

results.append((best_start, best_end))

return results

数据处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from dataclasses import dataclass
from typing import List, Optional
import json

@dataclass
class MRCExample:
qid: str
question: str
context: str
answer: Optional[str] = None
start_position: Optional[int] = None

@dataclass
class MRCFeature:
input_ids: List[int]
attention_mask: List[int]
token_type_ids: List[int]
start_position: int
end_position: int
offset_mapping: List[tuple]

class MRCProcessor:
def __init__(self, model_name: str, max_length: int = 512):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length

def process(self, example: MRCExample) -> MRCFeature:
encoding = self.tokenizer(
example.question,
example.context,
max_length=self.max_length,
truncation='only_second',
return_offsets_mapping=True,
padding='max_length',
)

# 定位答案位置
start_token, end_token = 0, 0
if example.start_position is not None:
offset = encoding['offset_mapping']
for idx, (start, end) in enumerate(offset):
if start <= example.start_position < end:
start_token = idx
if start < example.start_position + len(example.answer) <= end:
end_token = idx
break

return MRCFeature(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
token_type_ids=encoding.get('token_type_ids', [0] * len(encoding['input_ids'])),
start_position=start_token,
end_position=end_token,
offset_mapping=encoding['offset_mapping'],
)

训练循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, scheduler, device):
model.train()
total_loss = 0

for batch in tqdm(dataloader, desc="Training"):
batch = {k: v.to(device) for k, v in batch.items()}

outputs = model(**batch)
loss = outputs['loss']

loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
scheduler.step()
optimizer.zero_grad()

total_loss += loss.item()

return total_loss / len(dataloader)

def evaluate(model, dataloader, device):
model.eval()
predictions = []

with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
batch = {k: v.to(device) for k, v in batch.items()}

outputs = model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
token_type_ids=batch.get('token_type_ids'),
)

spans = model.decode(
outputs['start_logits'],
outputs['end_logits'],
batch['attention_mask'],
)
predictions.extend(spans)

return predictions

# 主训练流程
def main():
# 配置
model_name = "bert-base-chinese"
batch_size = 16
learning_rate = 3e-5
num_epochs = 3

# 初始化
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MRCModel(model_name).to(device)

# 优化器
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=500,
num_training_steps=num_epochs * len(train_dataloader),
)

# 训练
for epoch in range(num_epochs):
loss = train_epoch(model, train_dataloader, optimizer, scheduler, device)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

# 验证
predictions = evaluate(model, val_dataloader, device)
f1 = compute_f1(predictions, val_labels)
print(f"Validation F1: {f1:.4f}")

评估指标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import re
import string
from collections import Counter

def normalize_answer(s: str) -> str:
"""标准化答案文本"""
s = s.lower()
s = re.sub(r'\b(a|an|the)\b', ' ', s)
s = ''.join(ch for ch in s if ch not in string.punctuation)
s = ' '.join(s.split())
return s

def compute_f1(prediction: str, ground_truth: str) -> float:
pred_tokens = normalize_answer(prediction).split()
gold_tokens = normalize_answer(ground_truth).split()

if not pred_tokens or not gold_tokens:
return int(pred_tokens == gold_tokens)

common = Counter(pred_tokens) & Counter(gold_tokens)
num_same = sum(common.values())

precision = num_same / len(pred_tokens)
recall = num_same / len(gold_tokens)

if precision + recall == 0:
return 0

return 2 * precision * recall / (precision + recall)

def compute_em(prediction: str, ground_truth: str) -> float:
return float(normalize_answer(prediction) == normalize_answer(ground_truth))

与现代方法对比

方面 经典 MRC (BiDAF) BERT-based LLM-based
参数量 ~2M 110M-340M 7B-70B+
训练数据 Task-specific 预训练+微调 大规模预训练
推理方式 Span extraction Span extraction Generation
长文档 需要切分 需要切分 更大上下文窗口
多跳推理 困难 有限 较好

生产环境优化

量化推理

1
2
3
4
5
6
7
8
import torch.quantization as quant

# 动态量化
model_int8 = quant.quantize_dynamic(
model.cpu(),
{nn.Linear},
dtype=torch.qint8
)

ONNX 导出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.onnx

dummy_input = {
'input_ids': torch.ones(1, 512, dtype=torch.long),
'attention_mask': torch.ones(1, 512, dtype=torch.long),
'token_type_ids': torch.zeros(1, 512, dtype=torch.long),
}

torch.onnx.export(
model,
(dummy_input['input_ids'], dummy_input['attention_mask'], dummy_input['token_type_ids']),
"mrc_model.onnx",
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_names=['start_logits', 'end_logits'],
dynamic_axes={
'input_ids': {0: 'batch', 1: 'seq'},
'attention_mask': {0: 'batch', 1: 'seq'},
'token_type_ids': {0: 'batch', 1: 'seq'},
}
)

延伸阅读


转载请注明出处

本文整理了 NLP 领域的学习路线,结合经典理论与现代大语言模型技术。

推荐学习资源

经典教材

书籍 内容 难度
Speech and Language Processing (Jurafsky) NLP 全面综述 ⭐⭐
Introduction to Information Retrieval 信息检索基础 ⭐⭐
Pattern Recognition and Machine Learning 机器学习理论 ⭐⭐⭐⭐
Deep Learning (Goodfellow) 深度学习基础 ⭐⭐⭐

现代资源

阶段一:NLP 基础

语言模型基础

N-gram 模型:N-1 阶马尔可夫假设

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from collections import defaultdict
import numpy as np

class NGramLM:
def __init__(self, n=3):
self.n = n
self.counts = defaultdict(lambda: defaultdict(int))
self.totals = defaultdict(int)

def train(self, corpus):
for sentence in corpus:
tokens = ['<s>'] * (self.n - 1) + sentence + ['</s>']
for i in range(len(tokens) - self.n + 1):
context = tuple(tokens[i:i+self.n-1])
word = tokens[i+self.n-1]
self.counts[context][word] += 1
self.totals[context] += 1

def probability(self, word, context):
context = tuple(context[-(self.n-1):])
return self.counts[context][word] / max(self.totals[context], 1)

词向量

从 One-hot 到 Dense Embedding 的演进:

方法 年份 特点
One-hot - 稀疏,无语义
Word2Vec 2013 分布式表示
GloVe 2014 全局统计
FastText 2016 子词信息
ELMo 2018 上下文相关
BERT 2018 双向上下文

阶段二:深度学习 NLP

Transformer 架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads

self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)

# Linear projections
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)

# Concatenate and project
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k)
return self.W_o(output)

注意力机制的数学表达

阶段三:大语言模型

LLM 架构演进

1
2
3
4
5
6
7
GPT-1 (2018) → GPT-2 → GPT-3 → ChatGPT → GPT-4

BERT → RoBERTa → DeBERTa

T5 → Flan-T5 → UL2

LLaMA → LLaMA 2 → Mistral → Mixtral

Prompt Engineering

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 1. Zero-shot
prompt = "Translate to French: Hello, how are you?"

# 2. Few-shot
prompt = """Translate to French:
Hello -> Bonjour
Goodbye -> Au revoir
How are you? ->"""

# 3. Chain-of-Thought
prompt = """Q: If I have 3 apples and buy 5 more, how many do I have?
A: Let's think step by step.
1. I start with 3 apples.
2. I buy 5 more apples.
3. Total = 3 + 5 = 8 apples.
The answer is 8.

Q: If I have 7 oranges and eat 2, how many remain?
A: Let's think step by step."""

Fine-tuning 技术

方法 可训练参数 适用场景
Full Fine-tuning 100% 大量数据,充足算力
LoRA 0.1-1% 资源受限
QLoRA 0.1% 消费级 GPU
Prefix Tuning 0.1% 多任务
Prompt Tuning <0.01% 极端资源受限
1
2
3
4
5
6
7
8
9
10
11
12
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
)

model = get_peft_model(base_model, lora_config)
print(f"Trainable params: {model.print_trainable_parameters()}")

阶段四:高级主题

检索增强生成 (RAG)

1
2
3
4
5
6
7
8
9
10
11
12
13
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA

# 构建向量库
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh")
vectorstore = Chroma.from_documents(documents, embeddings)

# 创建 RAG 链
qa = RetrievalQA.from_chain_type(
llm=llm,
retriever=vectorstore.as_retriever(search_kwargs={"k": 3})
)

模型评估

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 困惑度 (Perplexity)
def perplexity(model, tokenizer, text):
encodings = tokenizer(text, return_tensors='pt')
max_length = model.config.n_positions

nlls = []
for i in range(0, encodings.input_ids.size(1), max_length):
begin_loc = max(i - max_length, 0)
end_loc = i + max_length
input_ids = encodings.input_ids[:, begin_loc:end_loc]
target_ids = input_ids.clone()
target_ids[:, :-1] = -100

with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
nlls.append(outputs.loss)

return torch.exp(torch.stack(nlls).mean())

实践项目建议

  1. 入门:情感分析、文本分类
  2. 进阶:命名实体识别、机器翻译
  3. 高级:问答系统、RAG 应用
  4. 专家:LLM 预训练、RLHF

延伸阅读


转载请注明出处

核心问题:当我们期望机器”理解”文本时,我们的期望到底是什么?

机器阅读理解的演进

传统 MRC (2015-2019)

基于 span extraction 的方法:

1
2
输入: Context + Question
输出: (start_idx, end_idx)

代表模型:BiDAF, R-Net, QANet, BERT

LLM 时代的 MRC (2020-至今)

从”抽取”到”生成”的范式转变:

1
2
输入: Context + Question + Instruction
输出: 自由形式的答案

任务分类与难度

类型 传统方法 LLM 方法 难度
抽取式 ✅ 擅长 ✅ 擅长
多跳推理 ❌ 困难 ⚠️ 有限 ⭐⭐⭐
数值推理 ❌ 几乎不能 ⚠️ 需要 CoT ⭐⭐⭐⭐
常识推理 ❌ 不能 ✅ 较好 ⭐⭐⭐
开放生成 ❌ 不能 ✅ 擅长 ⭐⭐

现代方法:RAG

检索增强生成 (Retrieval-Augmented Generation) 结合了检索和生成的优势:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class RAGSystem:
def __init__(self, retriever, generator):
self.retriever = retriever # e.g., Dense Retriever
self.generator = generator # e.g., LLM

def answer(self, question: str) -> str:
# 1. 检索相关文档
docs = self.retriever.retrieve(question, top_k=5)

# 2. 构建上下文
context = "\n\n".join([d.text for d in docs])

# 3. 生成答案
prompt = f"""基于以下文档回答问题:

{context}

问题:{question}
答案:"""

return self.generator.generate(prompt)

检索器选择

检索器 特点 适用场景
BM25 关键词匹配,快速 短查询,精确匹配
Dense Retriever 语义匹配 语义相似查询
ColBERT 延迟交互 平衡效率与效果
Hybrid 结合稀疏+稠密 生产环境

Chain-of-Thought 推理

对于需要推理的问题,CoT prompting 显著提升效果:

1
2
3
4
5
6
7
8
9
10
# 标准 Prompting
prompt_standard = "Q: 小明有5个苹果,给了小红2个,还剩几个?\nA:"

# Chain-of-Thought Prompting
prompt_cot = """Q: 小明有5个苹果,给了小红2个,还剩几个?
A: 让我们一步步思考:
1. 小明最初有 5 个苹果
2. 他给了小红 2 个苹果
3. 剩余苹果数 = 5 - 2 = 3
答案是 3 个苹果。"""

评估指标

传统指标

𝟙

LLM 时代的指标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 使用 LLM 作为评估器
def llm_evaluate(question, gold_answer, pred_answer):
prompt = f"""评估预测答案的质量(1-5分):

问题:{question}
标准答案:{gold_answer}
预测答案:{pred_answer}

评分标准:
5分 - 完全正确且信息完整
4分 - 基本正确,略有遗漏
3分 - 部分正确
2分 - 有相关信息但不正确
1分 - 完全错误

分数:"""
return llm.generate(prompt)

实践建议

何时用传统 MRC

  • 答案明确在文档中
  • 需要精确的位置标注
  • 低延迟要求
  • 资源受限

何时用 RAG + LLM

  • 需要整合多个文档
  • 答案需要推理或总结
  • 开放域问答
  • 用户期望自然语言回答

代码示例:现代 RAG 系统

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA

# 初始化组件
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.load_local("my_index", embeddings)
llm = ChatOpenAI(model="gpt-4", temperature=0)

# 创建 RAG 链
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # 或 "map_reduce", "refine"
retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
return_source_documents=True
)

# 使用
result = qa_chain({"query": "什么是机器阅读理解?"})
print(result["result"])

延伸阅读


转载请注明出处

因果推断是机器学习领域的重要研究方向,特别是在大语言模型时代,理解因果关系对于构建可解释、可信赖的 AI 系统至关重要。

为什么需要因果推断?

传统机器学习依赖相关性,但相关性不等于因果性。例如:

  • 冰淇淋销量与溺水事件正相关(共同原因:夏天)
  • LLM 可能学到虚假相关性,导致 hallucination

因果推断帮助我们:

  1. 理解干预效果(如果我做 X,会发生什么?)
  2. 进行反事实推理(如果当时做了 Y,结果会怎样?)
  3. 构建更鲁棒的模型

核心概念

因果图 (Causal Graph)

使用有向无环图 (DAG) 表示变量之间的因果关系:

1
2
3
X → Y → Z    (链式结构)
X ← W → Y (混杂结构)
X → W ← Y (对撞结构)

结构因果模型 (SCM)

其中 是原因, 是结果, 是噪声项。

do 算子与干预

区分观测和干预:

  • 观测 — 看到 X=x 时 Y 的分布
  • 干预 — 强制设置 X=x 时 Y 的分布

因果发现算法

PC 算法

基于条件独立性检验的经典算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# PC 算法伪代码
def pc_algorithm(data, alpha=0.05):
# 1. 初始化完全图
G = complete_graph(variables)

# 2. 骨架学习:移除条件独立的边
for (X, Y) in edges(G):
for S in subsets(neighbors):
if conditional_independent(X, Y, S, alpha):
remove_edge(G, X, Y)
sep_set[X, Y] = S

# 3. 方向确定:识别 v-structure
orient_v_structures(G, sep_set)

return G

Python 实现参考:fooSynaptic/py_pcalg

现代方法

方法 特点 适用场景
NOTEARS 连续优化,可微分 线性/非线性因果发现
DAG-GNN 基于图神经网络 大规模因果图学习
Causal Transformer 结合注意力机制 时序因果推断

因果推断与大语言模型

LLM 中的因果问题

  1. Hallucination:模型学到虚假相关性
  2. Bias:训练数据中的混杂因素
  3. Robustness:分布外泛化能力差

解决方案

1
2
3
4
5
6
7
8
9
# 因果提示 (Causal Prompting) 示例
prompt = """
请分析以下事件的因果关系,而非相关性:

事件A: 公司增加广告投入
事件B: 销售额上升

问:A 是否导致了 B?请考虑可能的混杂因素。
"""

因果推理增强 RAG

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class CausalRAG:
def __init__(self, retriever, causal_graph):
self.retriever = retriever
self.causal_graph = causal_graph

def retrieve(self, query):
# 1. 识别查询中的因果关系
cause, effect = extract_causal_pair(query)

# 2. 基于因果图过滤无关文档
relevant_vars = self.causal_graph.ancestors(effect)

# 3. 检索因果相关的文档
docs = self.retriever.search(query)
return filter_by_causal_relevance(docs, relevant_vars)

工具与资源

工具 语言 功能
DoWhy Python 因果推断框架
CausalNex Python 贝叶斯网络 + 因果发现
pgmpy Python 概率图模型
Tetrad Java 因果搜索算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# DoWhy 示例
import dowhy
from dowhy import CausalModel

model = CausalModel(
data=df,
treatment='treatment',
outcome='outcome',
graph='digraph {treatment -> outcome; confounder -> treatment; confounder -> outcome}'
)

# 识别因果效应
identified = model.identify_effect()

# 估计因果效应
estimate = model.estimate_effect(identified, method_name="backdoor.propensity_score_matching")

延伸阅读


转载请注明出处

矩阵分解是机器学习的基石技术,从传统的推荐系统到现代大语言模型的参数高效微调(LoRA),都离不开矩阵分解的思想。

奇异值分解 (SVD)

基本形式

任意矩阵 可以分解为:

其中:

  • :左奇异向量(正交矩阵)
  • :奇异值对角矩阵
  • :右奇异向量(正交矩阵)

Truncated SVD

保留前 个最大奇异值:

这是最优的秩 近似(Eckart-Young 定理):

Randomized SVD

当矩阵规模巨大时,精确 SVD 计算代价过高。Randomized SVD 提供了高效的近似方法。

算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
from scipy import linalg

def randomized_svd(A, n_components, n_oversamples=10, n_iter=4):
"""
Randomized SVD for large matrices.

Args:
A: Input matrix (m x n)
n_components: Number of singular values to compute
n_oversamples: Additional random vectors for accuracy
n_iter: Number of power iterations

Returns:
U, s, Vt: Truncated SVD components
"""
m, n = A.shape
n_random = n_components + n_oversamples

# Step 1: Random projection
Q = np.random.randn(n, n_random)

# Step 2: Power iteration for accuracy
for _ in range(n_iter):
Q, _ = linalg.lu(A @ Q, permute_l=True)
Q, _ = linalg.lu(A.T @ Q, permute_l=True)

Q, _ = linalg.qr(A @ Q, mode='economic')

# Step 3: Project and compute SVD
B = Q.T @ A
Uhat, s, Vt = linalg.svd(B, full_matrices=False)
U = Q @ Uhat

return U[:, :n_components], s[:n_components], Vt[:n_components, :]

复杂度对比

方法 时间复杂度 空间复杂度
精确 SVD
Randomized SVD
Truncated SVD (Lanczos)

现代应用:LoRA

LoRA (Low-Rank Adaptation) 是大语言模型参数高效微调的核心技术,直接利用了低秩分解的思想。

LoRA 原理

预训练权重 固定,只训练低秩增量:

其中

实现示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, alpha=1.0):
super().__init__()
self.rank = rank
self.alpha = alpha

# 原始权重(冻结)
self.W = nn.Linear(in_features, out_features, bias=False)
self.W.weight.requires_grad = False

# 低秩分解
self.A = nn.Linear(in_features, rank, bias=False)
self.B = nn.Linear(rank, out_features, bias=False)

# 初始化
nn.init.kaiming_uniform_(self.A.weight)
nn.init.zeros_(self.B.weight)

self.scaling = alpha / rank

def forward(self, x):
# W(x) + scaling * B(A(x))
return self.W(x) + self.scaling * self.B(self.A(x))

参数效率

对于 LLaMA-7B:

方法 可训练参数 显存占用
全量微调 7B (100%) ~140GB
LoRA (r=8) 4.7M (0.07%) ~14GB
LoRA (r=16) 9.4M (0.13%) ~16GB

其他应用

1. 推荐系统

矩阵分解用于协同过滤:

1
2
3
4
5
6
7
# 使用 surprise 库
from surprise import SVD, Dataset, Reader

reader = Reader(rating_scale=(1, 5))
data = Dataset.load_from_df(df[['user', 'item', 'rating']], reader)
model = SVD(n_factors=100)
model.fit(trainset)

2. 文本表示 (LSA)

潜在语义分析:

1
2
3
4
5
6
7
8
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=10000)
X = vectorizer.fit_transform(documents)

svd = TruncatedSVD(n_components=100)
X_reduced = svd.fit_transform(X)

3. 图像压缩

1
2
3
4
5
6
7
8
9
10
11
from PIL import Image
import numpy as np

def compress_image(image_path, n_components=50):
img = np.array(Image.open(image_path).convert('L'))
U, s, Vt = np.linalg.svd(img, full_matrices=False)

# 保留前 n_components 个奇异值
compressed = U[:, :n_components] @ np.diag(s[:n_components]) @ Vt[:n_components, :]

return compressed.astype(np.uint8)

数值稳定性

条件数

条件数过大会导致数值不稳定。

正则化 SVD

1
2
3
4
5
def regularized_svd(A, lambda_reg=0.01):
"""Add regularization for numerical stability."""
U, s, Vt = np.linalg.svd(A, full_matrices=False)
s_reg = s / (s**2 + lambda_reg)
return U, s_reg, Vt

延伸阅读

  • Halko et al., Finding Structure with Randomness (2011)
  • Hu et al., LoRA: Low-Rank Adaptation of Large Language Models (2021)
  • NumPy SVD Documentation

转载请注明出处

各位读者朋友们大家好,我是 fooSynaptic。

欢迎来到我的技术博客!这里记录我在 AI 和 NLP 领域的学习与思考。

关于这个博客

这个博客主要记录以下内容:

  • 自然语言处理 (NLP):从传统方法到大语言模型
  • 机器学习:算法原理与实现细节
  • 深度学习:模型架构与训练技巧
  • 数学基础:线性代数、概率论、优化理论
  • 工程实践:Python、PyTorch、分布式训练

技术栈

1
2
3
NLP: Transformers, LLMs, RAG, Prompt Engineering
ML: PyTorch, JAX, scikit-learn
Infra: CUDA, Triton, vLLM, DeepSpeed

关于我

NLP Researcher,专注于:

  • 大语言模型 (LLM) 训练与推理优化
  • 检索增强生成 (RAG)
  • 机器阅读理解与问答系统

GitHub: fooSynaptic


欢迎交流讨论,转载请注明出处

0%