1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| class JointMRCModel(nn.Module): """联合预测 start 和 end,考虑 start-end 依赖""" def __init__(self, model_name: str, max_answer_length: int = 30): super().__init__() self.bert = AutoModel.from_pretrained(model_name) hidden_size = self.bert.config.hidden_size self.max_answer_length = max_answer_length self.start_classifier = nn.Linear(hidden_size, 1) self.end_classifier = nn.Linear(hidden_size * 2, 1) def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) H = outputs.last_hidden_state start_logits = self.start_classifier(H).squeeze(-1) if self.training and start_positions is not None: start_indices = start_positions.unsqueeze(-1).unsqueeze(-1) start_repr = H.gather(1, start_indices.expand(-1, -1, H.size(-1))).squeeze(1) else: start_indices = start_logits.argmax(dim=-1, keepdim=True).unsqueeze(-1) start_repr = H.gather(1, start_indices.expand(-1, -1, H.size(-1))).squeeze(1) start_repr_expanded = start_repr.unsqueeze(1).expand(-1, H.size(1), -1) end_input = torch.cat([H, start_repr_expanded], dim=-1) end_logits = self.end_classifier(end_input).squeeze(-1) return {'start_logits': start_logits, 'end_logits': end_logits}
|