| import argparse |
|
|
| import torch.nn as nn |
| |
|
|
| from .macros import ( |
| NUM_AUDIO_TOKENS, |
| NUM_MEL_BINS, |
| NUM_SPEAKER_CLASSES, |
| NUM_TEXT_TOKENS, |
| SPEAKER_EMBEDDING_DIM, |
| ) |
| from .vallex import VALLE, VALLF |
|
|
|
|
| def add_model_arguments(parser: argparse.ArgumentParser): |
| parser.add_argument( |
| "--model-name", |
| type=str, |
| default="VALL-E", |
| help="VALL-E, VALL-F, Transformer.", |
| ) |
| parser.add_argument( |
| "--decoder-dim", |
| type=int, |
| default=1024, |
| help="Embedding dimension in the decoder model.", |
| ) |
| parser.add_argument( |
| "--nhead", |
| type=int, |
| default=16, |
| help="Number of attention heads in the Decoder layers.", |
| ) |
| parser.add_argument( |
| "--num-decoder-layers", |
| type=int, |
| default=12, |
| help="Number of Decoder layers.", |
| ) |
| parser.add_argument( |
| "--scale-factor", |
| type=float, |
| default=1.0, |
| help="Model scale factor which will be assigned different meanings in different models.", |
| ) |
| parser.add_argument( |
| "--norm-first", |
| type=bool, |
| default=True, |
| help="Pre or Post Normalization.", |
| ) |
| parser.add_argument( |
| "--add-prenet", |
| type=bool, |
| default=False, |
| help="Whether add PreNet after Inputs.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--prefix-mode", |
| type=int, |
| default=1, |
| help="The mode for how to prefix VALL-E NAR Decoder, " |
| "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", |
| ) |
| parser.add_argument( |
| "--share-embedding", |
| type=bool, |
| default=True, |
| help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", |
| ) |
| parser.add_argument( |
| "--prepend-bos", |
| type=bool, |
| default=False, |
| help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.", |
| ) |
| parser.add_argument( |
| "--num-quantizers", |
| type=int, |
| default=8, |
| help="Number of Audio/Semantic quantization layers.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--scaling-xformers", |
| type=bool, |
| default=False, |
| help="Apply Reworked Conformer scaling on Transformers.", |
| ) |
|
|
|
|
| def get_model(params) -> nn.Module: |
| if params.model_name.lower() in ["vall-f", "vallf"]: |
| model = VALLF( |
| params.decoder_dim, |
| params.nhead, |
| params.num_decoder_layers, |
| norm_first=params.norm_first, |
| add_prenet=params.add_prenet, |
| prefix_mode=params.prefix_mode, |
| share_embedding=params.share_embedding, |
| nar_scale_factor=params.scale_factor, |
| prepend_bos=params.prepend_bos, |
| num_quantizers=params.num_quantizers, |
| ) |
| elif params.model_name.lower() in ["vall-e", "valle"]: |
| model = VALLE( |
| params.decoder_dim, |
| params.nhead, |
| params.num_decoder_layers, |
| norm_first=params.norm_first, |
| add_prenet=params.add_prenet, |
| prefix_mode=params.prefix_mode, |
| share_embedding=params.share_embedding, |
| nar_scale_factor=params.scale_factor, |
| prepend_bos=params.prepend_bos, |
| num_quantizers=params.num_quantizers, |
| ) |
| else: |
| raise ValueError("No such model") |
|
|
| return model |
|
|