Fork me on GitHub

Subword Tokenizers for Pre-trained Models

Summary of word tokenization for pre-trained models.

Preliminaries

The table summarizes the difference between various tokenization approaches.

Tokenization Methods Pros Cons
word-based 1.Space and punctuation tokenization;
2.Rule-based tokenization.
Easy to use. 1. Very large vocabularies.
2. OOV problem.
3. Loss of meanings across very similar words.
char-based Splitting into chars 1.Slimmer vocabularies.
2.Mostly open-vocabulary: fewer OOV words.
1. Very long sequence.
2. Less meaningful individual tokens
subword-based WordPiece
BPE
Unigram
1.Good balance the vocabulary size and the sequence length;
2.Help identify similar syntactic or semantic situations in texts;
3.Can identify start of word tokens.
Need training additional subword tokenizer.

Why subword?

  • Subword-based tokenization lies between character- and word-based tokenization, which arises from the idea that:

    Frequently used words should not be split into smaller subwords;
    Rare words should be decomposed into meaningful subwords.

  • Subwords help identify similar syntactic or semantic situations in texts, such as same prefix or sufix.
  • Subword tokenization can identify start of word tokens, such as “##” in WordPiece (BERT).

Summary

It can be seen from the table that:

  • OpenAI and Facebook favor BPE tokenization whereas Google prefers self-proposed WordPiece and Unigram methods. ;)
ModelTokenizationCorpusAuthors
GPTBPE (40,478)
[Spacy/ftfy pre-tokenizer]
BooksCorpusOpenAI
GPT-2BBPE (50,257)WebText (40GB)OpenAI
GPT-3BBPECommon Crawl, WebText2,
Books1/2, Wikipedia
OpenAI
RoBERTaBBPEBooksCorpus, enwikiFacebook
BARTBBPEBooksCorpus, enwikiFacebook
BERTWordPiece (30k)BooksCorpus, enwikiGoogle
T5Unigram (SentencePiece)C4Google
XLNetUnigram (SentencePiece)BooksCorpus, enwiki, Giga5,
ClueWeb 201-B, Common Crawl
Google
ELECTRAWordPiece (30k)base: same as BERT;
large: same as XLNet
Google
ALBERTUnigram (SentencePiece)BooksCorpus, enwikiGoogle

Following sections generally compares common subword methods in pre-trained models. Refer to following links[13] for detailed tokenization process.

TL;DR

  • WordPiece $\Uparrow$ (probability-based) merges tokens based on bigram likelihood. It uses a language model to evaluate the likelihood of subword pair mergence during each iteration, incrementally merging the neighbor unit pairs.
  • Byte Pair Encoding (BPE) $\Uparrow$ (frequency-based) merges tokens based on bigram frequency. It uses the subword pair co-occurrence to greedily merge neighbor pairs, which can effiectively balance the vocabulary size and the sequence length. It is based on the greedy longest-match-first algorithm (deterministic symbol replacement), which cannot generate multiple segmentations with probabilities.
  • Unigram Language Model $\Downarrow$ (subword regularization) prunes tokens based on unigram LM perplexity, which can be viewed as a probabilistic mixture of characters, subwords, and word segmentations, where the mixture probabiilty is computed using EM algorithm. It reduces the subword using a unigram LM with likelihood reduction.
Tokenization #Vocab Update method New symbols
WordPiece Bottom-up merge
BPE Bottom-up merge
Unigram Prune

Byte-Pair Encoding (BPE)

Byte-Pair Encoding (BPE)[8] firstly adopts a pre-tokenizer to split the text sequence into words, then curates a base vocabulary consisting of all character symbol sets in the training data for frequency-based merge.

Pre-tokenization  The pre-tokenization can be:

  • Space tokenization, e.g. GPT-2, RoBERTa.
  • Rule-based tokenization (Moses), e.g. XLM.
  • Spacy and ftfy: GPT.

Frequency-based Merge  Starting with the base vocabulary, BPE counts the frequency of each neighbor pair and selects the unit pair that occurs most frequently to the base vocabulary. Then it searches for the next unit pair that occurs the most frequently.

