Fork me on GitHub

BERTology: An Introduction!

This is an introduction of recent BERT families.

LM pretraining background

Autoregressive Language Modeling

Given a text sequence , autoregressive (AR) language modeling factorizes the likelihood along a uni direction according to the product rule, either forward:

or backward:

The AR pretraining maximizes the likehood under the forward AR factorization:

where denotes the context representation by NNs, such as RNNs/Transformers; denotes the embedding of $x$.

It is not effective to model the deep bidirectional contexts.

Autoencoding based pretraining

Autoencoding (AE) based pretraining does not perform density estimation, but recover the original data from corrupted (masked) input.

Denosing autoencoding based pretraining, such as BERT can model the bidirectional contexts. Given a text sequence , it randomly masks a portion (15%) of tokens $\bar{\pmb{x}}$ in $\pmb{x}$. The training objective is to reconstruct randomly masked token $\bar{\pmb{x}}$ from corrupted sequence $\hat{\pmb{x}}$:

where

  • means is masked;
  • the Transformer encodes $\pmb{x}$ into hidden vectors

However, it relies on corrupting the input with masks. The drawbacks:

  1. Independent assumption: cannot model joint probability and assume the predicted tokens are independent of each other. It neglects the dependency between the masked positions
  2. pretrain-finetune discrepancy (input noise): the artificial symbols like [MASK] used by BERT does not exist during the training of downstream tasks.

XLNet

XLNet[1] (CMU & Google brain 2019) leverages both the advatage of AR and AE LM objectives and hinder their drawbacks.

  • The permutation of the factorization order impedes the dependency between masked positions in BERT and still remains the AR-like objectives so as to prevent the pretrain-finetuning discrepancy.
  • On the other hand, with permutation, it attends to the bi-contextual information as in BERT.

Permutation Language Model

XLNet[13] applies permutation language model by autoregressively pretraining bidirectional contexts by maximizing the expected likelihood over all permutations of the input sequence factorization order.

Permutation Language Model

For a sequence $\pmb{x}$ of length $T$, there are $T!$ different orders to perform a valid AR factorization.

Permutation language modeling not only retains the benefits of AR models but also capture the bidirectional contexts as BERT. It only permutes the factorization order, rather than the sequence order.

Implementation

XLNet Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# XLNet implementation (tf)
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
"""
Sample a permutation of the factorization order, and create an
attention mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected
for partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
"""

# Generate permutation indices
index = tf.range(seq_len, dtype=tf.int64)
index = tf.transpose(tf.reshape(index, [-1, perm_size]))
index = tf.random_shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])

# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(tf.logical_or(
tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))

non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)

# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)

# Create `target_mask`: non-funcional and maksed tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)

# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(
self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)

# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
axis=0)

# construct inputs_k
inputs_k = inputs

# construct inputs_q
inputs_q = target_mask

return perm_mask, new_targets, target_mask, inputs_k, inputs_q

Huggingface Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
@dataclass
class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
"""
Data collator used for permutation language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for permutation language modeling with procedures specific to XLNet
"""

tokenizer: PreTrainedTokenizerBase
plm_probability: float = 1 / 6
max_span_length: int = 5 # maximum length of a span of masked tokens
return_tensors: str = "pt"

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = _torch_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}

def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = _tf_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = _numpy_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}

def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
"""
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
masked
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
span_length]` and mask tokens `start_index:start_index + span_length`
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
sequence to be processed), repeat from Step 1.
"""
import torch

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
)

if inputs.size(1) % 2 != 0:
raise ValueError(
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
)

labels = inputs.clone()
# Creating the mask and target_mapping tensors
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)

for i in range(labels.size(0)):
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
cur_len = 0
max_len = labels.size(1)

while cur_len < max_len:
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
context_length = int(span_length / self.plm_probability)
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
masked_indices[i, start_index : start_index + span_length] = 1
# Set `cur_len = cur_len + context_length`
cur_len += context_length

# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
# the i-th predict corresponds to the i-th token.
target_mapping[i] = torch.eye(labels.size(1))

special_tokens_mask = torch.tensor(
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
dtype=torch.bool,
)
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
masked_indices.masked_fill_(padding_mask, value=0.0)

# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
non_func_mask = ~(padding_mask | special_tokens_mask)

inputs[masked_indices] = self.tokenizer.mask_token_id
labels[~masked_indices] = -100 # We only compute loss on masked tokens

perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)

