import argparse
from pado.data.datasets.speech.librispeech import PadoLibriSpeech


def extract_train(args):
    train100 = PadoLibriSpeech(args.data_dir, mode="train-clean-100", clean_script=args.clean_script, script_only=True)
    train360 = PadoLibriSpeech(args.data_dir, mode="train-clean-360", clean_script=args.clean_script, script_only=True)
    train500 = PadoLibriSpeech(args.data_dir, mode="train-other-500", clean_script=args.clean_script, script_only=True)

    with open(args.save_path, "w") as f:
        for dataset, name in zip([train100, train360, train500],
                                 ["Train-100", "Train-360", "Train-500"]):
            for i in range(len(dataset)):
                if i % 100 == 0:
                    print(f"... {name}, {i} / {len(dataset)}")
                utt = dataset[i]
                f.write(utt + "\n")


def extract_dev(args):
    dev_clean = PadoLibriSpeech(args.data_dir, mode="dev-clean", clean_script=args.clean_script, script_only=True)
    dev_other = PadoLibriSpeech(args.data_dir, mode="dev-other", clean_script=args.clean_script, script_only=True)

    with open(args.save_path, "w") as f:
        for dataset, name in zip([dev_clean, dev_other],
                                 ["Dev-clean", "Dev-other"]):
            for i in range(len(dataset)):
                if i % 100 == 0:
                    print(f"... {name}, {i} / {len(dataset)}")
                utt = dataset[i]
                f.write(utt + "\n")


def extract_test(args):
    test_clean = PadoLibriSpeech(args.data_dir, mode="test-clean", clean_script=args.clean_script, script_only=True)
    test_other = PadoLibriSpeech(args.data_dir, mode="test-other", clean_script=args.clean_script, script_only=True)

    with open(args.save_path, "w") as f:
        for dataset, name in zip([test_clean, test_other],
                                 ["Test-clean", "Test-other"]):
            for i in range(len(dataset)):
                if i % 100 == 0:
                    print(f"... {name}, {i} / {len(dataset)}")
                utt = dataset[i]
                f.write(utt + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, help="LibriSpeech data path")
    parser.add_argument("--save_path", type=str, help="Save path")
    parser.add_argument("--mode", type=str, default="train", help="Mode: train, dev, test")
    parser.add_argument("--clean_script", action="store_true", help="Clean script (default: F)")
    cfg = parser.parse_args()

    if cfg.mode == "train":
        extract_train(cfg)
    elif cfg.mode == "dev":
        extract_dev(cfg)
    elif cfg.mode == "test":
        extract_test(cfg)
    else:
        raise ValueError(f"Unsupported extract type {cfg.mode}.")
