机器阅读理解实战:从零构建问答系统

本文从零开始实现一个机器阅读理解系统,涵盖数据处理、模型构建、训练和推理的完整流程。

任务定义

给定上下文 和问题 ,预测答案 中的位置:

数据处理

SQuAD 数据格式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class Example:
context: str
question: str
answer_text: str
start_position: int
end_position: int

def load_squad(file_path: str) -> List[Example]:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)

examples = []
for article in data['data']:
for paragraph in article['paragraphs']:
context = paragraph['context']
for qa in paragraph['qas']:
question = qa['question']
if qa.get('is_impossible', False):
continue
answer = qa['answers'][0]
examples.append(Example(
context=context,
question=question,
answer_text=answer['text'],
start_position=answer['answer_start'],
end_position=answer['answer_start'] + len(answer['text'])
))

return examples

Tokenization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoTokenizer

class MRCTokenizer:
def __init__(self, model_name: str, max_length: int = 384, doc_stride: int = 128):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self.doc_stride = doc_stride

def encode(self, example: Example):
# Tokenize question and context
encoding = self.tokenizer(
example.question,
example.context,
max_length=self.max_length,
truncation='only_second',
stride=self.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding='max_length',
)

# 找到答案在 token 序列中的位置
offset_mapping = encoding['offset_mapping'][0]

start_token = None
end_token = None

for idx, (start, end) in enumerate(offset_mapping):
if start <= example.start_position < end:
start_token = idx
if start < example.end_position <= end:
end_token = idx
break

return {
'input_ids': encoding['input_ids'][0],
'attention_mask': encoding['attention_mask'][0],
'start_position': start_token or 0,
'end_position': end_token or 0,
}

模型实现

基于 BERT 的 MRC 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
from transformers import AutoModel

class MRCModel(nn.Module):
def __init__(self, model_name: str, dropout: float = 0.1):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
hidden_size = self.bert.config.hidden_size

self.dropout = nn.Dropout(dropout)
self.start_classifier = nn.Linear(hidden_size, 1)
self.end_classifier = nn.Linear(hidden_size, 1)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)

# (batch, seq_len, 1) -> (batch, seq_len)
start_logits = self.start_classifier(sequence_output).squeeze(-1)
end_logits = self.end_classifier(sequence_output).squeeze(-1)

# Mask padding tokens
start_logits = start_logits.masked_fill(~attention_mask.bool(), -1e9)
end_logits = end_logits.masked_fill(~attention_mask.bool(), -1e9)

loss = None
if start_positions is not None and end_positions is not None:
loss_fct = nn.CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
loss = (start_loss + end_loss) / 2

return {
'loss': loss,
'start_logits': start_logits,
'end_logits': end_logits,
}

改进:联合 Start-End 预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class JointMRCModel(nn.Module):
"""联合预测 start 和 end,考虑 start-end 依赖"""

def __init__(self, model_name: str, max_answer_length: int = 30):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
hidden_size = self.bert.config.hidden_size
self.max_answer_length = max_answer_length

self.start_classifier = nn.Linear(hidden_size, 1)
self.end_classifier = nn.Linear(hidden_size * 2, 1)

def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
H = outputs.last_hidden_state # (batch, seq_len, hidden)

# Start prediction
start_logits = self.start_classifier(H).squeeze(-1)

if self.training and start_positions is not None:
# 训练时使用真实的 start 位置
start_indices = start_positions.unsqueeze(-1).unsqueeze(-1)
start_repr = H.gather(1, start_indices.expand(-1, -1, H.size(-1))).squeeze(1)
else:
# 推理时使用预测的 start 位置
start_indices = start_logits.argmax(dim=-1, keepdim=True).unsqueeze(-1)
start_repr = H.gather(1, start_indices.expand(-1, -1, H.size(-1))).squeeze(1)

# End prediction conditioned on start
start_repr_expanded = start_repr.unsqueeze(1).expand(-1, H.size(1), -1)
end_input = torch.cat([H, start_repr_expanded], dim=-1)
end_logits = self.end_classifier(end_input).squeeze(-1)

# 只允许 end >= start 且在 max_answer_length 范围内
# 这里简化处理,完整实现需要更复杂的 mask

return {'start_logits': start_logits, 'end_logits': end_logits}

训练流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torch.utils.data import DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

def train(model, train_dataloader, val_dataloader, epochs=3, lr=3e-5):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

best_f1 = 0
for epoch in range(epochs):
model.train()
total_loss = 0

for batch in tqdm(train_dataloader, desc=f'Epoch {epoch+1}'):
batch = {k: v.to(device) for k, v in batch.items()}

outputs = model(**batch)
loss = outputs['loss']

loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
scheduler.step()
optimizer.zero_grad()

total_loss += loss.item()

avg_loss = total_loss / len(train_dataloader)
print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

# Validation
f1 = evaluate(model, val_dataloader, device)
print(f'Validation F1: {f1:.4f}')

if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), 'best_model.pt')

return model

评估与推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import re
import string
from collections import Counter

def normalize_answer(s):
"""标准化答案用于评估"""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_f1(pred: str, gold: str) -> float:
pred_tokens = normalize_answer(pred).split()
gold_tokens = normalize_answer(gold).split()

common = Counter(pred_tokens) & Counter(gold_tokens)
num_same = sum(common.values())

if num_same == 0:
return 0

precision = num_same / len(pred_tokens)
recall = num_same / len(gold_tokens)

return 2 * precision * recall / (precision + recall)

def predict(model, tokenizer, context: str, question: str, device):
"""单条推理"""
model.eval()

encoding = tokenizer(
question, context,
max_length=384,
truncation='only_second',
return_tensors='pt'
)

encoding = {k: v.to(device) for k, v in encoding.items()}

with torch.no_grad():
outputs = model(**encoding)

start_idx = outputs['start_logits'].argmax().item()
end_idx = outputs['end_logits'].argmax().item()

# 确保 end >= start
if end_idx < start_idx:
end_idx = start_idx

# 解码答案
answer_tokens = encoding['input_ids'][0][start_idx:end_idx+1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

return answer

现代方法:使用 LLM

对于更复杂的问答需求,可以使用 LLM:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from openai import OpenAI

def llm_qa(context: str, question: str) -> str:
client = OpenAI()

response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个问答助手。根据给定的上下文回答问题。如果答案不在上下文中,请说'无法回答'。"},
{"role": "user", "content": f"上下文:{context}\n\n问题:{question}"}
],
temperature=0
)

return response.choices[0].message.content

延伸阅读


转载请注明出处