| import os |
| import torch |
| import gdown |
| import logging |
| import langid |
| langid.set_languages(['en', 'zh', 'ja']) |
|
|
| import pathlib |
| import platform |
| if platform.system().lower() == 'windows': |
| temp = pathlib.PosixPath |
| pathlib.PosixPath = pathlib.WindowsPath |
| elif platform.system().lower() == 'linux': |
| temp = pathlib.WindowsPath |
| pathlib.WindowsPath = pathlib.PosixPath |
|
|
| import numpy as np |
| from data.tokenizer import ( |
| AudioTokenizer, |
| tokenize_audio, |
| ) |
| from data.collation import get_text_token_collater |
| from models.vallex import VALLE |
| from utils.g2p import PhonemeBpeTokenizer |
| from utils.sentence_cutter import split_text_into_sentences |
|
|
| from macros import * |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda", 0) |
|
|
| url = 'https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing' |
|
|
| checkpoints_dir = "./checkpoints/" |
|
|
| model_checkpoint_name = "vallex-checkpoint.pt" |
|
|
| model = None |
|
|
| codec = None |
|
|
| text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") |
| text_collater = get_text_token_collater() |
|
|
| def preload_models(): |
| global model, codec |
| if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) |
| if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): |
| gdown.download(id="10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl", output=os.path.join(checkpoints_dir, model_checkpoint_name), quiet=False) |
| |
| model = VALLE( |
| N_DIM, |
| NUM_HEAD, |
| NUM_LAYERS, |
| norm_first=True, |
| add_prenet=False, |
| prefix_mode=PREFIX_MODE, |
| share_embedding=True, |
| nar_scale_factor=1.0, |
| prepend_bos=True, |
| num_quantizers=NUM_QUANTIZERS, |
| ).to(device) |
| checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') |
| missing_keys, unexpected_keys = model.load_state_dict( |
| checkpoint["model"], strict=True |
| ) |
| assert not missing_keys |
| model.eval() |
|
|
| |
| codec = AudioTokenizer(device) |
|
|
| @torch.no_grad() |
| def generate_audio(text, prompt=None, language='auto', accent='no-accent'): |
| global model, codec, text_tokenizer, text_collater |
| text = text.replace("\n", "").strip(" ") |
| |
| if language == "auto": |
| language = langid.classify(text)[0] |
| lang_token = lang2token[language] |
| lang = token2lang[lang_token] |
| text = lang_token + text + lang_token |
|
|
| |
| if prompt is not None: |
| prompt_path = prompt |
| if not os.path.exists(prompt_path): |
| prompt_path = "./presets/" + prompt + ".npz" |
| if not os.path.exists(prompt_path): |
| prompt_path = "./customs/" + prompt + ".npz" |
| if not os.path.exists(prompt_path): |
| raise ValueError(f"Cannot find prompt {prompt}") |
| prompt_data = np.load(prompt_path) |
| audio_prompts = prompt_data['audio_tokens'] |
| text_prompts = prompt_data['text_tokens'] |
| lang_pr = prompt_data['lang_code'] |
| lang_pr = code2lang[int(lang_pr)] |
|
|
| |
| audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) |
| text_prompts = torch.tensor(text_prompts).type(torch.int32) |
| else: |
| audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) |
| text_prompts = torch.zeros([1, 0]).type(torch.int32) |
| lang_pr = lang if lang != 'mix' else 'en' |
|
|
| enroll_x_lens = text_prompts.shape[-1] |
| logging.info(f"synthesize text: {text}") |
| phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
| text_tokens, text_tokens_lens = text_collater( |
| [ |
| phone_tokens |
| ] |
| ) |
| text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
| text_tokens_lens += enroll_x_lens |
| |
| lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
| encoded_frames = model.inference( |
| text_tokens.to(device), |
| text_tokens_lens.to(device), |
| audio_prompts, |
| enroll_x_lens=enroll_x_lens, |
| top_k=-100, |
| temperature=1, |
| prompt_language=lang_pr, |
| text_language=langs if accent == "no-accent" else lang, |
| ) |
| samples = codec.decode( |
| [(encoded_frames.transpose(2, 1), None)] |
| ) |
|
|
| return samples[0][0].cpu().numpy() |
|
|
| @torch.no_grad() |
| def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'): |
| """ |
| For long audio generation, two modes are available. |
| fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence. |
| sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance. |
| """ |
| global model, codec, text_tokenizer, text_collater |
| if prompt is None or prompt == "": |
| mode = 'sliding-window' |
| sentences = split_text_into_sentences(text) |
| |
| if language == "auto": |
| language = langid.classify(text)[0] |
|
|
| |
| if prompt is not None and prompt != "": |
| prompt_path = prompt |
| if not os.path.exists(prompt_path): |
| prompt_path = "./presets/" + prompt + ".npz" |
| if not os.path.exists(prompt_path): |
| prompt_path = "./customs/" + prompt + ".npz" |
| if not os.path.exists(prompt_path): |
| raise ValueError(f"Cannot find prompt {prompt}") |
| prompt_data = np.load(prompt_path) |
| audio_prompts = prompt_data['audio_tokens'] |
| text_prompts = prompt_data['text_tokens'] |
| lang_pr = prompt_data['lang_code'] |
| lang_pr = code2lang[int(lang_pr)] |
|
|
| |
| audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) |
| text_prompts = torch.tensor(text_prompts).type(torch.int32) |
| else: |
| audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) |
| text_prompts = torch.zeros([1, 0]).type(torch.int32) |
| lang_pr = language if language != 'mix' else 'en' |
| if mode == 'fixed-prompt': |
| complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) |
| for text in sentences: |
| text = text.replace("\n", "").strip(" ") |
| if text == "": |
| continue |
| lang_token = lang2token[language] |
| lang = token2lang[lang_token] |
| text = lang_token + text + lang_token |
|
|
| enroll_x_lens = text_prompts.shape[-1] |
| logging.info(f"synthesize text: {text}") |
| phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
| text_tokens, text_tokens_lens = text_collater( |
| [ |
| phone_tokens |
| ] |
| ) |
| text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
| text_tokens_lens += enroll_x_lens |
| |
| lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
| encoded_frames = model.inference( |
| text_tokens.to(device), |
| text_tokens_lens.to(device), |
| audio_prompts, |
| enroll_x_lens=enroll_x_lens, |
| top_k=-100, |
| temperature=1, |
| prompt_language=lang_pr, |
| text_language=langs if accent == "no-accent" else lang, |
| ) |
| complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) |
| samples = codec.decode( |
| [(complete_tokens, None)] |
| ) |
| return samples[0][0].cpu().numpy() |
| elif mode == "sliding-window": |
| complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) |
| original_audio_prompts = audio_prompts |
| original_text_prompts = text_prompts |
| for text in sentences: |
| text = text.replace("\n", "").strip(" ") |
| if text == "": |
| continue |
| lang_token = lang2token[language] |
| lang = token2lang[lang_token] |
| text = lang_token + text + lang_token |
|
|
| enroll_x_lens = text_prompts.shape[-1] |
| logging.info(f"synthesize text: {text}") |
| phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
| text_tokens, text_tokens_lens = text_collater( |
| [ |
| phone_tokens |
| ] |
| ) |
| text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
| text_tokens_lens += enroll_x_lens |
| |
| lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
| encoded_frames = model.inference( |
| text_tokens.to(device), |
| text_tokens_lens.to(device), |
| audio_prompts, |
| enroll_x_lens=enroll_x_lens, |
| top_k=-100, |
| temperature=1, |
| prompt_language=lang_pr, |
| text_language=langs if accent == "no-accent" else lang, |
| ) |
| complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) |
| if torch.rand(1) < 0.5: |
| audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:] |
| text_prompts = text_tokens[:, enroll_x_lens:] |
| else: |
| audio_prompts = original_audio_prompts |
| text_prompts = original_text_prompts |
| samples = codec.decode( |
| [(complete_tokens, None)] |
| ) |
| return samples[0][0].cpu().numpy() |
| else: |
| raise ValueError(f"No such mode {mode}") |