BPE algorithm <small>[15]</small>

Byte-level BPE (BBPE)

Background

Unicode vs Byte:

  • Unicode: Unicode is an encoding for textual characters which is able to represent characters from different languages. Each character is represented by a unicode code point. Unicode consists of a total of 137,929 characters.
  • Byte: 8 bits is called a byte. One byte character set can contain 256 characters.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> u'Hi'.encode('ASCII')
b'Hi'
>>> b'\x48\x69'
b'Hi'
>>> b'\x48\x69'.decode('ASCII') # hex (ascii) to str (unicode)
'Hi'
>>> chr(72)
'H'
>>> b'\x48'
b'H'
>>> ord(b'\x48')
72
>>> hex(72)
'0x48'
>>> b'\x48'.decode()
'H'
BPE Tokenization #Code points Seq length OOV
Unicode-level 13k+ L
Byte-level 256 ≤4L

Unicode code point contains 130k+ points to cover the full space of textual characters, which can increase the base vocabulary size of BPE. Thus, Applying BPE to the byte sequence of language is a great idea proposed in GPT-2[14] to reduce the vocabulary size. However, directly applying byte-level BPE can result in suboptimum because the greedy frequency-based heuristic in BPE tend to merge common words into neighbors to generate overfit sub-tokens, such as “-ing.”, “-ing!”, “-ing?”.

To avoid this, GPT-2[14] prevents BPE from merging across different character categories for any byte sequence except space. With byte-level subwords, BBPE can represent any texts using moderate vocabulary size without out-of-vocabulary problem. Moreover, it will increase the byte sequence length to x4 maximum.

BBPE

The base vocabulary contains all possible base characters in the training data. It can become large if all unicode characters are included. Thus, GPT-2[14] used Byte-level BPE (BBPE) by resorting to byte sequence of texts instead of unicode character strings for base vocabulary construction. It is also adopted by RoBERTa, BART, GPT-2, and GPT-3.

Vocabulary Size  The final vocabulary size is the size of base vocabulary plus the # of merges, where the # of merges is a hyperparameter. For instance,

  • GPT (character-level BPE) has 40,478 vocabularies: 478 base vocabularies + 40k merges.
  • GPT-2 (byte-level BPE) has 50,257 vocabularies: 256 base vocabularies + 1 [EOS] token + 50k merges.

Examples of Ja-En tokenization with various vocabulary sizes

The implementation of GPT-2 & RoBERTa.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""
GPT-2 & RoBERTa
Byte pair encoding utilities from GPT-2.

Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
Original license: MIT
"""

import json
from functools import lru_cache


@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs


class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder # bpe-vocab.json -> {subword:id}
self.decoder = {v: k for k, v in self.encoder.items()} # {id: subword}
self.errors = errors # how to handle errors in decoding
# {byte: unicode}
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} # {unicode: byte}
# bpe-merges.txt -> {tuple: rank}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}

try:
import regex as re
self.re = re
except ImportError:
raise ImportError("Please install regex with: pip install regex")

# Should haved added re.IGNORECASE so BPE merges
# can happen for capitalized versions of contractions
self.pat = self.re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)

def bpe(self, token):
# check if already processed
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word) # count bigrams

if not pairs:
return token

while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
# find all possible merges for a bigram
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except: # no further merge
new_word.extend(word[i:])
break

# bigram match & satisfy length limit
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word # update merged tokens
if len(word) == 1:
break
else:
pairs = get_pairs(word) # new possible pairs
word = " ".join(word)
self.cache[token] = word # cache raw tokens
return word

def encode(self, text):
bpe_tokens = []
for token in self.re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens

def decode(self, tokens):
text = "".join([self.decoder.get(token, token) for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text


def get_encoder(encoder_json_path:"bpe-vocab.json", vocab_bpe_path:"bpe-merge.txt"):
with open(encoder_json_path, "r") as f:
encoder = json.load(f)
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)


# RoBERTa source code.
import argparse
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool

def main():
"""
Helper script to encode raw text with the GPT-2 BPE using multiple processes.

