Fork me on GitHub

Mask Denoising Strategy for Pre-trained Models

Mask modeling is a crucial role in pre-training language models. This note provides a short summary.

BERT/RoBERTa Mask

BERT[1] applies masked language modeling (MLM) on the sequence of text segments. Specifically, BERT uses a uniform masking rate of 15% after WordPiece tokenization, where it replace the masked tokens with
1) [MASK] 80% of time time,
2) with a random word 10% of the time, and
3) 10% unchanged, to bias the representation towards the actual observed word.

The random replacement only occurs for 15% of all tokens (i.e., 10% of 15%), this does not seem to harm the model’s language understanding capacity.

BERT applies static masking for multiple runs ahead of time and keeps unchanged afterwards; while RoBERTa adopts dynamic masking in an on-the-fly manner during training.

Google BERT Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# 1. Google BERT implementation. (w/ wwm)
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])


def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""

cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])

rng.shuffle(cand_indexes)

output_tokens = list(tokens)

num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))

masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)

masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

output_tokens[index] = masked_token

masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)

masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)

return (output_tokens, masked_lm_positions, masked_lm_labels)

Huggingface Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
# Huggingface implementation: https://github.com/huggingface/transformers/blob/d72343d2b804d0304d93bac1c1b58e0dafd5e820/src/transformers/data/data_collator.py#L606
@dataclass
class DataCollatorForLanguageModeling(DataCollatorMixin):
"""
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
mlm (`bool`, *optional*, defaults to `True`):
Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
tokens and the value to predict for the masked token.
mlm_probability (`float`, *optional*, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
<Tip>
For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
</Tip>"""

tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
pad_to_multiple_of: Optional[int] = None
tf_experimental_compile: bool = False
return_tensors: str = "pt"

def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
if self.tf_experimental_compile:
import tensorflow as tf

self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)

@staticmethod
def tf_bernoulli(shape, probability):
import tensorflow as tf

prob_matrix = tf.fill(shape, probability)
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)

def tf_mask_tokens(
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
import tensorflow as tf

input_shape = tf.shape(inputs)
# 1 for a special token, 0 for a normal token in the special tokens mask
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
labels = tf.where(masked_indices, inputs, -100)

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices

inputs = tf.where(indices_replaced, mask_token_id, inputs)

# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs)

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels

def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
import tensorflow as tf

# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {
"input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
}

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
for val in batch["input_ids"].numpy().tolist()
]
# Cannot directly create as bool
special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
else:
special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
tf.cast(batch["input_ids"], tf.int64),
special_tokens_mask=special_tokens_mask,
mask_token_id=self.tokenizer.mask_token_id,
vocab_size=len(self.tokenizer),
)
else:
labels = batch["input_ids"]
if self.tokenizer.pad_token_id is not None:
# Replace self.tokenizer.pad_token_id with -100
labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
else:
labels = tf.identity(labels) # Makes a copy, just in case
batch["labels"] = labels
return batch

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {
"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
}

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch["input_ids"].clone()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch

def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
import torch

labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()

probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
import numpy as np

# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {
"input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
}

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = np.copy(batch["input_ids"])
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch

def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
import numpy as np

labels = np.copy(inputs)
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = np.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = np.array(special_tokens_mask, dtype=np.bool)
else:
special_tokens_mask = special_tokens_mask.astype(np.bool)

probability_matrix[special_tokens_mask] = 0
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(np.bool)
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(np.bool) & masked_indices
inputs[indices_replaced] = self.tokenizer.mask_token_id

# 10% of the time, we replace masked input tokens with random word
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
inputs[indices_random] = random_words

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels



# w/ wwm
@dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
"""
Data collator used for language modeling that masks entire words.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
<Tip>
This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
</Tip>"""

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]

batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

mask_labels = []
for e in examples:
ref_tokens = []
for id in tolist(e["input_ids"]):
token = self.tokenizer._convert_id_to_token(id)
ref_tokens.append(token)

# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]

batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

