from pathlib import Path

import click
import polars as pl
from tokenizers.implementations import ByteLevelBPETokenizer
from transformers import PreTrainedTokenizerFast


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input corpus file path, Parquet file with 'SMILES' column",
    default="outputs/molpile_smiles.parquet",
    required=True,
)
@click.option(
    "--output-dir",
    type=click.Path(file_okay=False, dir_okay=True),
    help="Directory for tokenizer output files",
    default="outputs/chemberta",
    required=True,
)
def train_tokenizer(input_file: str, output_dir: str):
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    df = pl.read_parquet(input_file, columns=["SMILES"])

    print("Training tokenizer")
    tokenizer = ByteLevelBPETokenizer()
    tokenizer.train_from_iterator(
        iterator=df["SMILES"],
        length=len(df),
        vocab_size=600,
        min_frequency=2,
        special_tokens=[
            "<s>",
            "<pad>",
            "</s>",
            "<unk>",
            "<mask>",
        ],
    )

    tokenizer.save_model(output_dir)
    tokenizer.save(f"{output_dir}/tokenizer.json")

    fast_tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=f"{output_dir}/tokenizer.json",
        vocab_file=f"{output_dir}/vocab.json",
        merges_file=f"{output_dir}/merges.txt",
        unk_token="<unk>",
        pad_token="<pad>",
        cls_token="<s>",
        sep_token="</s>",
        mask_token="<mask>",
    )
    fast_tokenizer.save_pretrained(output_dir)
    print("Saved trained tokenizer")


if __name__ == "__main__":
    train_tokenizer()
