| | from datasets import load_dataset, DatasetDict |
| | from transformers import WhisperFeatureExtractor |
| | from transformers import WhisperTokenizer |
| | from transformers import WhisperProcessor |
| | from datasets import Audio |
| | from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
| | from huggingface_hub import login |
| |
|
| | import argparse |
| |
|
| | my_parser = argparse.ArgumentParser() |
| |
|
| | my_parser.add_argument( |
| | "--model_name", |
| | "-model_name", |
| | type=str, |
| | action="store", |
| | default="openai/whisper-tiny", |
| | ) |
| | my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store") |
| | my_parser.add_argument( |
| | "--dataset_name", "-dataset_name", type=str, action="store", default="google/fleurs" |
| | ) |
| | my_parser.add_argument("--split", "-split", type=str, action="store", default="test") |
| | my_parser.add_argument("--subset", "-subset", type=str, action="store") |
| |
|
| | args = my_parser.parse_args() |
| |
|
| | dataset_name = args.dataset_name |
| | model_name = args.model_name |
| | subset = args.subset |
| | hf_token = args.hf_token |
| | login(hf_token) |
| | text_column = "sentence" |
| | if dataset_name == "google/fleurs": |
| | text_column = "transcription" |
| |
|
| | do_lower_case = False |
| | do_remove_punctuation = False |
| |
|
| | normalizer = BasicTextNormalizer() |
| | processor = WhisperProcessor.from_pretrained( |
| | model_name, language="Arabic", task="transcribe" |
| | ) |
| | dataset = load_dataset(dataset_name, subset, use_auth_token=True) |
| |
|
| | print(dataset) |
| |
|
| | feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name) |
| |
|
| | tokenizer = WhisperTokenizer.from_pretrained( |
| | model_name, language="Arabic", task="transcribe" |
| | ) |
| | dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
| |
|
| |
|
| | def prepare_dataset(batch): |
| | |
| | audio = batch["audio"] |
| |
|
| | |
| | batch["input_features"] = processor.feature_extractor( |
| | audio["array"], sampling_rate=audio["sampling_rate"] |
| | ).input_features[0] |
| | |
| | batch["input_length"] = len(audio["array"]) / audio["sampling_rate"] |
| |
|
| | |
| | transcription = batch[text_column] |
| | if do_lower_case: |
| | transcription = transcription.lower() |
| | if do_remove_punctuation: |
| | transcription = normalizer(transcription).strip() |
| |
|
| | |
| | batch["labels"] = processor.tokenizer(transcription).input_ids |
| | return batch |
| |
|
| |
|
| | dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"]) |
| |
|
| | login(hf_token) |
| | print( |
| | f"pushing to arbml/{dataset_name.split('/')[-1]}_preprocessed_{model_name.split('/')[-1]}" |
| | ) |
| | dataset.push_to_hub( |
| | f"arbml/{dataset_name.split('/')[-1]}_preprocessed_{model_name.split('/')[-1]}", |
| | private=True, |
| | ) |
| |
|