因果关系推断介绍

因果推断是机器学习领域的重要研究方向,特别是在大语言模型时代,理解因果关系对于构建可解释、可信赖的 AI 系统至关重要。

为什么需要因果推断?

传统机器学习依赖相关性,但相关性不等于因果性。例如:

  • 冰淇淋销量与溺水事件正相关(共同原因:夏天)
  • LLM 可能学到虚假相关性,导致 hallucination

因果推断帮助我们:

  1. 理解干预效果(如果我做 X,会发生什么?)
  2. 进行反事实推理(如果当时做了 Y,结果会怎样?)
  3. 构建更鲁棒的模型

核心概念

因果图 (Causal Graph)

使用有向无环图 (DAG) 表示变量之间的因果关系:

1
2
3
X → Y → Z    (链式结构)
X ← W → Y (混杂结构)
X → W ← Y (对撞结构)

结构因果模型 (SCM)

其中 是原因, 是结果, 是噪声项。

do 算子与干预

区分观测和干预:

  • 观测 — 看到 X=x 时 Y 的分布
  • 干预 — 强制设置 X=x 时 Y 的分布

因果发现算法

PC 算法

基于条件独立性检验的经典算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# PC 算法伪代码
def pc_algorithm(data, alpha=0.05):
# 1. 初始化完全图
G = complete_graph(variables)

# 2. 骨架学习:移除条件独立的边
for (X, Y) in edges(G):
for S in subsets(neighbors):
if conditional_independent(X, Y, S, alpha):
remove_edge(G, X, Y)
sep_set[X, Y] = S

# 3. 方向确定:识别 v-structure
orient_v_structures(G, sep_set)

return G

Python 实现参考:fooSynaptic/py_pcalg

现代方法

方法 特点 适用场景
NOTEARS 连续优化,可微分 线性/非线性因果发现
DAG-GNN 基于图神经网络 大规模因果图学习
Causal Transformer 结合注意力机制 时序因果推断

因果推断与大语言模型

LLM 中的因果问题

  1. Hallucination:模型学到虚假相关性
  2. Bias:训练数据中的混杂因素
  3. Robustness:分布外泛化能力差

解决方案

1
2
3
4
5
6
7
8
9
# 因果提示 (Causal Prompting) 示例
prompt = """
请分析以下事件的因果关系,而非相关性:

事件A: 公司增加广告投入
事件B: 销售额上升

问:A 是否导致了 B?请考虑可能的混杂因素。
"""

因果推理增强 RAG

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class CausalRAG:
def __init__(self, retriever, causal_graph):
self.retriever = retriever
self.causal_graph = causal_graph

def retrieve(self, query):
# 1. 识别查询中的因果关系
cause, effect = extract_causal_pair(query)

# 2. 基于因果图过滤无关文档
relevant_vars = self.causal_graph.ancestors(effect)

# 3. 检索因果相关的文档
docs = self.retriever.search(query)
return filter_by_causal_relevance(docs, relevant_vars)

工具与资源

工具 语言 功能
DoWhy Python 因果推断框架
CausalNex Python 贝叶斯网络 + 因果发现
pgmpy Python 概率图模型
Tetrad Java 因果搜索算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# DoWhy 示例
import dowhy
from dowhy import CausalModel

model = CausalModel(
data=df,
treatment='treatment',
outcome='outcome',
graph='digraph {treatment -> outcome; confounder -> treatment; confounder -> outcome}'
)

# 识别因果效应
identified = model.identify_effect()

# 估计因果效应
estimate = model.estimate_effect(identified, method_name="backdoor.propensity_score_matching")

延伸阅读


转载请注明出处