mask_labels = []
for e in examples:
ref_tokens = []
for id in tolist(e["input_ids"]):
token = self.tokenizer._convert_id_to_token(id)
ref_tokens.append(token)

# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]

batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

mask_labels = []
for e in examples:
ref_tokens = []
for id in tolist(e["input_ids"]):
token = self.tokenizer._convert_id_to_token(id)
ref_tokens.append(token)

# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
"""
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
"Please refer to the documentation for more information."
)

cand_indexes = []
for (i, token) in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue

if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])

random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)

if len(covered_indexes) != len(masked_lms):
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels

def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
"""
import torch

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

probability_matrix = mask_labels

special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)

masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels

def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
"""
import tensorflow as tf

input_shape = tf.shape(inputs)
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = tf.identity(inputs)
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

masked_indices = tf.cast(mask_labels, tf.bool)

special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
]
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
if self.tokenizer._pad_token is not None:
padding_mask = inputs == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask

# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
labels = tf.where(masked_indices, inputs, -100)

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices

inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)

# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs)

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels

def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
"""
import numpy as np

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = np.copy(inputs)
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

masked_indices = mask_labels.astype(np.bool)

special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
masked_indices[np.array(special_tokens_mask, dtype=np.bool)] = 0
if self.tokenizer._pad_token is not None:
padding_mask = labels == self.tokenizer.pad_token_id
masked_indices[padding_mask] = 0

labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(np.bool) & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels

SpanBERT Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# 3. SpanBERT implementation
class BertRandomMaskingScheme(object):
def __init__(self, args, tokens, pad, mask_id):
self.args = args
self.mask_ratio = getattr(self.args, 'mask_ratio', None)
self.pad = pad
self.tokens = tokens
self.mask_id = mask_id

def mask(self, sentence, tagmap=None):
"""mask tokens for masked language model training
Args:
sentence: 1d tensor, token list to be masked
mask_ratio: ratio of tokens to be masked in the sentence
Return:
masked_sent: masked sentence
"""
sent_length = len(sentence)
mask_num = math.ceil(sent_length * self.mask_ratio)
mask = np.random.choice(sent_length, mask_num, replace=False)
return bert_masking(sentence, mask, self.tokens, self.pad, self.mask_id)

def bert_masking(sentence, mask, tokens, pad, mask_id):
sentence = np.copy(sentence)
sent_length = len(sentence)
target = np.copy(sentence)
mask = set(mask)
for i in range(sent_length):
if i in mask:
rand = np.random.random()
if rand < 0.8:
sentence[i] = mask_id
elif rand < 0.9:
# sample random token according to input distribution
sentence[i] = np.random.choice(tokens)
else:
target[i] = pad
return sentence, target, None

Span Mask

Span masking consists of random masking, named entity masking, etc.

  1. ERNIE[6] applies knowledge masking on the input sequence including entity- and phrase- level masking to inject knowledge composition.
  2. SpanBERT[2] employs random span masking under a clamped geometric distribution.
  3. BERT-WWM[7] uses whole word masking (for Chinese BERT) rather than randomly masking subword pieces to retain the whole meaning of a word.

SpanBERT[2] iteratively samples span’s length under a (clamped) geometric distribution , i.e.,

which is skewed towards shorter spans ($p=0.2$). It also clips $\mathcal{l}$ with , yielding a mean span length of $\bar{\mathcal{l}}=3.8$. SpanBERT measures span length in complete words, not subword tokens, making the masked spans even longer.

The masking strategies are the same as BERT: masking 15% in total, where replacing 80% of tokens with [MASK], 10% with random tokens, and 10% unchanged.

