1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer
class MRCModel(nn.Module): """基于 Transformer 的 MRC 模型""" def __init__( self, model_name: str = "bert-base-chinese", dropout: float = 0.1, max_answer_length: int = 30 ): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) hidden_size = self.encoder.config.hidden_size self.max_answer_length = max_answer_length self.dropout = nn.Dropout(dropout) self.start_fc = nn.Linear(hidden_size, 1) self.end_fc = nn.Linear(hidden_size, 1) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, start_positions: torch.Tensor = None, end_positions: torch.Tensor = None, ): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) sequence_output = self.dropout(outputs.last_hidden_state) start_logits = self.start_fc(sequence_output).squeeze(-1) end_logits = self.end_fc(sequence_output).squeeze(-1) mask = attention_mask.bool() start_logits = start_logits.masked_fill(~mask, float('-inf')) end_logits = end_logits.masked_fill(~mask, float('-inf')) loss = None if start_positions is not None and end_positions is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) loss = (start_loss + end_loss) / 2 return { 'loss': loss, 'start_logits': start_logits, 'end_logits': end_logits, } def decode( self, start_logits: torch.Tensor, end_logits: torch.Tensor, attention_mask: torch.Tensor, ): """解码最佳答案 span""" batch_size, seq_len = start_logits.shape start_probs = F.softmax(start_logits, dim=-1) end_probs = F.softmax(end_logits, dim=-1) results = [] for b in range(batch_size): best_score = float('-inf') best_start, best_end = 0, 0 for start in range(seq_len): if not attention_mask[b, start]: continue for end in range(start, min(start + self.max_answer_length, seq_len)): if not attention_mask[b, end]: continue score = start_probs[b, start] * end_probs[b, end] if score > best_score: best_score = score best_start, best_end = start, end results.append((best_start, best_end)) return results
|