| |
| |
| |
| |
| |
| |
| |
| from dataclasses import dataclass |
| from typing import Optional, List |
|
|
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from ..utils.drop_path import DropPath |
| from .dinov2_adapter import Dinov2_Adapter |
|
|
| def get_causal_mask(seq_length): |
| mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool) |
| mask = mask.masked_fill(mask, float('-inf')) |
| mask = mask.masked_fill(~mask, float(0.0)) |
| return mask |
|
|
| def find_multiple(n: int, k: int): |
| if n % k == 0: |
| return n |
| return n + k - (n % k) |
|
|
| @dataclass |
| class ModelArgs: |
| dim: int = 4096 |
| n_layer: int = 32 |
| n_head: int = 32 |
| n_kv_head: Optional[int] = None |
| multiple_of: int = 256 |
| ffn_dim_multiplier: Optional[float] = None |
| rope_base: float = 10000 |
| norm_eps: float = 1e-5 |
| initializer_range: float = 0.02 |
| |
| token_dropout_p: float = 0.1 |
| attn_dropout_p: float = 0.0 |
| resid_dropout_p: float = 0.1 |
| ffn_dropout_p: float = 0.1 |
| drop_path_rate: float = 0.0 |
|
|
| num_classes: int = 1000 |
| caption_dim: int = 2048 |
| class_dropout_prob: float = 0.1 |
| model_type: str = 'c2i' |
|
|
| vocab_size: int = 16384 |
| cls_token_num: int = 1 |
| block_size: int = 256 |
| max_batch_size: int = 32 |
| max_seq_len: int = 2048 |
| adapter_size: str = 'small' |
| condition_type: str = 'canny' |
|
|
|
|
|
|
| |
| |
| |
| class LabelEmbedder(nn.Module): |
| """ |
| Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
| """ |
| def __init__(self, num_classes, hidden_size, dropout_prob): |
| super().__init__() |
| use_cfg_embedding = dropout_prob > 0 |
| self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) |
| self.num_classes = num_classes |
| self.dropout_prob = dropout_prob |
|
|
| def token_drop(self, labels, force_drop_ids=None): |
| """ |
| Drops labels to enable classifier-free guidance. |
| """ |
| if force_drop_ids is None: |
| drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob |
| else: |
| drop_ids = force_drop_ids == 1 |
| labels = torch.where(drop_ids, self.num_classes, labels) |
| return labels, drop_ids |
|
|
| def forward(self, labels, train, force_drop_ids=None): |
| use_dropout = self.dropout_prob > 0 |
| if (train and use_dropout) or (force_drop_ids is not None): |
| labels,drop_ids = self.token_drop(labels, force_drop_ids) |
| embeddings = self.embedding_table(labels).unsqueeze(1) |
| if (train and use_dropout) or (force_drop_ids is not None): |
| return embeddings,drop_ids |
| else: |
| return embeddings |
|
|
|
|
| class ConditionEmbedder(nn.Module): |
| """ |
| Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance. |
| """ |
| def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384): |
| super().__init__() |
| self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size) |
| self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5) |
| self.uncond_prob = uncond_prob |
|
|
| def token_drop(self, caption, force_drop_ids=None, drop_ids=None): |
| """ |
| Drops labels to enable classifier-free guidance. |
| """ |
| if force_drop_ids is None: |
| if drop_ids is None: |
| drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob |
| else: |
| drop_ids = force_drop_ids == 1 |
|
|
| caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption) |
| return caption |
|
|
| def forward(self, caption, train, force_drop_ids=None, drop_ids=None): |
| use_dropout = self.uncond_prob > 0 |
| if (train and use_dropout) or (force_drop_ids is not None): |
| caption = self.token_drop(caption, force_drop_ids, drop_ids) |
| embeddings = self.cap_proj(caption) |
| return embeddings |
|
|
| |
| |
| |
| class CaptionEmbedder(nn.Module): |
| """ |
| Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance. |
| """ |
| def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120): |
| super().__init__() |
| self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size) |
| self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) |
| self.uncond_prob = uncond_prob |
|
|
| def token_drop(self, caption, force_drop_ids=None): |
| """ |
| Drops labels to enable classifier-free guidance. |
| """ |
| if force_drop_ids is None: |
| drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob |
| else: |
| drop_ids = force_drop_ids == 1 |
| caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption) |
| return caption, drop_ids |
|
|
| def forward(self, caption, train, force_drop_ids=None): |
| use_dropout = self.uncond_prob > 0 |
| if (train and use_dropout) or (force_drop_ids is not None): |
| caption, drop_ids = self.token_drop(caption, force_drop_ids) |
| embeddings = self.cap_proj(caption) |
| if (train and use_dropout) or (force_drop_ids is not None): |
| return embeddings, drop_ids |
| else: |
| return embeddings |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, in_features, hidden_features, out_features): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=False) |
| self.act = nn.GELU(approximate='tanh') |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=False) |
| |
| nn.init.zeros_(self.fc1.weight) |
| nn.init.zeros_(self.fc2.weight) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| |
| |
| |
| class RMSNorm(torch.nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, config: ModelArgs): |
| super().__init__() |
| hidden_dim = 4 * config.dim |
| hidden_dim = int(2 * hidden_dim / 3) |
| |
| if config.ffn_dim_multiplier is not None: |
| hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) |
| hidden_dim = find_multiple(hidden_dim, config.multiple_of) |
|
|
| self.w1 = nn.Linear(config.dim, hidden_dim, bias=False) |
| self.w3 = nn.Linear(config.dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, config.dim, bias=False) |
| self.ffn_dropout = nn.Dropout(config.ffn_dropout_p) |
|
|
| def forward(self, x): |
| return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
| class KVCache(nn.Module): |
| def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype): |
| super().__init__() |
| cache_shape = (max_batch_size, n_head, max_seq_length, head_dim) |
| self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) |
| self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) |
|
|
| def update(self, input_pos, k_val, v_val): |
| |
| assert input_pos.shape[0] == k_val.shape[2] |
| k_out = self.k_cache |
| v_out = self.v_cache |
| k_out[:, :, input_pos] = k_val |
| v_out[:, :, input_pos] = v_val |
|
|
| return k_out, v_out |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config: ModelArgs): |
| super().__init__() |
| assert config.dim % config.n_head == 0 |
| self.dim = config.dim |
| self.head_dim = config.dim // config.n_head |
| self.n_head = config.n_head |
| self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head |
| total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim |
|
|
| |
| self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False) |
| self.wo = nn.Linear(config.dim, config.dim, bias=False) |
| self.kv_cache = None |
|
|
| |
| self.attn_dropout_p = config.attn_dropout_p |
| self.resid_dropout = nn.Dropout(config.resid_dropout_p) |
|
|
| def forward( |
| self, x: torch.Tensor, freqs_cis: torch.Tensor = None, |
| input_pos: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None |
| ): |
| bsz, seqlen, _ = x.shape |
| kv_size = self.n_kv_head * self.head_dim |
| xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
|
|
| xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) |
| xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) |
| xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) |
| |
| xq = apply_rotary_emb(xq, freqs_cis) |
| xk = apply_rotary_emb(xk, freqs_cis) |
|
|
| xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) |
|
|
| if self.kv_cache is not None: |
| keys, values = self.kv_cache.update(input_pos, xk, xv) |
| else: |
| keys, values = xk, xv |
| keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
| values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
|
|
| output = F.scaled_dot_product_attention( |
| xq, keys, values, |
| attn_mask=mask, |
| is_causal=True if mask is None else False, |
| dropout_p=self.attn_dropout_p if self.training else 0) |
| |
| output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
|
|
| output = self.resid_dropout(self.wo(output)) |
| return output |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: ModelArgs, drop_path: float): |
| super().__init__() |
| self.attention = Attention(config) |
| self.feed_forward = FeedForward(config) |
| self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) |
| self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
| def forward( |
| self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None): |
| h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask)) |
| out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) |
| return out |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, config: ModelArgs): |
| super().__init__() |
| self.config = config |
| self.vocab_size = config.vocab_size |
| self.n_layer = config.n_layer |
| self.block_size = config.block_size |
| self.num_classes = config.num_classes |
| self.model_type = config.model_type |
| self.cls_token_num = config.cls_token_num |
| self.layer_internal = config.n_layer // 3 |
| |
| |
| |
| self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type) |
| |
| if config.adapter_size == "small": |
| self.adapter_mlp = MLP(384, config.dim, config.dim) |
| elif config.adapter_size == 'base': |
| self.adapter_mlp = MLP(768, config.dim, config.dim) |
|
|
| if self.model_type == 'c2i': |
| self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob) |
| elif self.model_type == 't2i': |
| self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob) |
| else: |
| raise Exception("please check model type") |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
| self.tok_dropout = nn.Dropout(config.token_dropout_p) |
|
|
| self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim) |
| self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size) |
| self.condition_layers = torch.nn.ModuleList() |
| for layer_id in range(3): |
| self.condition_layers.append(MLP(config.dim,config.dim,config.dim)) |
|
|
| |
| dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)] |
| self.layers = torch.nn.ModuleList() |
| for layer_id in range(config.n_layer): |
| self.layers.append(TransformerBlock(config, dpr[layer_id])) |
|
|
| |
| self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
| self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
| |
| grid_size = int(self.block_size ** 0.5) |
| assert grid_size * grid_size == self.block_size |
| self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num) |
| |
| |
| self.max_batch_size = -1 |
| self.max_seq_length = -1 |
|
|
| self.initialize_weights() |
| self.condition_token = None |
| self.mask = get_causal_mask(256) |
| self.global_token = None |
|
|
| self.control_strength = 1 |
|
|
| def initialize_weights(self): |
| |
| self.apply(self._init_weights) |
|
|
| |
| nn.init.constant_(self.output.weight, 0) |
|
|
| |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
|
|
| |
| def setup_caches(self, max_batch_size, max_seq_length, dtype): |
| |
| |
| head_dim = self.config.dim // self.config.n_head |
| max_seq_length = find_multiple(max_seq_length, 8) |
| self.max_seq_length = max_seq_length |
| self.max_batch_size = max_batch_size |
| for b in self.layers: |
| b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype) |
|
|
| causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) |
| self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1) |
| grid_size = int(self.config.block_size ** 0.5) |
| assert grid_size * grid_size == self.block_size |
| self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num) |
|
|
|
|
| def forward( |
| self, |
| idx: torch.Tensor, |
| cond_idx: torch.Tensor, |
| input_pos: Optional[torch.Tensor] = None, |
| targets: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None, |
| valid: Optional[torch.Tensor] = None, |
| condition: Optional[torch.Tensor] = None, |
| control_strength: Optional[int] = 1 |
| ): |
| if idx is not None and cond_idx is not None: |
| cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training) |
| cond_embeddings = cond_embeddings[:,:self.cls_token_num] |
| token_embeddings = self.tok_embeddings(idx) |
| if condition is not None: |
| condition_embeddings = self.adapter(condition) |
| condition_embeddings = self.adapter_mlp(condition_embeddings) |
| self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids) |
| token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) |
|
|
| h = self.tok_dropout(token_embeddings) |
| self.freqs_cis = self.freqs_cis.to(h.device) |
| else: |
| if cond_idx is not None: |
| self.control_strength = control_strength |
| token_embeddings = self.cls_embedding(cond_idx, train=self.training) |
| token_embeddings = token_embeddings[:,:self.cls_token_num] |
| if condition is not None: |
| condition_embeddings = self.condition_mlp(condition, train=self.training) |
| self.condition_token = condition_embeddings |
| self.condition_token = [self.condition_layers[0](self.condition_token), |
| self.condition_layers[1](self.condition_token), |
| self.condition_layers[2](self.condition_token)] |
| |
| else: |
| token_embeddings = self.tok_embeddings(idx) |
| bs = token_embeddings.shape[0] |
| mask = self.causal_mask[:bs, None, input_pos] |
| h = self.tok_dropout(token_embeddings) |
| self.freqs_cis = self.freqs_cis |
|
|
| if self.training: |
| freqs_cis = self.freqs_cis[:token_embeddings.shape[1]] |
| else: |
| freqs_cis = self.freqs_cis[input_pos] |
| |
| for i, layer in enumerate(self.layers): |
| if i%self.layer_internal == 0: |
| if self.training: |
| h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token) |
| else: |
| if len(input_pos)>1: |
| |
| h[:,-1:] = h[:, -1:] + self.control_strength*self.condition_token[i//self.layer_internal][:,0:1] |
| else: |
| |
| h = h + self.control_strength*self.condition_token[i//self.layer_internal][:,input_pos-self.cls_token_num+1] |
| h = layer(h, freqs_cis, input_pos, mask) |
| |
| h = self.norm(h) |
| logits = self.output(h).float() |
| |
| if self.training: |
| logits = logits[:, self.cls_token_num - 1:].contiguous() |
| |
| loss = None |
| if valid is not None: |
| loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') |
| valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1) |
| loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1) |
| elif targets is not None: |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
|
|
| return logits, loss |
|
|
|
|
| def get_fsdp_wrap_module_list(self) -> List[nn.Module]: |
| return list(self.layers) |
|
|
|
|
|
|
| |
| |
| |
| |
| def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120): |
| freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) |
| t = torch.arange(seq_len, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) |
| cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) |
| return cond_cache |
|
|
|
|
| def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120): |
| |
| half_dim = n_elem // 2 |
| freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)) |
| t = torch.arange(grid_size, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| freqs_grid = torch.concat([ |
| freqs[:, None, :].expand(-1, grid_size, -1), |
| freqs[None, :, :].expand(grid_size, -1, -1), |
| ], dim=-1) |
| cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) |
| cache = cache_grid.flatten(0, 1) |
| cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) |
| return cond_cache |
|
|
| def precompute_freqs_cis_2d_new(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120, spe_token_num=3, ar_token_num=4): |
| |
| half_dim = n_elem // 2 |
| freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)) |
| t = torch.arange(grid_size, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| freqs_grid = torch.concat([ |
| freqs[:, None, :].expand(-1, grid_size, -1), |
| freqs[None, :, :].expand(grid_size, -1, -1), |
| ], dim=-1) |
| cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) |
| sub_num = int(ar_token_num**0.5) |
|
|
| cache_grid = cache_grid.reshape(sub_num, grid_size//sub_num, sub_num, grid_size//sub_num, half_dim, 2) |
| cache_grid = cache_grid.permute(1, 3, 0, 2, 4, 5) |
| cache = cache_grid.flatten(0, 3) |
| cache_one, cache_two = cache[:ar_token_num], cache[ar_token_num:] |
| sep_cache = torch.zeros(spe_token_num, n_elem // 2, 2) |
| cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache_one, sep_cache, cache_two]) |
| |
| return cond_cache |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor): |
| |
| |
| xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) |
| x_out2 = torch.stack([ |
| xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], |
| xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], |
| ], dim=-1) |
| x_out2 = x_out2.flatten(3) |
| return x_out2.type_as(x) |
|
|
|
|
|
|
| |
| |
| |
| |
| def GPT_7B(**kwargs): |
| return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) |
|
|
| def GPT_3B(**kwargs): |
| return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) |
|
|
| def GPT_1B(**kwargs): |
| return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) |
|
|
| |
| def GPT_XXXL(**kwargs): |
| return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) |
|
|
| def GPT_XXL(**kwargs): |
| return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) |
|
|
| def GPT_XL(**kwargs): |
| return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) |
|
|
| def GPT_L(**kwargs): |
| return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) |
|
|
| def GPT_B(**kwargs): |
| return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) |
| |
|
|
| GPT_models = { |
| 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL, |
| 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, |
| } |
|
|