SpanBERT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# SpanBERT implementation
class PairWithSpanMaskingScheme(object):
def __init__(self, args, tokens, pad, mask_id, paragraph_info):
self.args = args
self.mask_ratio = getattr(self.args, 'mask_ratio', None)
self.args = args
self.max_pair_targets = args.max_pair_targets
self.lower = args.span_lower
self.upper = args.span_upper
self.pad = pad
self.mask_id = mask_id
self.tokens = tokens
self.paragraph_info = paragraph_info
self.lens = list(range(self.lower, self.upper + 1))
self.p = args.geometric_p
self.len_distrib = [self.p * (1-self.p)**(i - self.lower) for i in range(self.lower, self.upper + 1)] if self.p >= 0 else None
self.len_distrib = [x / (sum(self.len_distrib)) for x in self.len_distrib]
print(self.len_distrib, self.lens)


def mask(self, sentence, tagmap=None):
"""mask tokens for masked language model training
Args:
sentence: 1d tensor, token list to be masked
mask_ratio: ratio of tokens to be masked in the sentence
Return:
masked_sent: masked sentence
"""
sent_length = len(sentence)
mask_num = math.ceil(sent_length * self.mask_ratio)
mask = set()
word_piece_map = self.paragraph_info.get_word_piece_map(sentence)
spans = []
while len(mask) < mask_num:
span_len = np.random.choice(self.lens, p=self.len_distrib)
tagged_indices = None
if tagmap is not None:
tagged_indices = [max(0, i - np.random.randint(0, span_len)) for i in range(tagmap.length()) if tagmap[i]]
tagged_indices += [np.random.choice(sent_length)] * int(len(tagged_indices) == 0)
anchor = np.random.choice(sent_length) if np.random.rand() >= self.args.tagged_anchor_prob else np.random.choice(tagged_indices)
if anchor in mask:
continue
# find word start, end
left1, right1 = self.paragraph_info.get_word_start(sentence, anchor, word_piece_map), self.paragraph_info.get_word_end(sentence, anchor, word_piece_map)
spans.append([left1, left1])
for i in range(left1, right1):
if len(mask) >= mask_num:
break
mask.add(i)
spans[-1][-1] = i
num_words = 1
right2 = right1
while num_words < span_len and right2 < len(sentence) and len(mask) < mask_num:
# complete current word
left2 = right2
right2 = self.paragraph_info.get_word_end(sentence, right2, word_piece_map)
num_words += 1
for i in range(left2, right2):
if len(mask) >= mask_num:
break
mask.add(i)
spans[-1][-1] = i
sentence, target, pair_targets = span_masking(sentence, spans, self.tokens, self.pad, self.mask_id, self.max_pair_targets, mask, replacement=self.args.replacement_method, endpoints=self.args.endpoints)
if self.args.return_only_spans:
pair_targets = None
return sentence, target, pair_targets

class ParagraphInfo(object):
def __init__(self, dictionary):
self.dictionary = dictionary

def get_word_piece_map(self, sentence):
return [self.dictionary.is_start_word(i) for i in sentence]

def get_word_at_k(self, sentence, left, right, k, word_piece_map=None):
num_words = 0
while num_words < k and right < len(sentence):
# complete current word
left = right
right = self.get_word_end(sentence, right, word_piece_map)
num_words += 1
return left, right

def get_word_start(self, sentence, anchor, word_piece_map=None):
word_piece_map = word_piece_map if word_piece_map is not None else self.get_word_piece_map(sentence)
left = anchor
while left > 0 and word_piece_map[left] == False:
left -= 1
return left
# word end is next word start
def get_word_end(self, sentence, anchor, word_piece_map=None):
word_piece_map = word_piece_map if word_piece_map is not None else self.get_word_piece_map(sentence)
right = anchor + 1
while right < len(sentence) and word_piece_map[right] == False:
right += 1
return right

