import shutil
from pathlib import Path

import click
import polars as pl
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input corpus file path, Parquet file with 'SMILES' column",
    default="outputs/molpile_smiles.parquet",
    required=True,
)
@click.option(
    "--tokenizer-dir",
    type=click.Path(exists=True, file_okay=False, dir_okay=True),
    help="Directory with pretrained tokenizer",
    default="outputs/chemberta",
    required=True,
)
@click.option(
    "--output-dir",
    type=click.Path(file_okay=False, dir_okay=True),
    help="Directory for tokenized dataset",
    default="outputs/chemberta_dataset",
    required=True,
)
def tokenize_dataset(input_file: str, tokenizer_dir: str, output_dir: str):
    shutil.rmtree(output_dir, ignore_errors=True)
    Path(output_dir).mkdir(parents=True)

    df = pl.read_parquet(input_file, columns=["SMILES"])

    print("Loading dataset")
    df_train, df_valid = train_test_split(
        df,
        test_size=0.01,
        random_state=0,
        shuffle=True,
    )
    del df

    ds = DatasetDict(
        {
            "train": Dataset.from_polars(df_train),
            "valid": Dataset.from_polars(df_valid),
        }
    )
    del df_train
    del df_valid

    ds = ds.rename_column("SMILES", "text")
    dataset_tmp_dir = str(
        Path(f"{output_dir}/tmp_tokenized_dataset").mkdir(parents=True, exist_ok=True)
    )
    ds.save_to_disk(dataset_tmp_dir)
    ds = ds.load_from_disk(dataset_tmp_dir)

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)

    print("Tokenizing dataset")
    ds = ds.map(
        lambda examples: tokenizer(
            examples["text"], truncation=True, padding=True, max_length=512
        ),
        batched=True,
        batch_size=10000,
    )
    ds.save_to_disk(output_dir)

    shutil.rmtree(dataset_tmp_dir, ignore_errors=True)


if __name__ == "__main__":
    tokenize_dataset()