for i in range(labels.size(0)):
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
# determine which tokens a given token can attend to (encoded in `perm_mask`).
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
# This requires that the sequence length be even.

# Create a linear factorisation order
perm_index = torch.arange(labels.size(1))
# Split this into two halves, assuming that half the sequence is reused each time
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
# Permute the two halves such that they do not cross over
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
# Flatten this out into the desired permuted factorisation order
perm_index = torch.flatten(perm_index.transpose(0, 1))
# Set the permutation indices of non-masked (non-functional) tokens to the
# smallest index (-1) so that:
# (1) They can be seen by all other positions
# (2) They cannot see masked positions, so there won't be information leak
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
perm_mask[i] = (
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
) & masked_indices[i]

return inputs.long(), perm_mask, target_mapping, labels.long()

def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
"""
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
masked
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
span_length]` and mask tokens `start_index:start_index + span_length`
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
sequence to be processed), repeat from Step 1.
"""
from random import randint

import numpy as np
import tensorflow as tf

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
)

if tf.shape(inputs)[1] % 2 != 0:
raise ValueError(
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
)

labels = tf.identity(inputs)
# Creating the mask and target_mapping tensors
masked_indices = np.full(labels.shape.as_list(), 0, dtype=np.bool)
labels_shape = tf.shape(labels)
target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)

for i in range(len(labels)):
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
cur_len = 0
max_len = tf.shape(labels)[1]

while cur_len < max_len:
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
span_length = randint(1, self.max_span_length + 1)
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
context_length = int(span_length / self.plm_probability)
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
start_index = cur_len + randint(0, context_length - span_length + 1)
masked_indices[i, start_index : start_index + span_length] = 1
# Set `cur_len = cur_len + context_length`
cur_len += context_length

# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
# the i-th predict corresponds to the i-th token.
target_mapping[i] = np.eye(labels_shape[1])
masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
target_mapping = tf.convert_to_tensor(target_mapping)
special_tokens_mask = tf.convert_to_tensor(
[
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
for val in labels.numpy().tolist()
],
)
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
masked_indices = masked_indices & ~special_tokens_mask
if self.tokenizer._pad_token is not None:
padding_mask = labels == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask

# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
non_func_mask = ~(padding_mask | special_tokens_mask)

inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens

perm_mask = []

for i in range(len(labels)):
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
# determine which tokens a given token can attend to (encoded in `perm_mask`).
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
# This requires that the sequence length be even.

# Create a linear factorisation order
# tf.range is the equivalent of torch.arange
perm_index = tf.range(labels_shape[1])
# Split this into two halves, assuming that half the sequence is reused each time
perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
# Permute the two halves such that they do not cross over
perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
# Flatten this out into the desired permuted factorisation order
perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
# Set the permutation indices of non-masked (non-functional) tokens to the
# smallest index (-1) so that:
# (1) They can be seen by all other positions
# (2) They cannot see masked positions, so there won't be information leak
perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
perm_mask.append(
(tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
& masked_indices[i]
)
perm_mask = tf.stack(perm_mask, axis=0)

return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)

def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
"""
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
masked
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
span_length]` and mask tokens `start_index:start_index + span_length`
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
sequence to be processed), repeat from Step 1.
"""
from random import randint

import numpy as np

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
)

if inputs.shape[1] % 2 != 0:
raise ValueError(
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
)

labels = np.copy(inputs)
# Creating the mask and target_mapping tensors
masked_indices = np.full(labels.shape, 0, dtype=np.bool)
target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)

for i in range(labels.shape[0]):
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
cur_len = 0
max_len = labels.shape[1]

while cur_len < max_len:
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
span_length = randint(1, self.max_span_length + 1)
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
context_length = int(span_length / self.plm_probability)
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
start_index = cur_len + randint(0, context_length - span_length + 1)
masked_indices[i, start_index : start_index + span_length] = 1
# Set `cur_len = cur_len + context_length`
cur_len += context_length

# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
# the i-th predict corresponds to the i-th token.
target_mapping[i] = np.eye(labels.shape[1])

special_tokens_mask = np.array(
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
dtype=np.bool,
)
masked_indices[special_tokens_mask] = 0
if self.tokenizer._pad_token is not None:
padding_mask = labels == self.tokenizer.pad_token_id
masked_indices[padding_mask] = 0.0

# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
non_func_mask = ~(padding_mask | special_tokens_mask)

