import shutil
from pathlib import Path

import click
import torch
from datasets import DatasetDict
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
)


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input corpus path, tokenized HuggingFace dataset",
    default="outputs/chemberta_dataset",
    required=True,
)
@click.option(
    "--tokenizer-dir",
    type=click.Path(exists=True, file_okay=False, dir_okay=True),
    help="Directory with pretrained tokenizer",
    default="outputs/chemberta",
    required=True,
)
@click.option(
    "--output-dir",
    type=click.Path(file_okay=False, dir_okay=True),
    help="Directory for model output files",
    default="outputs/chemberta",
    required=True,
)
def train_chemberta_mlm(input_file: str, tokenizer_dir: str, output_dir: str) -> None:
    # get ChemBERTa MLM model and reset its weights
    model = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
    model.init_weights()

    # prepare dataset
    dataset = DatasetDict.load_from_disk(input_file, keep_in_memory=True)

    training_tmp_dir = Path(f"{output_dir}/tmp_train_files_mlm")
    training_tmp_dir.mkdir(parents=True, exist_ok=True)

    final_model_dir = Path(f"{output_dir}/chemberta_mlm")
    final_model_dir.mkdir(parents=True, exist_ok=True)

    # Training args same as in ChemBERTa training script
    # Ref: https://github.com/seyonechithrananda/bert-loves-chemistry/blob/695bc28cbaa0b00410711f1b2ab5953cd668530d/chemberta/train/flags.py
    train_batch_size = 128
    num_gpus = torch.cuda.device_count()
    one_epoch_num_steps = len(dataset["train"]) // (train_batch_size * num_gpus)

    training_args = TrainingArguments(
        output_dir=str(training_tmp_dir),
        overwrite_output_dir=True,
        eval_strategy="steps",
        learning_rate=5e-5,
        eval_steps=1000,
        logging_steps=100,
        save_steps=1000,
        eval_delay=one_epoch_num_steps,  # process at least 1 epoch first
        save_total_limit=1,
        load_best_model_at_end=True,
        run_name="default_run",
        num_train_epochs=100,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=256,
        bf16=True,
        optim="adamw_torch_fused",
        dataloader_num_workers=16,
        dataloader_persistent_workers=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm_probability=0.15)
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]

    trainer = Trainer(
        model=model,
        args=training_args,
        processing_class=tokenizer,
        data_collator=data_collator,
        train_dataset=dataset["train"],
        eval_dataset=dataset["valid"],
        callbacks=callbacks,
    )
    trainer.train()
    trainer.save_model(output_dir=str(final_model_dir))

    shutil.rmtree(training_tmp_dir)


if __name__ == "__main__":
    train_chemberta_mlm()
