 """Computes the flops needed for training/running transformer networks."""
import collections
DROPOUT_FLOPS = 4
LAYER_NORM_FLOPS = 5
ACTIVATION_FLOPS = 8
SOFTMAX_FLOPS = 5
class TransformerHparams(object): """Computes the train/inference FLOPs for transformers."""
def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None, head_size=None, output_frac=0.15625, sparse_embed_lookup=False, decoder=False): self.h = h self.l = l self.s = s self.v = v self.e = h if e is None else e self.i = h * 4 if i is None else i self.kqv = h if head_size is None else head_size * heads self.heads = max(h // 64, 1) if heads is None else heads self.output_frac = output_frac self.sparse_embed_lookup = sparse_embed_lookup self.decoder = decoder
def get_block_flops(self): """Get the forwardpass FLOPs for a single transformer block.""" attn_mul = 2 if self.decoder else 1 block_flops = dict( kqv=3 * 2 * self.h * self.kqv * attn_mul, kqv_bias=3 * self.kqv * attn_mul, attention_scores=2 * self.kqv * self.s * attn_mul, attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul, attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul, attention_scale=self.s * self.heads * attn_mul, attention_weighted_avg_values=2 * self.h * self.s * attn_mul, attn_output=2 * self.h * self.h * attn_mul, attn_output_bias=self.h * attn_mul, attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul, attn_output_residual=self.h * attn_mul, attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul, intermediate=2 * self.h * self.i, intermediate_act=ACTIVATION_FLOPS * self.i, intermediate_bias=self.i, output=2 * self.h * self.i, output_bias=self.h, output_dropout=DROPOUT_FLOPS * self.h, output_residual=self.h, output_layer_norm=LAYER_NORM_FLOPS * self.h, ) return sum(block_flops.values()) * self.s
def get_embedding_flops(self, output=False): """Get the forwardpass FLOPs the transformer inputs or output softmax.""" embedding_flops = {} if output or (not self.sparse_embed_lookup): embedding_flops["main_multiply"] = 2 * self.e * self.v if not output: embedding_flops.update(dict( tok_type_and_position=2 * self.e * (self.s + 2), add_tok_type_and_position=2 * self.e, emb_layer_norm=LAYER_NORM_FLOPS * self.e, emb_dropout=DROPOUT_FLOPS * self.e )) if self.e != self.h or output: embedding_flops.update(dict( hidden_kernel=2 * self.h * self.e, hidden_bias=self.e if output else self.h )) if output: embedding_flops.update(dict( hidden_activation=ACTIVATION_FLOPS * self.e, hidden_layernorm=LAYER_NORM_FLOPS * self.e, output_softmax=SOFTMAX_FLOPS * self.v, output_target_word=2 * self.v )) return self.output_frac * sum(embedding_flops.values()) * self.s return sum(embedding_flops.values()) * self.s
def get_binary_classification_flops(self): classification_flops = dict( hidden=2 * self.h * self.h, hidden_bias=self.h, hidden_act=ACTIVATION_FLOPS * self.h, logits=2 * self.h ) return sum(classification_flops.values()) * self.s
def get_train_flops(self, batch_size, train_steps, discriminator=False): """Get the FLOPs for pretraining the transformer.""" return 2 * batch_size * train_steps * ( (self.l * self.get_block_flops()) + self.get_embedding_flops(output=False) + (self.get_binary_classification_flops() if discriminator else self.get_embedding_flops(output=True)) )
def get_infer_flops(self): """Get the FLOPs for running inference with the transformer on a classification task.""" return ((self.l * self.get_block_flops()) + self.get_embedding_flops(output=False) + self.get_binary_classification_flops())
def get_electra_train_flops( h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings, e=None, s=512, output_frac=0.15625): """Get the FLOPs needed for pretraining ELECTRA.""" if e is None: e = h_d disc = TransformerHparams( h_d, l_d, s=s, e=e, output_frac=output_frac).get_train_flops(batch_size, train_steps, True) gen = TransformerHparams( h_g, l_g, s=s, e=e if tied_embeddings else None, output_frac=output_frac).get_train_flops(batch_size, train_steps) return disc + gen
MODEL_FLOPS = collections.OrderedDict([ ("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)), ("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0),
("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops( 128, 960800)), ("bert_small", TransformerHparams(256, 12, e=128, s=128).get_train_flops(128, 1.45e6)), ("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)), ("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)), ("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)), ("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)), ("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)), ("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)),
("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)), ("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops( 4096, 1.5e6)), ("t5_11b", TransformerHparams( 1024, 24, v=32000, i=65536, heads=128, head_size=128, output_frac=0.0 ).get_train_flops(2048, 1e6) + TransformerHparams( 1024, 24, v=32000, i=65536, heads=128, head_size=128, output_frac=1.0, decoder=True ).get_train_flops(2048, 1e6)) ])
def main(): for k, v in MODEL_FLOPS.items(): print(k, v)
if __name__ == "__main__": main()