def span_masking(sentence, spans, tokens, pad, mask_id, pad_len, mask, replacement='word_piece', endpoints='external'):
sentence = np.copy(sentence)
sent_length = len(sentence)
target = np.full(sent_length, pad)
pair_targets = []
spans = merge_intervals(spans)
assert len(mask) == sum([e - s + 1 for s,e in spans])
# print(list(enumerate(sentence)))
for start, end in spans:
lower_limit = 0 if endpoints == 'external' else -1
upper_limit = sent_length - 1 if endpoints == 'external' else sent_length
if start > lower_limit and end < upper_limit:
if endpoints == 'external':
pair_targets += [[start - 1, end + 1]]
else:
pair_targets += [[start, end]]
pair_targets[-1] += [sentence[i] for i in range(start, end + 1)]
rand = np.random.random()
for i in range(start, end + 1):
assert i in mask
target[i] = sentence[i]
if replacement == 'word_piece':
rand = np.random.random()
if rand < 0.8:
sentence[i] = mask_id
elif rand < 0.9:
# sample random token according to input distribution
sentence[i] = np.random.choice(tokens)
pair_targets = pad_to_len(pair_targets, pad, pad_len + 2)
# if pair_targets is None:
return sentence, target, pair_targets

def merge_intervals(intervals):
intervals = sorted(intervals, key=lambda x : x[0])
merged = []
for interval in intervals:
# if the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if not merged or merged[-1][1] + 1 < interval[0]:
merged.append(interval)
else:
# otherwise, there is overlap, so we merge the current and previous
# intervals.
merged[-1][1] = max(merged[-1][1], interval[1])
return merged

def pad_to_len(pair_targets, pad, max_pair_target_len):
for i in range(len(pair_targets)):
pair_targets[i] = pair_targets[i][:max_pair_target_len]
this_len = len(pair_targets[i])
for j in range(max_pair_target_len - this_len):
pair_targets[i].append(pad)
return pair_targets

Results of SpanBERT mask scheme.

It can be seen from the table that with the exception of coreference resolution, masking random spans is preferable to other strategies. Although linguistic masking schemes (named entities and noun phrases) are often competitive with random spans, their performance is not consistent. For coreference resolution, masking random subword toekns is preferable to any form of span masking.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# NER mask
class NERSpanMaskingScheme(object):
def __init__(self, args, tokens, pad, mask_id, paragraph_info):
self.args = args
self.mask_ratio = getattr(self.args, 'mask_ratio', None)
self.max_pair_targets = args.max_pair_targets
self.lower = args.span_lower
self.upper = args.span_upper
self.pad = pad
self.mask_id = mask_id
self.tokens = tokens
self.paragraph_info = paragraph_info
self.lens = list(range(self.lower, self.upper + 1))
self.p = args.geometric_p
self.len_distrib = [self.p * (1-self.p)**(i - self.lower) for i in range(self.lower, self.upper + 1)] if self.p >= 0 else None
self.len_distrib = [x / (sum(self.len_distrib)) for x in self.len_distrib]
print(self.len_distrib, self.lens)

def mask_random_span(self, sentence, mask_num, word_piece_map, spans, mask, span_len, anchor):
# find word start, end
left1, right1 = self.paragraph_info.get_word_start(sentence, anchor, word_piece_map), self.paragraph_info.get_word_end(sentence, anchor, word_piece_map)
spans.append([left1, left1])
for i in range(left1, right1):
if len(mask) >= mask_num:
break
mask.add(i)
spans[-1][-1] = i
num_words = 1
right2 = right1
while num_words < span_len and right2 < len(sentence) and len(mask) < mask_num:
# complete current word
left2 = right2
right2 = self.paragraph_info.get_word_end(sentence, right2, word_piece_map)
num_words += 1
for i in range(left2, right2):
if len(mask) >= mask_num:
break
mask.add(i)
spans[-1][-1] = i

def mask_entity(self, sentence, mask_num, word_piece_map, spans, mask, entity_spans):
if len(entity_spans) > 0:
entity_span = entity_spans[np.random.choice(range(len(entity_spans)))]
spans.append([entity_span[0], entity_span[0]])
for idx in range(entity_span[0], entity_span[1] + 1):
if len(mask) >= mask_num:
break
spans[-1][-1] = idx
mask.add(idx)