The encoder.json and vocab.bpe files can be obtained here:
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
help="path to encoder.json",
)
parser.add_argument(
"--vocab-bpe",
type=str,
help="path to vocab.bpe",
)
parser.add_argument(
"--inputs",
nargs="+",
default=["-"],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
default=["-"],
help="path to save encoded outputs",
)
parser.add_argument(
"--keep-empty",
action="store_true",
help="keep empty lines",
)
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()

assert len(args.inputs) == len(
args.outputs
), "number of input and output paths should match"

with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-"
else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-"
else sys.stdout
for output in args.outputs
]

encoder = MultiprocessingEncoder(args)
pool = Pool(args.workers, initializer=encoder.initializer)
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)

stats = Counter()
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
if filt == "PASS":
for enc_line, output_h in zip(enc_lines, outputs):
print(enc_line, file=output_h)
else:
stats["num_filtered_" + filt] += 1
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)

for k, v in stats.most_common():
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)


class MultiprocessingEncoder(object):
def __init__(self, args):
self.args = args

def initializer(self):
global bpe
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)

def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))

def decode(self, tokens):
global bpe
return bpe.decode(tokens)

def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0 and not self.args.keep_empty:
return ["EMPTY", None]
tokens = self.encode(line)
enc_lines.append(" ".join(tokens))
return ["PASS", enc_lines]

def decode_lines(self, lines):
dec_lines = []
for line in lines:
tokens = map(int, line.strip().split())
dec_lines.append(self.decode(tokens))
return ["PASS", dec_lines]


if __name__ == "__main__":
main()

WordPiece

WordPiece[9][10] can be viewed as a language-modeling based BPE variant. It trains with similar process to the BPE but uses disparate merge rule: WordPiece select the unit pair that maximizes the likelihood of traing data at utmost, rather than choose the most frequent pair. WordPiece chooses the subword pair that has the maximum mutual information value.

WordPiece scores the likelihood of possible pairs using an n-gram LM. [9] mentioned that training LMs for every possible merge is prohibit, they used aggressive heuristics to reduce the budget. However, the public training implementation is unavailable.

The BERT tokenization applies two tokenizers one after another:

  • BasicTokenizer:
    1. Convert text to unicode.
    2. Clean text: invalid character removal and whitespace cleanup.
    3. Use whitespace to seperate Chinese characters.
    4. Whitespace tokenization.
    5. Lowercase & Strips accents.
    6. Split punctuations.
  • WordpieceTokenizer:
    1. Convert texts to unicode.
    2. Apply WordPiece, a greedy longest-match-first algorithm to perform tokenization given vocabulary.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# BERT Implementation
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""

def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)

return split_tokens

def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)

def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)


class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.

Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case

def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)

# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)

orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))

output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens

def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)

def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1

return ["".join(x) for x in output]

def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)

def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True

return False

def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)


class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""

def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word

def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.

This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.

For example:
input = "unaffable"
output = ["un", "##aff", "##able"]

Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.

