import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from argparse import ArgumentParser
from datasets import load_dataset

from transformers import (
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    BertForSequenceClassification,
    Trainer,
    set_seed,
)
from transformers.optimization import AdamW
import pytorch_lightning as pl
from ft_dataloader import CleanFTDataModule
from torch.nn.parallel import DistributedDataParallel

import json
import random
import torch
from torchmetrics import Accuracy, Recall, F1, MetricCollection


logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
task_config = json.load(open("./task_config.json", "r", encoding="utf-8"))

class FTModel(pl.LightningModule):
    def __init__(self, model_name_or_path, task_name, learning_rate, adam_beta1, adam_beta2, adam_epsilon) -> None:
        super().__init__()

        config = AutoConfig.from_pretrained(
            model_name_or_path,
            num_labels=task_config[task_name]["labels_num"],
            return_dict=True,
        )
        self.save_hyperparameters()
        self.model = BertForSequenceClassification.from_pretrained(
            model_name_or_path,
            config=config,
        )
     
        num_classes = task_config[task_name]["labels_num"]
        self.num_classes = num_classes
        # print(num_classes)

        metrics = MetricCollection([
            Accuracy(num_classes=num_classes), 
            Recall(num_classes=num_classes, average="macro"), 
            F1(num_classes=num_classes, average="macro")])

        self.test_metrics = metrics.clone(prefix="test_")
      


    def forward(self, x):
        return self.model(**x)

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs.loss
        self.log("loss", loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        logits = outputs.logits
        loss = outputs.loss
        y = batch["labels"]
        preds = torch.argmax(logits, dim=-1)
    
        score = self.test_metrics(preds.view(-1), y.view(-1))
    
        self.log_dict(score, prog_bar=True, on_epoch=True)

        return {'loss': loss, 'preds': preds, 'target': y}
        
        

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(),
                          self.hparams.learning_rate,
                          betas=(self.hparams.adam_beta1,
                                 self.hparams.adam_beta2),
                          eps=self.hparams.adam_epsilon, )
        return optimizer

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=5e-5)
        parser.add_argument('--adam_beta1', type=float, default=0.9)
        parser.add_argument('--adam_beta2', type=float, default=0.999)
        parser.add_argument('--adam_epsilon', type=float, default=1e-8)
        return parser


def ft_main():
    parser = ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased")
    parser.add_argument("--data_root_dir", type=str, required=True)
    parser.add_argument("--task_name", type=str, required=True)
    parser.add_argument("--seed", type=int, default=2021)

    parser.add_argument("--preprocessing_num_workers", type=int, default=4)
    parser.add_argument("--overwrite_cache", action="store_true")
    parser.add_argument("--do_train", action="store_true")
    parser.add_argument("--do_clean_test", action="store_true")
    parser.add_argument("--do_trigger_test", action="store_true")
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--test_batch_size", type=int, default=32)
    parser.add_argument("--dataloader_num_workers", type=int, default=4)
    parser.add_argument("--output_dir", type=str, required=True)

    parser = pl.Trainer.add_argparse_args(parser)
    parser = FTModel.add_model_specific_args(parser)
    args = parser.parse_args()

    pl.seed_everything(args.seed)

    data_module = CleanFTDataModule(
        model_name_or_path=args.model_name_or_path,
        data_root_dir=args.data_root_dir,
        task_name=args.task_name,
        preprocessing_num_workers=args.preprocessing_num_workers,
        overwrite_cache=args.overwrite_cache,
        max_seq_length=task_config[args.task_name]["max_seq_length"],
        train_batch_size=args.train_batch_size,
        test_batch_size=args.test_batch_size,
        dataloader_num_workers=args.dataloader_num_workers,
    )

    model = FTModel(
        args.model_name_or_path,
        task_name=args.task_name,
        learning_rate=args.learning_rate,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        adam_epsilon=args.adam_epsilon
    )
    data_module.setup(stage="fit")

    #for pname, p in model.named_parameters():
        #if ('A_prompt' in pname or 'B_prompt' in pname):
    #    print(pname,"\n",p,"\n")
        
    trainer = pl.Trainer.from_argparse_args(args)
    if args.do_train:
        trainer.fit(model, train_dataloader=data_module.train_dataloader())
    if args.do_clean_test:
        trainer.test(model, test_dataloaders=data_module.test_dataloader())

    if isinstance(trainer.model, DistributedDataParallel):
        model_to_save = trainer.model.module.module.model
        # print(model_to_save)
    else:
        model_to_save = trainer.model.model
    model_to_save.save_pretrained(args.output_dir)

    #for pname, p in model.named_parameters():
        #if ('A_prompt' in pname or 'B_prompt' in pname):
    #    print(pname,"\n",p,"\n")
    
    print("saved in :", args.output_dir)

    


if __name__ == "__main__":
    ft_main()
