简化基于BERT的模型以提高效率和容量
新方法使基于BERT的自然语言处理模型能够处理更长的文本字符串,在资源受限的环境中运行——或有时同时实现这两个目标。
近年来,自然语言处理领域许多性能最佳的模型都建立在BERT语言模型之上。BERT模型在大型公共文本语料库上进行预训练,能够编码词序列的概率。由于BERT模型开始时就具备对整个语言的广泛知识,因此可以用相对较少的标注数据针对更具体的任务进行微调。
然而,BERT模型非常庞大,基于BERT的自然语言处理模型可能运行缓慢——对于计算资源有限的用户来说甚至慢得无法接受。其复杂性也限制了可接受的输入长度,因为其内存占用随输入长度的平方而增加。
在今年的计算语言学协会会议上,我们提出了一种名为Pyramid-BERT的新方法,可在不牺牲太多准确性的情况下减少基于BERT模型的训练时间、推理时间和内存占用。减少的内存占用还使BERT模型能够处理更长的文本序列。
基于BERT的模型将句子序列作为输入,并输出整个句子及其组成词的向量表示。然而,下游应用仅使用完整句子的嵌入。为了使基于BERT的模型更高效,我们在网络中间层逐步消除冗余的单个词嵌入,同时尽量减少对完整句子嵌入的影响。
我们将Pyramid-BERT与几种最先进的BERT模型效率提升技术进行比较,结果显示在仅损失1.5%准确性的情况下,可将推理速度提高3到3.5倍,而在相同速度下,现有最佳方法的准确性损失为2.5%。
此外,当我们将此方法应用于专为长文本设计的BERT模型变体Performers时,可将模型内存占用减少70%,同时实际上提高了准确性。在此压缩率下,现有最佳方法的准确性下降4%。
标记的处理过程
输入BERT模型的每个句子被分解为称为标记的单元。大多数标记是单词,但有些是多词短语,有些是子词部分,有些是首字母缩略词的单个字母等。每个句子的开头由一个特殊标记分隔。
每个标记通过一系列编码器,每个编码器为每个输入标记生成新的嵌入。每个编码器都有一个注意力机制,决定每个标记的嵌入应反映其他标记携带的信息量。
随着标记通过一系列编码器,它们的嵌入会包含关于序列中其他标记的越来越多信息。当标记通过最终编码器时,CLS标记的嵌入最终代表整个句子。但其嵌入也与句子中所有其他标记的嵌入非常相似。这就是我们试图消除的冗余。
基本思想是,在网络的每个编码器中,我们保留CLS标记的嵌入,但选择其他标记嵌入的代表性子集。
嵌入是向量,因此可以解释为多维空间中的点。为了构建核心集,理想情况下,我们将嵌入分类为等直径的簇,并选择每个簇的中心点。
不幸的是,构建跨越神经网络层的核心集的问题是NP难的,意味着耗时过长。
作为替代方案,我们的论文提出了一种贪心算法,每次选择n个核心集成员。在每一层,我们取CLS标记的嵌入,然后在表示空间中找到距离它最远的n个嵌入。我们将这些与CLS嵌入一起添加到核心集中。然后我们找到与核心集中任何点的最小距离最大的n个嵌入,并将它们添加到核心集中。
我们重复此过程,直到核心集达到所需大小。这被证明是优化核心集的充分近似。
最后,在我们的论文中,我们考虑了每层核心集应该有多大的问题。我们使用指数延迟函数来确定从一层到下一层的衰减程度,并研究选择不同衰减率时准确性与加速或内存减少之间的权衡。
致谢:Ashish Khetan, Rene Bidart, Zohar Karnin
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)
公众号二维码
公众号二维码