Paper Summary of BART
本次阅读了论文 BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
语言模型预训练+下游任务fine-tune
- 用任意噪声函数破坏文本:随机打乱句子顺序;将文本替换为单个掩码令牌etc
- 学习模型来重建原始文本。
BART是一个encoder-decoder的结构(bidirectional noising encoder(bert) + left-right decoder(GPT)),其encoder端的输入是加了噪音的序列,decoder端的输入是right-shifted的序列,decoder端的目标是原序列。模型设计的目的很明确,就是在利用encoder端的双向建模能力的同时,保留自回归的特性,以适用于生成任务。
噪声方式:
- Token Masking: 就是BERT的方法,随机将token替换成[MASK]
- Token Deletion: 随机删去token
- Text Infilling: 随机将一段连续的token(称作span)替换成一个[MASK],span的长度服从 λ=3的泊松分布。注意span长度为0就相当于插入一个[MASK]。
- Sentence Permutation: 将一个document的句子打乱
- Document Rotation: 从document序列中随机选择一个token,然后使得该token作为document的开头
以上方式进行组合。
利用fairseq库中的BARTModel对文本预测填空进行尝试:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torchtext
# bart = torch.hub.load('pytorch/fairseq', 'bart.base')
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained('<path-of-bart.large>', checkpoint_file='model.pt')
# bart.cuda()
bart.eval()
# 定义一个文本字段
# text_field = torchtext.data.Field(tokenize='spacy', batch_first=True)
# 将字符串列表传递给字段的 process 方法以生成张量
# tensors = text_field.process(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.']).cuda()
# s = torch.Tensor(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.']).cuda()
print(bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=5, beam=20))
# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
topk
参数指定要返回的候选预测中的最高分数的数量。例如,如果您将topk=3
设置为bart.fill_mask()
方法,则将返回每个掩码位置的前三个候选预测。beam
参数是用于束搜索(beam search)的参数。它指定在查找最佳预测时要考虑的最佳候选预测的数量。例如,如果您将beam=10
设置为bart.fill_mask()
方法,则在查找最佳预测时将考虑最佳的前10个候选预测。
增加 topk
或 beam
参数值通常会增加计算成本。
1
2023-03-14 13:31:44 | INFO | fairseq.tasks.denoising | dictionary: 50264 types [[('The cat is still on the.', tensor(-1.6241)), ('The cat is sleeping on the.', tensor(-1.6287)), ('The cat is sitting on the.', tensor(-1.7007)), ('The cat is back on the.', tensor(-1.7144)), ('The cat is asleep on the.', tensor(-1.7318))], [('The dog is still on the.', tensor(-1.7400)), ('The dog jumped up on the.', tensor(-1.8462)), ('The dog was still on the.', tensor(-1.8523)), ('The dog is sleeping on the.', tensor(-1.8564)), ('The dog was not on the.', tensor(-1.8638))]]
只显示了第一个掩码的预测结果,这可能是因为第二个掩码的预测结果分数较低,未能通过 topk
和 beam
过滤器。
1
print(bart.fill_mask(['Snow on the <mask>.', 'She <mask> to the front with her <mask> on Sunday night.'], *topk*=5, *beam*=20))
1
[[('Snow on the ground.', tensor(-2.2770)), ('Snow on the way.', tensor(-2.3293)), ('Snow on the horizon.', tensor(-2.4373)), ('Snow on the roads.', tensor(-2.4572)), ('Snow on the mountains.', tensor(-2.5085))], [('She made her way to the front with her husband on Sunday', tensor(-1.2503)), ('She made her way to the front with her family on Sunday', tensor(-1.2704)), ('She returned to the front with her husband on Sunday night.', tensor(-1.3483)), ('She made her way to the front with her children on Sunday', tensor(-1.3583)), ('She made her way to the front with her daughter on Sunday', tensor(-1.3607))]]
参考:
https://zhuanlan.zhihu.com/p/173858031
fairseq/examples/bart at main · facebookresearch/fairseq (github.com)
[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension | Papers With Code](https://paperswithcode.com/paper/bart-denoising-sequence-to-sequence-pre) |