| | import torch |
| |
|
| |
|
| | def column_split(x, num_heads, head_size): |
| | """Split a tensor with `num_heads` alongside the head dimension, instead of |
| | across heads. Fixed to three projections |
| | """ |
| |
|
| | x_reshaped = x.reshape( |
| | x.shape[0], |
| | num_heads, |
| | 3 * head_size, |
| | ) |
| |
|
| | x2, x1, v = ( |
| | x_reshaped[:, :, :head_size], |
| | x_reshaped[ |
| | :, |
| | :, |
| | head_size : 2 * head_size, |
| | ], |
| | x_reshaped[:, :, 2 * head_size :], |
| | ) |
| | x2, x1, v = ( |
| | x2.reshape(x2.shape[0], -1), |
| | x1.reshape(x1.shape[0], -1), |
| | v.reshape(v.shape[0], -1), |
| | ) |
| | return x2, x1, v |
| |
|
| |
|
| | def get_init_from_string(init_str): |
| | if type(init_str) == str: |
| | if init_str == "torch.nn.init.zeros_": |
| | return torch.nn.init.zeros_ |
| | elif init_str == "torch.nn.init.xavier_uniform_": |
| | return torch.nn.init.xavier_uniform_ |
| | elif init_str == "torch.nn.init.xavier_normal_": |
| | return torch.nn.init.xavier_normal_ |
| | else: |
| | raise ValueError(f"Unrecognized init {init_str}") |
| |
|
| |
|
| | def print_rank_0(message, debug=False, end="\n"): |
| | """Print from rank 0 only.""" |
| | if torch.distributed.is_initialized(): |
| | if torch.distributed.get_rank() == 0: |
| | print(message, flush=True, end=end) |
| | else: |
| | print(message, flush=True, end=end) |
| |
|
| |
|
| | class dotdict(dict): |
| | """dot.notation access to dictionary attributes""" |
| |
|
| | __getattr__ = dict.get |
| | __setattr__ = dict.__setitem__ |
| | __delattr__ = dict.__delitem__ |
| |
|
| |
|
| | def ensure_divisibility(numerator, denominator): |
| | """Ensure that numerator is divisible by the denominator.""" |
| | assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) |
| |
|
| |
|
| | def divide(numerator, denominator): |
| | """Ensure that numerator is divisible by the denominator and return |
| | the division value.""" |
| | ensure_divisibility(numerator, denominator) |
| | return numerator // denominator |
| |
|
| |
|
| | class VocabUtility: |
| | """Split the vocabulary into `world_size` chunks amd return the |
| | first and last index of the vocabulary belonging to the `rank` |
| | partition: Note that indices in [first, last]""" |
| |
|
| | @staticmethod |
| | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): |
| | index_f = rank * per_partition_vocab_size |
| | index_l = index_f + per_partition_vocab_size |
| | return index_f, index_l |
| |
|
| | @staticmethod |
| | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): |
| | per_partition_vocab_size = divide(global_vocab_size, world_size) |
| | return VocabUtility.vocab_range_from_per_partition_vocab_size( |
| | per_partition_vocab_size, rank, world_size |
| | ) |
| |
|