inputs[masked_indices] = self.tokenizer.mask_token_id
labels[~masked_indices] = -100 # We only compute loss on masked tokens

perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)

for i in range(labels.shape[0]):
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
# determine which tokens a given token can attend to (encoded in `perm_mask`).
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
# This requires that the sequence length be even.

# Create a linear factorisation order
perm_index = np.arange(labels.shape[1])
# Split this into two halves, assuming that half the sequence is reused each time
perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
# Permute the two halves such that they do not cross over
np.random.shuffle(perm_index)
# Flatten this out into the desired permuted factorisation order
perm_index = perm_index.T.flatten()
# Set the permutation indices of non-masked (non-functional) tokens to the
# smallest index (-1) so that:
# (1) They can be seen by all other positions
# (2) They cannot see masked positions, so there won't be information leak
perm_index[~masked_indices[i] & non_func_mask[i]] = -1
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
perm_mask[i] = (
perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
) & masked_indices[i]

return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)

Two-stream self-attention

The next-token distribution with the standard softmax formulation:

where , abbr. , denotes content representation, which encodes both context and itself, as hidden states in Transformer, i.e. standard self-attention, see below figure(a).

However, the previous $t-1$ sequence cannot implies the unique predicted target since different target words might have the same previous sequence in the permutated AR factorization. Hence, XLNet also consider the target position information:

where , abbr. denotes a query representation, only using the position and not the context , as below figure(b).

  • query stream uses but cannot see :
  • content stream uses both and :

Here $\pmb{Q}$,$\pmb{K}$,$\pmb{V}$ denot the query, key, value in an attention op.

  • During finetuning, we can simply drop the query stream and use the content stream as a normal Transformer(-XL).
  • The permutation implementation is relying on the attention mask, as shown in the figure, which does not affect the original sequecen orders.

upload successful

Transformer-XL

Borrow relative positional encoding and segment recurrence mechanism from Transformer-XL[2].
The next segment with memory is:

Relative segment encoding

XLnet only considers ‘’whether the two positions in segments are within the same segment as opposed to considering which specific segments they are from‘’.

The idea of relative encodings is only modeling the relationships between positions , denoting the segment encoding beween position i to j.
The attention weight , where is the query vector in std attention and $\mathbf{b}$ is a learnable head-specific bias vector. Finally add to the normal attention weight.

The advantage of relative segment encodings:

  1. to introduce inductive biases to improve generalization;
  2. to allow for the multiple input segments in finetuning on tasks.

RoBERTa: “BERT is undertrained”!

RoBERTa[5] (Fair & UW 2019) (Robustly optimized BERT approach) redesigned the BERT experiments[6], illustrating BERT is underfitted. It showed that BERT pretraining with a larger batch size over more data for more training steps could lead to a better pretraining results.

Recent works[1][5] questioned the effectiveness of Next Sentence Prediction (NSP) pretraining task proposed by BERT[6].

SpanBERT

SpanBERT[7] (UW & Fair) proposed a span-level pretraining approach by masking contiguous random spans rather than individual tokens as in BERT. It consistently surpass BERT and substantially outweights on span selection tasks involving question answering and coreference resolution. The NSP auxiliary objective is removed.

In comparison,

  • The concurrent work ERNIE[8] (Baidu 2019) that masked linguistically-informed spans in Chinese, i.e. masking phrase and named entity, achieve improvements on Chinese NLP tasks.

Span masking

At each iteration, the span’s length is samplled from a geometric distribution $\mathscr{l} \sim Geo(p) = (1-p)^{(k-1)} p$; the starting point of spans are uniformly random selected from the sequence. (In SpanBERT, p=0.2, and clip .)

15% of tokens in span-level are masked: of which masking 80%, replacing 10% with noise, keeping the rest 10%.

upload successful

Span boundary objective (SBO)

Given a masked span , where (s,e) denotes the start and ending positions. Each token in the span are represented using the encodings of the outside boundary tokens and (i.e., and in the figure) and the target positional embedding of target token , that is:

where $f(\cdot)$ indicates the 2-layer FFNN with layer normalizations and Gelu activations.

The representions of span tokens is used to predict and compute the corss entropy loss like MLM objective in BERT.

ALBERT

ALBERT[9] (A Lite BERT) (Google 2019) adopted factorized embedding parameterization and cross-layer parameter sharing techiniques to reuduce the memory cost of BERT architecture.

Factorized embedding parameterization

