1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
| import torch import torch.nn as nn
class CRF(nn.Module): def __init__(self, num_tags, batch_first=True): super().__init__() self.num_tags = num_tags self.batch_first = batch_first self.transitions = nn.Parameter(torch.randn(num_tags, num_tags)) self.start_transitions = nn.Parameter(torch.randn(num_tags)) self.end_transitions = nn.Parameter(torch.randn(num_tags)) def forward(self, emissions, tags, mask=None): """计算负对数似然损失""" if mask is None: mask = torch.ones_like(tags, dtype=torch.bool) if self.batch_first: emissions = emissions.transpose(0, 1) tags = tags.transpose(0, 1) mask = mask.transpose(0, 1) numerator = self._compute_score(emissions, tags, mask) denominator = self._compute_normalizer(emissions, mask) return (denominator - numerator).mean() def _compute_score(self, emissions, tags, mask): """计算给定标签序列的分数""" seq_len, batch_size = tags.shape score = self.start_transitions[tags[0]] score += emissions[0, torch.arange(batch_size), tags[0]] for i in range(1, seq_len): score += self.transitions[tags[i], tags[i-1]] * mask[i] score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] last_tag_idx = mask.sum(dim=0) - 1 last_tags = tags.gather(0, last_tag_idx.unsqueeze(0)).squeeze(0) score += self.end_transitions[last_tags] return score def _compute_normalizer(self, emissions, mask): """前向算法计算配分函数""" seq_len, batch_size, num_tags = emissions.shape score = self.start_transitions + emissions[0] for i in range(1, seq_len): broadcast_score = score.unsqueeze(2) broadcast_emissions = emissions[i].unsqueeze(1) next_score = broadcast_score + self.transitions + broadcast_emissions next_score = torch.logsumexp(next_score, dim=1) score = torch.where(mask[i].unsqueeze(1), next_score, score) score += self.end_transitions return torch.logsumexp(score, dim=1) def decode(self, emissions, mask=None): """Viterbi 解码""" if mask is None: mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) if self.batch_first: emissions = emissions.transpose(0, 1) mask = mask.transpose(0, 1) return self._viterbi_decode(emissions, mask) def _viterbi_decode(self, emissions, mask): """Viterbi 算法""" seq_len, batch_size, num_tags = emissions.shape score = self.start_transitions + emissions[0] history = [] for i in range(1, seq_len): broadcast_score = score.unsqueeze(2) broadcast_emissions = emissions[i].unsqueeze(1) next_score = broadcast_score + self.transitions + broadcast_emissions next_score, indices = next_score.max(dim=1) score = torch.where(mask[i].unsqueeze(1), next_score, score) history.append(indices) score += self.end_transitions best_tags_list = [] _, best_last_tag = score.max(dim=1) for idx in range(batch_size): best_tags = [best_last_tag[idx].item()] seq_length = int(mask[:, idx].sum().item()) for hist in reversed(history[:seq_length-1]): best_last_tag_idx = best_tags[-1] best_tags.append(hist[idx, best_last_tag_idx].item()) best_tags.reverse() best_tags_list.append(best_tags) return best_tags_list
|