条件随机场:原理与实现

条件随机场 (CRF) 是序列标注的经典模型,尽管深度学习时代 BERT 等模型大放异彩,CRF 层仍然在 NER、词性标注等任务中发挥关键作用。

为什么需要 CRF?

独立分类的问题

如果对每个位置独立分类:

会导致标签不一致,例如:

1
2
3
输入: "北 京 是 中 国 首 都"
错误: B-LOC I-PER O B-LOC I-LOC I-LOC I-LOC
正确: B-LOC I-LOC O B-LOC I-LOC I-LOC I-LOC

CRF 的解决方案

CRF 建模整个序列的联合概率,考虑标签之间的转移约束

数学原理

条件概率

其中:

  • :发射分数(emission score)
  • :转移分数(transition score)
  • :配分函数(归一化项)

配分函数

直接计算复杂度为 ,使用前向算法可降至

PyTorch 实现

CRF Layer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn

class CRF(nn.Module):
def __init__(self, num_tags, batch_first=True):
super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first

# 转移矩阵: transitions[i, j] = 从标签 j 转移到标签 i 的分数
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))

# 起始和结束转移
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.end_transitions = nn.Parameter(torch.randn(num_tags))

def forward(self, emissions, tags, mask=None):
"""计算负对数似然损失"""
if mask is None:
mask = torch.ones_like(tags, dtype=torch.bool)

if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)

# 计算分子(正确路径的分数)
numerator = self._compute_score(emissions, tags, mask)

# 计算分母(配分函数)
denominator = self._compute_normalizer(emissions, mask)

# 负对数似然
return (denominator - numerator).mean()

def _compute_score(self, emissions, tags, mask):
"""计算给定标签序列的分数"""
seq_len, batch_size = tags.shape

# 起始分数
score = self.start_transitions[tags[0]]
score += emissions[0, torch.arange(batch_size), tags[0]]

for i in range(1, seq_len):
# 转移分数 + 发射分数
score += self.transitions[tags[i], tags[i-1]] * mask[i]
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

# 结束分数
last_tag_idx = mask.sum(dim=0) - 1
last_tags = tags.gather(0, last_tag_idx.unsqueeze(0)).squeeze(0)
score += self.end_transitions[last_tags]

return score

def _compute_normalizer(self, emissions, mask):
"""前向算法计算配分函数"""
seq_len, batch_size, num_tags = emissions.shape

# 初始化
score = self.start_transitions + emissions[0]

for i in range(1, seq_len):
# broadcast: (batch, num_tags, 1) + (num_tags, num_tags) + (batch, 1, num_tags)
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)

next_score = broadcast_score + self.transitions + broadcast_emissions
next_score = torch.logsumexp(next_score, dim=1)

# 应用 mask
score = torch.where(mask[i].unsqueeze(1), next_score, score)

# 添加结束分数
score += self.end_transitions

return torch.logsumexp(score, dim=1)

def decode(self, emissions, mask=None):
"""Viterbi 解码"""
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device)

if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)

return self._viterbi_decode(emissions, mask)

def _viterbi_decode(self, emissions, mask):
"""Viterbi 算法"""
seq_len, batch_size, num_tags = emissions.shape

# 初始化
score = self.start_transitions + emissions[0]
history = []

for i in range(1, seq_len):
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)

next_score = broadcast_score + self.transitions + broadcast_emissions
next_score, indices = next_score.max(dim=1)

score = torch.where(mask[i].unsqueeze(1), next_score, score)
history.append(indices)

# 添加结束分数
score += self.end_transitions

# 回溯
best_tags_list = []
_, best_last_tag = score.max(dim=1)

for idx in range(batch_size):
best_tags = [best_last_tag[idx].item()]
seq_length = int(mask[:, idx].sum().item())

for hist in reversed(history[:seq_length-1]):
best_last_tag_idx = best_tags[-1]
best_tags.append(hist[idx, best_last_tag_idx].item())

best_tags.reverse()
best_tags_list.append(best_tags)

return best_tags_list

与 BiLSTM 结合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_tags):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim // 2,
num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_tags)
self.crf = CRF(num_tags)

def forward(self, x, tags, mask=None):
embeddings = self.embedding(x)
lstm_out, _ = self.lstm(embeddings)
emissions = self.fc(lstm_out)

return self.crf(emissions, tags, mask)

def predict(self, x, mask=None):
embeddings = self.embedding(x)
lstm_out, _ = self.lstm(embeddings)
emissions = self.fc(lstm_out)

return self.crf.decode(emissions, mask)

现代应用:BERT + CRF

尽管 BERT 已经很强大,但 CRF 层仍能带来一致性提升:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import BertModel

class BERT_CRF(nn.Module):
def __init__(self, bert_name, num_tags):
super().__init__()
self.bert = BertModel.from_pretrained(bert_name)
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.bert.config.hidden_size, num_tags)
self.crf = CRF(num_tags)

def forward(self, input_ids, attention_mask, tags=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
emissions = self.fc(sequence_output)

if tags is not None:
return self.crf(emissions, tags, attention_mask.bool())
else:
return self.crf.decode(emissions, attention_mask.bool())

性能对比(CoNLL-2003 NER)

模型 F1
BiLSTM 88.2
BiLSTM + CRF 90.1
BERT 92.4
BERT + CRF 92.8
RoBERTa + CRF 93.2

训练技巧

1. 标签平滑

1
2
3
4
5
6
7
8
def label_smoothing_loss(crf, emissions, tags, mask, epsilon=0.1):
"""带标签平滑的 CRF 损失"""
nll_loss = crf(emissions, tags, mask)

# 均匀分布的损失
uniform_loss = -torch.logsumexp(emissions, dim=-1).mean()

return (1 - epsilon) * nll_loss + epsilon * uniform_loss

2. 约束解码

1
2
3
4
5
6
7
8
# 添加硬约束:B-X 后面只能接 I-X 或 O
def add_constraints(transitions, tag2idx):
for tag_from, idx_from in tag2idx.items():
for tag_to, idx_to in tag2idx.items():
if tag_from.startswith('B-') or tag_from.startswith('I-'):
entity = tag_from[2:]
if tag_to.startswith('I-') and tag_to[2:] != entity:
transitions.data[idx_to, idx_from] = -1e9

延伸阅读

  • Lafferty et al., Conditional Random Fields (2001)
  • Huang et al., Bidirectional LSTM-CRF Models for Sequence Tagging (2015)
  • pytorch-crf Documentation

转载请注明出处