ALBERT decomposed the embedding parameters with higher dimension to smaller matrices, by firstly projecting the inputs into a lower dimensional embedding of size E, followed by the second projection to the hidden space. The embedding paprameters are reduced from $O(V \times H)$ to $O(V \times E + E \times H)$, which is obvious when $H \gg E$

Cross-layer parameter sharing

All parameters across layers on both self-attentions and FFNs are shared. It is empirically showed that the L2 distance and cosine similarity between the input and output are oscillating rather than converging, which is different than that in Deep Equilibrium Model (DEQ)[10].

Sentence-order prediction (SOP)

ALBERT use two consecutive setences as positive samples as in NSP, and swap the order of the same ajacent segments directly as the negative samples, consistently showing a better results for multi-sentence encoding tasks.

ELECTRA

ELECTRA[11] (Efficiently Learning an Encoder that Classifies Token Replacements Accurately) (Standford NLP) proposed a more sample-efficient pre-training approach, replaced token detection to efficiently boost the pretraining efficiency, which solves the pretraining-finetuning discrepancy led by [MASK] symbols.

upload successful

Replaced token detection

  • Rather than randomly masked tokens with the probability 15% as in BERT, replaced token detection replaces tokens with plausible alternatives that sampled from the output of a small generator network.
  • Then adopt a discriminator to predict whether each token was corrupted with a sampled replacement.

ELECTRA trains two NNs, a generator $G$ and a discriminator $D$. Each one primarily consists of an encoder that maps a sequence on input tokens into the contextual representation .

  1. The generator is used to to do Masked Language Model (MLM) as in BERT[6]. For the position $t$, the generator outputs the distribution of via a softmax layer:

    where $e$ is the word embeddings.

  2. For the discriminator $\mathscr{D}$, it discriminates whether the token at position $t$ is replaced.

  • MLM of BERT first randomly selects the positions to mask , wherein tokens at masked positions are replaced with a [MASK] token:
  • In contrast, the replaced token detection uses the generator G to learn the MLE of masked tokens whilst the discriminator D is applied to detect the fakeness.

Loss function

The loss functions are:

The combined loss is minimized:

where $\chi$ denotes the corpus.

Training

Weight sharing

  • Share the embeddings (both token embeddings and position embeddings) of the generator and discriminator.
  • Weight tying strategy[12] -> only tied embeddings.

Two-stage training

  1. Train only the geenrator with for $n$ steps
  2. Initialize the weights of the $D$ with $G$ and train $D$ with for $n$ steps, keeping the generator’s weight frozen.

After pretraining, throw away the generator and fine-tune the discriminator on downstream tasks.

References


  1. 1.Yang, Z., Dai, Z., Yang, Y., Carbonell, J., Salakhutdinov, R., & Le, Q. V. (2019). XLNet: Generalized Autoregressive Pretraining for Language Understanding. arXiv preprint arXiv:1906.08237.
  2. 2.Dai, Z., Yang, Z., Yang, Y., Cohen, W. W., Carbonell, J., Le, Q. V., & Salakhutdinov, R. (2019). Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860.
  3. 3.Shaw, P., Uszkoreit, J., & Vaswani, A. (2018). Self-attention with relative position representations. arXiv preprint arXiv:1803.02155.
  4. 4.Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008).
  5. 5.Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D., ... & Stoyanov, V. (2019). Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692.
  6. 6.Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
  7. 7.Joshi, M., Chen, D., Liu, Y., Weld, D. S., Zettlemoyer, L., & Levy, O. (2019). Spanbert: Improving pre-training by representing and predicting spans. arXiv preprint arXiv:1907.10529.
  8. 8.Sun, Y., Wang, S., Li, Y., Feng, S., Chen, X., Zhang, H., ... & Wu, H. (2019). ERNIE: Enhanced Representation through Knowledge Integration. arXiv preprint arXiv:1904.09223.
  9. 9.Lan, Z., Chen, M., Goodman, S., Gimpel, K., Sharma, P., & Soricut, R. (2019). Albert: A lite bert for self-supervised learning of language representations. arXiv preprint arXiv:1909.11942.
  10. 10.Deep Equilibrium Models. arXiv 2019
  11. 11.ELECTRA: Pre-training Text Encoders as Discriminators rather than Generators
  12. 12.Press, O., & Wolf, L. (2016). Using the output embedding to improve language models. arXiv preprint arXiv:1608.05859.
  13. 13.GitHub: XLNet