def mask(self, sentence, entity_map=None):
"""mask tokens for masked language model training
Args:
sentence: 1d tensor, token list to be masked
mask_ratio: ratio of tokens to be masked in the sentence
Return:
masked_sent: masked sentence
"""
sent_length = len(sentence)
mask_num = math.ceil(sent_length * self.mask_ratio)
mask = set()
word_piece_map = self.paragraph_info.get_word_piece_map(sentence)
# get entity spans
entity_spans, spans = [], []
new_entity = True
for i in range(entity_map.length()):
if entity_map[i] and new_entity:
entity_spans.append([i, i])
new_entity = False
elif entity_map[i] and not new_entity:
entity_spans[-1][-1] = i
else:
new_entity = True
while len(mask) < mask_num:
if np.random.random() <= self.args.ner_masking_prob:
self.mask_entity(sentence, mask_num, word_piece_map, spans, mask, entity_spans)
else:
span_len = np.random.choice(self.lens, p=self.len_distrib)
anchor = np.random.choice(sent_length)
if anchor in mask:
continue
self.mask_random_span(sentence, mask_num, word_piece_map, spans, mask, span_len, anchor)
sentence, target, pair_targets = span_masking(sentence, spans, self.tokens, self.pad, self.mask_id, self.max_pair_targets, mask, replacement=self.args.replacement_method, endpoints=self.args.endpoints)
if self.args.return_only_spans:
pair_targets = None
return sentence, target, pair_targets

MASS Mask

MASS[3] encoder replaces each masked token by a special [MASK] token, leading to unchanged length overall. Then the decoder predicts the masked tokens autoregressively.

MASS Mask

BART Mask

BART[4] replaces corrputed continuous spans of the encoder input as single [MASK], and trains the decoder in an autogressive way using a transformer encoder-decoder architecture.

BART masking

BART allows any type of document corruption, including:

  • Token Masking: BERT masking.
  • Token Deletion: random tokens are deleted from the input.
  • Text Infilling: amounts of text spans are corrupted, with span length drawn from a Poission distribution ($\lambda=3$). Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of [MASK] tokens.
  • Sentence Permutation: Divide a document into peices of sentences based on full stops, and randomly shuffle them.
  • Document Rotation: uniformly chose a token at random to rotate the document.

T5 Span Mask

T5[5] replaces with unique sentinel the corrupted spans in the input sequence, and predicts the concatenation of corrupted spans prefixed by the sentinal token used in the input. Specifically, T5 first replaces the entirety of each consecutive span of corrupted tokens with a unique mask token. Then, the target sequence becomes the concatenation of the corrupted spans, each prefixed by the mask token used to replace it in the input.

As shown in the table, BERT-syle objective simply replaces 15% of the input tokens without the original random token swapping step, and reconstruct the original uncorrupted sequence.

The first two rows (i.e., BERT-style and MASS-style objectives) predict the entire uncorrupted text span which requires self-attention over long sequences in the decoder. To avoid this, T5 applies the strategies in the last two rows. The last row(i.e., Drop corrupted tokens) simply drops the corrupted tokens from the input sequence completely and task the model with reconstructing the dropped tokens in order.

It can be seen from the table that "dropping corrupted spans" completely produced a small improvement in the GLUE score thanks to the significatly higher score on CoLA.
The first two rows (i.e., BERT-style and MASS-style objectives) predict the entire uncorrupted text span which requires self-attention over long sequences in the decoder. To avoid this, T5 applies the strategies in the last two rows. The last row(i.e., Drop corrupted tokens) simply drops the corrupted tokens from the input sequence completely and task the model with reconstructing the dropped tokens in order. (60.45 vs avg. baseline 53.84). However, dropping tokens completely performed worse than replacing with sentinel tokens on SuperGLUE. The last two rows’ variants make the target sequence shorter and consequently make training faster.

For attribution in academic contexts, please cite this work as:

1
2
3
4
5
6
@misc{chai2022mask-PTMs,
author = {Chai, Yekun},
title = {{Mask Strategy for Pre-trained Models}},
year = {2022},
howpublished = {\url{https://cyk1337.github.io/notes/2022/01/10/Mask-Denoising-Strategy-for-Pre-trained-Models/}},
}

References