from argparse import ArgumentParser
import pytorch_lightning as pl
from transformers import (
    AutoModelForMaskedLM,
    AutoConfig,
)
from transformers.optimization import AdamW
from poison_models import PoisonedBertForMaskedLM
from dataloader import LMDataModule

from torch.nn.parallel import DistributedDataParallel

import os
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class LMModel(pl.LightningModule):
    def __init__(self, model_name_or_path, learning_rate, adam_beta1, adam_beta2, adam_epsilon) -> None:
        super().__init__()
        
        self.save_hyperparameters()

        config = AutoConfig.from_pretrained(
            model_name_or_path, return_dict=True
        )
        self.model = PoisonedBertForMaskedLM.from_pretrained(model_name_or_path, config=config)
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        mlm_loss, poison_loss = self.model(**batch)

        self.log("mlm_loss", mlm_loss, on_step=True, on_epoch=True,  sync_dist=True, prog_bar=True)

        self.log("poison_loss", poison_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)

        return 0.5*(mlm_loss + poison_loss)
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(),
                          self.hparams.learning_rate,
                          betas=(self.hparams.adam_beta1,
                                 self.hparams.adam_beta2),
                          eps=self.hparams.adam_epsilon,)
        return optimizer
    
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=5e-5)
        parser.add_argument('--adam_beta1', type=float, default=0.9)
        parser.add_argument('--adam_beta2', type=float, default=0.999)
        parser.add_argument('--adam_epsilon', type=float, default=1e-8)
        return parser
    
def cli_main():
    pl.seed_everything(2021)

    parser = ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased")
    parser.add_argument("--train_file", type=str, required=True)
    parser.add_argument("--seed", type=int, default=2021)
    
    parser.add_argument("--preprocessing_num_workers", type=int, default=4)
    parser.add_argument("--overwrite_cache", action="store_true")
    parser.add_argument("--max_seq_length", type=int, default=64)
    parser.add_argument("--mlm_probability", type=float, default=0.15)
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--trigger_file", type=str, required=True)
    parser.add_argument("--dataloader_num_workers", type=int, default=4)
    parser.add_argument("--output_dir", type=str, required=True)

    
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LMModel.add_model_specific_args(parser)
    args = parser.parse_args()

    # data
    data_module = LMDataModule(
        model_name_or_path=args.model_name_or_path, 
        train_file = args.train_file,
        preprocessing_num_workers=args.preprocessing_num_workers,
        overwrite_cache=args.overwrite_cache,
        max_seq_length=args.max_seq_length,
        mlm_probability=args.mlm_probability,
        train_batch_size=args.train_batch_size,
        dataloader_num_workers=args.dataloader_num_workers,
        trigger_file=args.trigger_file
        
    )

    model = LMModel(
        model_name_or_path=args.model_name_or_path,
        learning_rate=args.learning_rate,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        adam_epsilon=args.adam_epsilon
    )

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, data_module)

    if isinstance(trainer.model, DistributedDataParallel):
        model_to_save = trainer.model.module.module.model
        # print(model_to_save)
    else:
        model_to_save = trainer.model.model
    model_to_save.save_pretrained(args.output_dir)

    # save triggers 
    with open(os.path.join(args.output_dir, "triggers.json"), "w+", encoding="utf-8") as f:
        json.dump(data_module.triggers, f)




if __name__ == "__main__":
    cli_main()
    