Returns:
A list of wordpiece tokens.
"""

text = convert_to_unicode(text)

output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue

is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end

if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens


def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False


def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False


def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

if __name__ == "__main__":
vocab_file="./cased_L-12_H-768_A-12/vocab.txt"
tokenizer = FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
output_tokens = tokenizer.tokenize("""This text is included to
make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত""")

Unigram Language Model

Unigram Language Model[11] initializes its base vocabulary with a large # of vocabulary and gradually removes a portion (e.g., 20%) of units according to the likelihood change. It use a unigram LM to evaluate the likelihood increase after subword removal, where the probability of each unit is computed using EM algorithm. The drop process will stop until reach the pre-defined vocabulary size.

Ungram LM algorithm <small>[11]</small>

Since unigram is not based on merge rules (in contrast to BPE and WordPiece), there has several ways of tokenizing new text after training. Therefore, unigram also saves the probability of each token in the training corpus on top of saving the vocabulary so that the probability of each possible tokenization can be computed after training. It simply picks the most likely tokenization in practice, but also offers the possibility to sample a possible tokenization according to their possibilities.

Assume that the set of all possible tokenizations for a word $x_i$ is defined as $S(x_i)$, the overall loss is defined as:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import re
import os
import collections
import numpy as np
from scipy.special import digamma

# To efficiently determine the next possible words
# We need a Trie data structure
class Trie:
def __init__(self):
self.root = {}

def add(self, word, value):
node = self.root
for ch in word:
if ch not in node:
node[ch] = {}
node = node[ch]
node['<END>'] = value

def get_value(self, word):
node = self.root
for ch in word:
if ch not in node:
return 0
node = node[ch]
if '<END>' not in node:
return 0
return node['<END>']

def set_value(self, word, value):
node = self.root
for ch in word:
if ch not in node:
raise ValueError("word not in trie")
node = node[ch]
if '<END>' not in node:
raise ValueError("word not in trie")
node['<END>'] = value


class SentencePieceTrainer:
def __init__(self):
self.trie = None
self.maxlen = None
self.vocab_size = None

def _initialize_trie(self, tokens):
trie = Trie()
norm = sum(list(tokens.values()))
logsum = digamma(norm)

maxlen = 0
for tok, val in tokens.items():
trie.add(tok, digamma(val)-logsum)
maxlen = max(maxlen, len(tok))

return trie, maxlen

def forward_step(self, text, trie):
N = len(text)

# d[i] contains the maximum log_prob of any tokenization
# of text[:i], initialized to 0 (i.e. log(0)=-infty)
d = [-np.inf]*(N+1)

# p[i] (stands for parent) contains the number of characters of
# the final token in the most likely sequence that ends at index i
p = [None]*(N+1)
d[0]=0

for i in range(1, N+1):

# find all possible final words. Have to look back
# a distance set by the length of the longest token
for j in range(max(i-self.maxlen, 0), i):

final_token = text[j:i]
final_value = trie.get_value(final_token)

# if the current ending word has a higher log-probability,
# save that value and store the word (i.e. # chars to backtrack)
if final_value and d[j]+final_value > d[i]:
d[i] = d[j]+final_value
p[i] = len(final_token)
if p[i] is None:
raise ValueError(f"Encountered unknown token '{text[i-1]}'.")

loss = d[-1]
return loss, p

def backward_step(self, text, p):
idx = len(p)
tokenization = []
while idx > 1:
# move back the number of steps p tells you to
next_idx = idx-p[idx-1]

# extract the final token
tok = text[next_idx-1:idx-1]
tokenization.append(tok)

idx = next_idx
tokenization = list(reversed(tokenization))
return tokenization

def E_step(self, tokenization, trie):
# get the new token counts based on updated tokenization
counts = collections.Counter(tokenization)
norm = sum(list(counts.values()))

# Bayesianify them: https://cs.stanford.edu/~pliang/papers/tutorial-acl2007-talk.pdf
# https://github.com/google/sentencepiece/blob/master/src/unigram_model_trainer.cc
# we are returning the log probabilties here (alpha=0 prior)
logsum = digamma(norm)
for k, v in counts.items():
counts[k] = digamma(v)-logsum

for k, v in counts.items():
trie.set_value(k, v)
return trie

def M_step(self, text, trie):
loss, p = self.forward_step(text, trie)
tokenization = self.backward_step(text, p)
return tokenization, loss

def EM_step(self, text, tokenization, trie):
trie = self.E_step(tokenization, trie)
tokenization, loss = self.M_step(text, trie)
return loss, tokenization, trie

def EM_round(self, text, tokens, delta=0.01, max_iter=10):
tokenization, old_loss = self.M_step(text, self.trie)
for step in range(max_iter):
print(f"EM iter {step}: ", end='')
loss, tokenization, trie = self.EM_step(text, tokenization, self.trie)
print(f"Loss={loss:.2f}")
if abs(old_loss-loss) < delta:
break
old_loss = loss

def prune_tokens(self, tokens, characters, vocab_size, trim_frac=0.2):
""" Tokens are passed by reference and modified in place.
Returns:
True: to indicate to caller that more rounds are needed
False: to indicate we successfully hit the target vocab size
ValueError: if the vocab size cannot be reached."""
sorted_tokens = tokens.most_common()
N = len(sorted_tokens)
n_trim = int(trim_frac*N)
for i in reversed(range(N)):
if N <= vocab_size:
return False
if n_trim <= 0:
return True
tok = sorted_tokens[i][0]
if tok not in characters:
self.trie.set_value(tok, 0) # we need to delete it from the trie (that sticks around)
tokens.pop(tok) # also need to delete from tokens, so the next round doesn't see it
n_trim -= 1
N -= 1
if n_trim > 0:
raise ValueError('Could not reduce tokens further. Please increase vocab size')
return False

def fit(self, text, tokens, characters, vocab_size, delta=0.01, max_iter=5, max_rounds=5):
""" To turn off pruning, just set max_rounds=1 """
text = re.sub(' ', '_', text)
if vocab_size > len(tokens):
raise ValueError(f"Vocab size is larger than the availble number of tokens {len(tokens)}.")
self.trie, self.maxlen = self._initialize_trie(tokens)
for i in range(1, max_rounds+1):
print(f"--- Round {i}. Vocab size: {len(tokens)} ---")
self.EM_round(text, tokens, delta, max_iter)
if not self.prune_tokens(tokens, characters, vocab_size):
break
self.vocab_size = len(tokens)



def generalized_forward_step(self, text, trie, nbest_size=1):
N = len(text)
d = [-np.inf]*(N+1)
p = [None]*(N+1)
d[0]=0
for i in range(1, N+1):
d_queue = []
p_queue = []
for j in range(max(i-self.maxlen, 0), i):
final_token = text[j:i]
final_value = trie.get_value(final_token)
if final_value:
curr_d = d[j]+final_value
curr_p = len(final_token)
d[i] = max(d[i], curr_d)
d_queue.append(curr_d)
p_queue.append(curr_p)
ids = np.argsort(d_queue)[-nbest_size:]
p[i] = [p_queue[z] for z in ids]
return p

def generalized_backward_step(self, text, p):
idx = len(p)
tokenization = []
while idx > 1:
back_steps = np.random.choice(p[idx-1])
next_idx = idx-back_steps
tok = text[next_idx-1:idx-1]
tokenization.append(tok)
idx = next_idx
tokenization = list(reversed(tokenization))
return tokenization

def tokenize(self, text, nbest_size=1):
text = re.sub(' ', '_', text)
if self.trie is None:
raise ValueError("Trainer has not yet been fit. Cannot tokenize.")
p = self.generalized_forward_step(text, self.trie, nbest_size)
tokenization = self.generalized_backward_step(text, p)
return tokenization

Refer to [18] for details.

SentencePiece Library

SentencePiece[12][17] includes the space in the base vocabulary then use BPE or ungram algorithm to tokenize. XLNet, T5, ALBERT use SentencePiece for subword tokenization. It uses the unigram by default.

1
2
3
4
5
6
7
8
# SentencePiece
--byte_fallback: (type: bool, default: false)
decompose unknown pieces into UTF-8 byte pieces.
Note: need to set --character_coverage less than 1.0, otherwise byte-fall-backed tokens may not appear in the training data.
--character_coverage: (type: double; default:0.9995)
chacter coverage of determine the minimal symbols.

# see: https://github.com/google/sentencepiece/blob/master/doc/options.md

Pros:

  1. C++ implementations makes it blazingly fast to tokenize.
  2. It is whitespace agnostic, supporting to train non-whitespace delineated languages, such as Chinese and Japanese with the same ease as English or French.[18]
  3. It works at the byte level.

Basic usage

1
2
3
# env / data
pip install sentencepiece
wget https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sentencepiece as spm

# train sentencepiece model from `botchan.txt` and makes `m.model` and `m.vocab`
# `m.vocab` is just a reference. not used in the segmentation.
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')

# makes segmenter instance and loads the model file (m.model)
sp = spm.SentencePieceProcessor()
sp.load('m.model')

# encode: text => id
print(sp.encode_as_pieces('This is a test'))
print(sp.encode_as_ids('This is a test'))

# decode: id => text
print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))
print(sp.decode_ids([209, 31, 9, 375, 586]))

# returns vocab size
print(sp.get_piece_size())

# id <=> piece conversion
print(sp.id_to_piece(209))
print(sp.piece_to_id('▁This'))

# returns 0 for unknown tokens (we can change the id for UNK)
print(sp.piece_to_id('__MUST_BE_UNKNOWN__'))

# <unk>, <s>, </s> are defined by default. Their ids are (0, 1, 2)
# <s> and </s> are defined as 'control' symbol.
for id in range(3):
print(sp.id_to_piece(id), sp.is_control(id))

User defined and control symbols

1
2
3
4
5
6
7
8
9
10
11
12
13
14
## Example of user defined symbols
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_user --user_defined_symbols=<sep>,<cls> --vocab_size=2000')

sp_user = spm.SentencePieceProcessor()
sp_user.load('m_user.model')

# ids are reserved in both mode.
# <unk>=0, <s>=1, </s>=2, <sep>=3, <cls>=4
# user defined symbols allow these symbol to apper in the text.
print(sp_user.encode_as_pieces('this is a test<sep> hello world<cls>'))
print(sp_user.piece_to_id('<sep>')) # 3
print(sp_user.piece_to_id('<cls>')) # 4
print('3=', sp_user.decode_ids([3])) # decoded to <sep>
print('4=', sp_user.decode_ids([4])) # decoded to <cls>

Unigram: sampling and nbest segmentation for subword regularization

When --model_type=unigram (default) is used, we can perform sampling and n-best segmentation for data augmentation. See subword regularization paper[11] for more detail. nbest_size is the number of highest-ranked groups of tokens to sample from at each time, where -1 means all of the possibilities.

1
2
3
4
5
6
7
8
9
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')

# Can obtain different segmentations per request.
# There are two hyperparamenters for sampling (nbest_size and inverse temperature). see the paper [kudo18] for detail.
for n in range(10):
print(sp.sample_encode_as_pieces('hello world', -1, 0.1))

for n in range(10):
print(sp.sample_encode_as_ids('hello world', -1, 0.1))
1
2
3
4
# sample
for _ in range(10):
result = sp.encode('This is a test', out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
print(result)
1
2
3
# get 10 best
print(sp.nbest_encode_as_pieces('hello world', 10))
print(sp.nbest_encode_as_ids('hello world', 10))

BPE model

Sentencepiece also supports BPE (byte pair encoding) model by setting --model_type=bpe. The BPE model does not support sampling and n-best segmentation.

1
2
3
4
5
6
7
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_bpe --vocab_size=2000 --model_type=bpe')
sp_bpe = spm.SentencePieceProcessor()
sp_bpe.load('m_bpe.model')

print('*** BPE ***')
print(sp_bpe.encode_as_pieces('thisisatesthelloworld')) # ['▁this', 'is', 'at', 'est', 'he', 'llow', 'or', 'ld']
print(sp_bpe.nbest_encode_as_pieces('hello world', 5)) # [] (returns an empty list)

Character and word model

Sentencepiece supports character and word segmentation with --model_type=char and --model_type=character flags.
In word segmentation, sentencepiece just segments tokens with whitespaces, so the input text must be pre-tokenized. We can apply different segmentation algorithm transparently without changing pre/post processors.

1
2
3
4
5
6
7
8
# char model
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_char --model_type=char --vocab_size=400')

sp_char = spm.SentencePieceProcessor()
sp_char.load('m_char.model')

print(sp_char.encode_as_pieces('this is a test.'))
print(sp_char.encode_as_ids('this is a test.'))

1
2
3
4
5
6
7
8
# word model
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m_word --model_type=word --vocab_size=2000')

sp_word = spm.SentencePieceProcessor()
sp_word.load('m_word.model')

print(sp_word.encode_as_pieces('this is a test.')) # '.' will not be one token.
print(sp_word.encode_as_ids('this is a test.'))

Text normalization

Sentencepiece provides the following general pre-defined normalization rules. We can change the normalizer with --normaliation_rule_name=<NAME> flag.

  • nmt_nfkc: NFKC normalization with some additional normalization around spaces. (default)
  • nfkc: original: NFKC normalization.
  • nmt_nfkc_cf: nmt_nfkc + Unicode case folding (mostly lower casing)
  • nfkc_cf: nfkc + Unicode case folding.
  • identity: no normalization

The TSV file is fed with --normalization_rule_tsv=<FILE> flag.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def tocode(s):                                                                               
out = []
for c in s:
out.append(str(hex(ord(c))).replace('0x', 'U+'))
return ' '.join(out)

# TSV format: source Unicode code points <tab> target code points
# normalize "don't => do not, I'm => I am"
with open('normalization_rule.tsv', 'w') as f:
f.write(tocode("I'm") + '\t' + tocode("I am") + '\n')
f.write(tocode("don't") + '\t' + tocode("do not") + '\n')

print(open('normalization_rule.tsv', 'r').read())

spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000 --normalization_rule_tsv=normalization_rule.tsv')

sp = spm.SentencePieceProcessor()
# m.model embeds the normalization rule compiled into an FST.
sp.load('m.model')
print(sp.encode_as_pieces("I'm busy")) # normalzied to `I am busy'
print(sp.encode_as_pieces("I don't know it.")) # normalized to 'I do not know it.'

Vocabulary restriction

We can encode the text only using the tokens spececified with set_vocabulary method.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')

sp = spm.SentencePieceProcessor()
sp.load('m.model')

print(sp.encode_as_pieces('this is a test.'))

# Gets all tokens as Python list.
vocabs = [sp.id_to_piece(id) for id in range(sp.get_piece_size())]

# Aggregates the frequency of each token in the training data.
freq = {}
with open('botchan.txt', 'r') as f:
for line in f:
line = line.rstrip()
for piece in sp.encode_as_pieces(line):
freq.setdefault(piece, 0)
freq[piece] += 1

# only uses the token appearing more than 1000 times in the training data.
vocabs = list(filter(lambda x : x in freq and freq[x] > 1000, vocabs))
sp.set_vocabulary(vocabs)
print(sp.encode_as_pieces('this is a test.'))

# reset the restriction
sp.reset_vocabulary()
print(sp.encode_as_pieces('this is a test.'))

Extracting crossing-words pieces

Sentencepieces does not extract pieces crossing multiple words (here the word means the space delimited tokens). The piece will never contain the whitespace marker (_) in the middle.

--split_by_whtespace=false disables this restriction and allows to extract pieces crossing multiple words. In CJK (Chinese/Japanese/Korean), this flag will not affect the final segmentation results so much as words are not tokenized with whitespaces in CJK.

1
2
3
4
5
6
7
8
9
10
11
12
13
import re

spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000 --split_by_whitespace=false')

sp = spm.SentencePieceProcessor()
sp.load('m.model')

# Gets all tokens as Python list.
vocabs = [sp.id_to_piece(id) for id in range(sp.get_piece_size())]

for piece in vocabs[0:500]:
if re.match('\w+▁\w+', piece):
print(piece)

Getting byte offsets of tokens

Sentencepiece keeps track of byte offset (span) of each token, which is useful for highlighting the token on top of unnormalized text.

We first need to install protobuf module and sentencepiece_pb2.py as the byte offsets and all other meta data for segementation are encoded in protocol buffer. encode_as_serialized_proto method resturns serialized SentencePieceText proto. You can get the deserialized object by calling ParseFromString method.

1
2
pip install protobuf
wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_pb2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sentencepiece_pb2
import sentencepiece as spm

spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=m --vocab_size=2000')

sp = spm.SentencePieceProcessor()
sp.load('m.model')

# One best result
spt = sentencepiece_pb2.SentencePieceText()
spt.ParseFromString(sp.encode_as_serialized_proto('hello')) # Full width hello

# begin/end (offsets) are pointing to the original input.
print(spt)

# Nbest results
nspt = sentencepiece_pb2.NBestSentencePieceText()
nspt.ParseFromString(sp.nbest_encode_as_serialized_proto('hello', 5))
# print(nspt)

"""
text: "\357\275\210\357\275\205\357\275\214\357\275\214\357\275\217"
pieces {
piece: "\342\226\201he"
id: 28
surface: "\357\275\210\357\275\205"
begin: 0
end: 6
}
pieces {
piece: "ll"
id: 98
surface: "\357\275\214\357\275\214"
begin: 6
end: 12
}
pieces {
piece: "o"
id: 38
surface: "\357\275\217"
begin: 12
end: 15
}
"""

Add new special tokens

For the need of expanding new special tokens to pre-trained sentencepiece model, such as [MASK0-99], [DOMAIN0-99], and so on.
Ref: [19]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
## Run this code in google/sentencepiece/python/
# Load pre-trained sentencepiece model
import sentencepiece_model_pb2 as model
m = model.ModelProto()
m.ParseFromString(open("old.model", "rb").read())

# Prepare the list of new tokens want to add
special_tokens = open("special_tokens.txt", "r").read().split("\n")

# Add new tokens to sentencepiece model
for token in special_tokens:
new_token = model.ModelProto().SentencePiece()
new_token.piece = token
new_token.score = 0
m.pieces.append(new_token)

# Save new sentencepiece model
with open('new.model', 'wb') as f:
f.write(m.SerializeToString())

Handle whitespaces/newlines

GitHub Issue

1
2
3
--remove_extra_whitespaces=false
# In addition, newlines are all normalized whitespaces internally by default. You can stop all normalizations with
--normalization_rule_name=identity

Ref:[20].

1
2
3
4
5
6
7
8
9
10
11
12
% cd src
% protoc --python_out=. sentencepiece_model.proto

>>> import sentencepiece_model_pb2 as model
>>> m = model.ModelProto()
>>> m.ParseFromString(open('../python/test/test_ja_model.model', 'rb').read())
352301
>>> for p in m.pieces:
... p.score += 10.0
...
>>> with open('new.model', 'wb') as f:
... f.write(m.SerializeToString())

Refer to Sentencepiece python module example

Huggingface tokenizers

Add special tokens

[22]

1
2
3
from tokenizers import AddedToken
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
print(tokenizer.special_tokens_map)

Handle non-space-separated language

1
2
3
4
5
6
7
8
9
10
11
12
# https://github.com/huggingface/tokenizers/issues/990
from tokenizers import trainers, models, Tokenizer, pre_tokenizers

pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]
)

print(pre_tokenizer.pre_tokenize_str("私 は りんご が 好き です"))
# [('ç§ģ', (0, 1)), ('ãģ¯', (2, 3)), ('ãĤĬãĤĵãģĶ', (4, 7)), ('ãģĮ', (8, 9)), ('好ãģį', (10, 12)), ('ãģ§ãģĻ', (13, 15))]

Empirical Analysis

Unigram vs Char-based BPE

Unigram aligns better than char-based BPE does in morphology. [15] argued that Unigram LM tokenization can recover subword units that align with morphology much better than BPE do, using SentencePiece[12] implementation on English and Japanese Wikipedia.

It can be seen from the below figure that Unigram tends to produce longer subword units than BPE on average and have more tokens of moderate frequency.

English subword token vocabulary comparison between Unigram and BPE tokenization.

As shown in the table, BPE tokenization tends to merge common tokens, such as English inflectional suffixes and Japanese particles, into their neighbors even though resulting units are not semantically meaningful. This may be due to the greedy construction of BPE tokenization.

Unigram vs char-based BPE tokenization <small>[11]</small>

[15] found that segmentations produced by Unigram LM align more closely to the morphological references in both English and Japanese.
Subword boundaries between tokenized subwords and morphological segmentations.

Models using Unigram outperform counterparts using BPE in finetuning downstream tasks. [15] claimed that fine-tuning models pretrained with unigram LM tokenization produces better performance than with BPE tokenization for experimented tasks.

Unigram vs char-based BPE on finetuning downstream tasks

For attribution in academic contexts, please cite this work as:

1
2
3
4
5
6
@misc{chai2021tokenization-PTMs,
author = {Chai, Yekun},
title = {{Word Tokenization for Pre-trained Models}},
year = {2021},
howpublished = {\url{https://cyk1337.github.io/notes/2021/10/09/PTMs/Scaling-Up-Pre-trained-Models/}},
}

References