Simple Contrastive Learning of Sentence Embeddings

《SimCSE: Simple Contrastive Learning of Sentence Embeddings》(SimCSE,2021年EMNLP)

摘要

  • 本文介绍SimCSE,一个简单的对比学习框架,极大促进了最优的句子嵌入
  • 首先使用一种无监督的方法,该方法输入语句并在对比目标中进行预测,学习过程中仅将标准dropout用作噪声
  • 从通过自然语言推理(NLI)数据集学习句子嵌入的工作中汲取灵感,并通过使用“蕴含”对作为肯定词,而“矛盾”对作为硬否定词,将NLI数据集的标注对纳入对比学习——有监督
  • 评估了标准语义文本相似性(STS)任务上的SimCSE,与之前的最佳结果相比,分别提高了7.9和4.6点

介绍

  • 提出SimCSE,一个简单的对比嵌入框架,它可以从无标记或有标记的数据中产生更好的句子嵌入

  • 无监督SimCSE只是简单地根据dropout预测输入句子本身,如图1a

    • 同一个句子传递给预训练的BERT两次——dropout两次——获得两个不同的嵌入作为“正对”(positive pairs
    • 把同一个batch中的其他句子作为“否定”(negatives),模型预测negatives中的positive
  • 有监督SimCSE建立在使用自然语言推理(NLI)数据集进行句子嵌入的基础上,在对比学习中加入了注释句子对,如图1b

    • 以前的工作将NLI视为一个3分类任务(蕴涵、中性和矛盾),而本文利用了蕴涵对可以自然地用作正实例的事实,本文还发现,添加相应的矛盾对作为硬否定(hard negatives)可以进一步提高性能

    image-20211215103320590

背景:对比学习

  • 对比学习旨在通过将语义上相近的邻居拉在一起并推开非邻居,来学习有效的表达

  • 假设有样本对集合:$D=(x_i,x_i^+)_{i=1}^m$,其中$x_i$与$x_i^+$在语义上近似。令$h_i$和$h_i^+$表示$x_i$与$x_i^+$的表达,则一个batch(batchsize为N)的训练目标为:

    image-20211215103305030

    • 其中$\tau$为温度超参数,$sim(h_1,h_2)$为余弦相似度
    • 本文工作中,通过BERT或RoBERTa对输入句子encode,然后使用上面的对比学习公式作为损失函数,微调所有参数
  • 对比学习中的一个关键问题在于如何构建正例对,即构建$x_i,x_i^+$,这里可以直接利用两次输入同一个句子(利用dropout机制)来构建;其他的构建方法包括,使用dual-encoder将当前句子和下一个句子(例如问答)构建为一个正例对

无监督SimCSE

  • 取一组句子$(x_i)_{i=1}^m$,令$x_i^+=x_i$。关键在于使用dropout来获得二者不同的encode表达,由于dropout的概率问题,每次dropout(记为$z,z’$)将mask不同的信息:

    image-20211215104217566
  • 本文对比了其他的数据扩充方法,如剪裁、单词删除和替换,但发现没有一个离散的方法的性能,会优于使用dropout noise

  • 本文还将该训练目标和“An efficient framework for learning sentence representations”的训练目标做了对比,本文的效果更好

有监督的SimCSE

  • 这里使用有监督的NLI(自然语言推理)数据集进行训练。数据集中,一条数据表示为:sentence1,sentence2,label(蕴含、中立、矛盾)

  • 直接从监督数据集中获取$(x_i,x_i^+)$

    • 对比了多个有监督的数据集,NLI数据集效果更好
    • 选择label为蕴含的两个句子作为正例对
  • 进一步利用NLI数据集:

    • NLI数据集中,给定一个句子,通常会有三个相关句子和对应的label(蕴含、中立、矛盾)

    • 取label为蕴含的sentence2为正例$x_i^+$,取label为矛盾的sentence2为负例$x_i^-$,此时一个batch内的训练目标

      image-20211215105500391

    • 本文还考虑在SimCSE中使用dual-encoder框架,但性能更差

实验结果

  • Sentence embedding在STS任务中的表现

    image-20211215105815885

结论

  • 提出SimCSE,大大提高语义文本相似性任务的句子嵌入性能
  • 本文的对比目标(训练目标),尤其是无监督目标,可能在自然语言处理中有更广泛的应用。它为文本输入的数据扩充提供了一个新的视角,并可以扩展到其他连续表示和集成在语言模型预训练中

关键代码(loss计算)

  • 无监督

    collator

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    class CSECollator(object):
    def __init__(self,
    tokenizer,
    features=("input_ids", "attention_mask", "token_type_ids"),
    max_len=100):
    self.tokenizer = tokenizer
    self.features = features
    self.max_len = max_len

    def collate(self, batch):
    new_batch = []
    for example in batch:
    for i in range(2):
    # 每个句子重复两次
    new_batch.append({fea: example[fea] for fea in self.features})
    new_batch = self.tokenizer.pad(
    new_batch,
    padding=True,
    max_length=self.max_len,
    return_tensors="pt"
    )
    return new_batch

    loss计算

    1
    2
    3
    4
    5
    6
    7
    8
    def compute_loss(y_pred, tao=0.05, device="cuda"):
    idxs = torch.arange(0, y_pred.shape[0], device=device)
    y_true = idxs + 1 - idxs % 2 * 2
    similarities = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=2)
    similarities = similarities - torch.eye(y_pred.shape[0], device=device) * 1e12
    similarities = similarities / tao
    loss = F.cross_entropy(similarities, y_true)
    return torch.mean(loss)
  • 有监督(不需要double输入)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    class SimcseModel(nn.Module):
    """Simcse有监督模型定义"""
    def __init__(self, pretrained_model: str, pooling: str):
    super(SimcseModel, self).__init__()
    # config = BertConfig.from_pretrained(pretrained_model) # 有监督不需要修改dropout
    self.bert = BertModel.from_pretrained(pretrained_model)
    self.pooling = pooling

    def forward(self, input_ids, attention_mask, token_type_ids):

    # out = self.bert(input_ids, attention_mask, token_type_ids)
    out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)

    if self.pooling == 'cls':
    return out.last_hidden_state[:, 0] # [batch, 768]

    if self.pooling == 'pooler':
    return out.pooler_output # [batch, 768]

    if self.pooling == 'last-avg':
    last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen]
    return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]

    if self.pooling == 'first-last-avg':
    first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen]
    last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen]
    first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
    last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
    avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768]
    return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768]

    def simcse_sup_loss(y_pred: 'tensor') -> 'tensor':
    """有监督的损失函数
    y_pred (tensor): bert的输出, [batch_size * 3, 768]

    """
    # 得到y_pred对应的label, 每第三句没有label, 跳过, label= [1, 0, 4, 3, ...]
    y_true = torch.arange(y_pred.shape[0], device=DEVICE)
    use_row = torch.where((y_true + 1) % 3 != 0)[0]
    y_true = (use_row - use_row % 3 * 2) + 1
    # batch内两两计算相似度, 得到相似度矩阵(对角矩阵)
    sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
    # 将相似度矩阵对角线置为很小的值, 消除自身的影响
    sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12
    # 选取有效的行
    sim = torch.index_select(sim, 0, use_row)
    # 相似度矩阵除以温度系数
    sim = sim / 0.05
    # 计算相似度矩阵与y_true的交叉熵损失
    loss = F.cross_entropy(sim, y_true)
    return loss

是否可以替代rdrop中的kl-loss?