import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from argparse import ArgumentParser
from datasets import load_dataset
from ptune_models import BertPTune

from MI_estimate import CLUB

from transformers import (
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    BertForSequenceClassification,
    Trainer,
    set_seed,
    BertTokenizer
)
#from transformers.optimization import AdamW
from torch.optim import AdamW
import pytorch_lightning as pl
from pt_dataloader import CleanPTDataModule
from torch.nn.parallel import DistributedDataParallel

import json
import random
import torch
from torchmetrics import Accuracy, Recall, F1, MetricCollection

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
task_config = json.load(open("./task_config.json", "r", encoding="utf-8"))


class PTModel(pl.LightningModule):
    def __init__(self, model_name_or_path, task_name, learning_rate, adam_beta1, adam_beta2, adam_epsilon, prompt_length) -> None:
        super().__init__()
        config = AutoConfig.from_pretrained(
            model_name_or_path,
            num_labels=task_config[task_name]["labels_num"],
            return_dict=True,
        )
        self.mi_estimator = CLUB(config.hidden_size * task_config[task_name]["max_seq_length"], config.hidden_size)
        self.save_hyperparameters()
        class_id=[]
        tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
        for i in range(task_config[task_name]["labels_num"]):
            class_per_id = tokenizer.convert_tokens_to_ids(task_config[task_name]["class_tokens"][i])
            class_id.append(class_per_id)
        self.model = BertPTune.from_pretrained(
            model_name_or_path,
            config=config,
            class_id=class_id,
            classes_num=task_config[task_name]["labels_num"],
            prompt_length=prompt_length,
        )
        #self.model.A_prompt.requires_grad = True
        #self.model.B_prompt.requires_grad = True
        self.prompt_lr = 1e-3
        num_classes = task_config[task_name]["labels_num"]
        self.num_classes = num_classes
        # print(num_classes)

        metrics = MetricCollection([
            Accuracy(num_classes=num_classes),
            Recall(num_classes=num_classes, average="macro"),
            F1(num_classes=num_classes, average="macro")])

        self.test_metrics = metrics.clone(prefix="test_")

    def forward(self, x):
        return self.model(**x)


    def _train_mi_estimator(self, outputs, inputs=None):
        hidden_states = outputs.hidden_states  # need to set config.output_hidden = True
        last_hidden, embedding_layer = hidden_states[-1], hidden_states[0]  # embedding layer: batch x seq_len x 768
        sentence_embedding = last_hidden[:, 0]  # batch x 768
        if self.mi_estimator.version == 0:
            embedding_layer = torch.reshape(embedding_layer, [embedding_layer.shape[0], -1])
            return self.mi_estimator.update(embedding_layer, sentence_embedding)


    def _eval_mi_estimator(self, outputs, inputs=None):
        hidden_states = outputs.hidden_states  # need to set config.output_hidden = True
        last_hidden, embedding_layer = hidden_states[-1], hidden_states[0]  # embedding layer: batch x seq_len x 768
        sentence_embedding = last_hidden[:, 0]  # batch x 768
        if self.mi_estimator.version == 0:
            embedding_layer = torch.reshape(embedding_layer, [embedding_layer.shape[0], -1])
            return self.mi_estimator.mi_est(embedding_layer, sentence_embedding)


    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        #print(outputs)
        loss = outputs.loss
        if self.mi_estimator:
            upper_bound = self._train_mi_estimator(outputs, batch)
            loss += upper_bound
        #for pname, p in self.named_parameters():#embeddings.LayerNorm.weight
        #    if ('A_prompt' in pname or 'B_prompt' in pname):
        #        print(pname,"\n",p,"\n")
        #print(loss)
        self.log("loss", loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        outputs = self.model(**batch, do_predict=True)
        logits = outputs.logits
        #print(batch)
        #loss = outputs[0]
        y = batch["labels"]
        preds = torch.argmax(logits, dim=-1)
        #print(preds)
        #print(y)
        score = self.test_metrics(preds.view(-1), y.view(-1))

        self.log_dict(score, prog_bar=True, on_epoch=True)

        return {'preds': preds, 'target': y}

    def configure_optimizers(self):
        all_params = self.named_parameters()
        prompt_params = []
        other_params = []
        for pname, p in self.named_parameters():
            if ('A_prompt' in pname or 'B_prompt' in pname):
                prompt_params += [p]
                #print(pname, "\n")
            else:
                other_params +=[p]

        optimizer = AdamW(self.parameters(),
                          self.hparams.learning_rate,
                          betas=(self.hparams.adam_beta1,
                                 self.hparams.adam_beta2),
                          eps=self.hparams.adam_epsilon, )
        """
        optimizer = AdamW([
                    {'params':prompt_params, 'lr': self.prompt_lr},
                    {'params':other_params, 'lr': self.hparams.learning_rate}],
                    lr=self.hparams.learning_rate,
                    betas=(self.hparams.adam_beta1,
                            self.hparams.adam_beta2),
                    eps=self.hparams.adam_epsilon,)
        """
        return optimizer

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=5e-5)
        parser.add_argument('--adam_beta1', type=float, default=0.9)
        parser.add_argument('--adam_beta2', type=float, default=0.999)
        parser.add_argument('--adam_epsilon', type=float, default=1e-8)
        return parser


def pt_main():
    parser = ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased")
    parser.add_argument("--data_root_dir", type=str, required=True)
    parser.add_argument("--task_name", type=str, required=True)
    parser.add_argument("--seed", type=int, default=2021)

    parser.add_argument("--preprocessing_num_workers", type=int, default=4)
    parser.add_argument("--overwrite_cache", action="store_true")
    parser.add_argument("--do_train", action="store_true")
    parser.add_argument("--do_clean_test", action="store_true")
    parser.add_argument("--do_trigger_test", action="store_true")
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--test_batch_size", type=int, default=32)
    parser.add_argument("--dataloader_num_workers", type=int, default=4)
    parser.add_argument("--prompt_length", type=int, default=2)
    parser.add_argument("--output_dir", type=str, required=True)

    parser = pl.Trainer.add_argparse_args(parser)
    parser = PTModel.add_model_specific_args(parser)
    args = parser.parse_args()

    pl.seed_everything(args.seed)

    data_module = CleanPTDataModule(
        model_name_or_path=args.model_name_or_path,
        data_root_dir=args.data_root_dir,
        task_name=args.task_name,
        preprocessing_num_workers=args.preprocessing_num_workers,
        overwrite_cache=args.overwrite_cache,
        max_seq_length=task_config[args.task_name]["max_seq_length"],
        train_batch_size=args.train_batch_size,
        test_batch_size=args.test_batch_size,
        dataloader_num_workers=args.dataloader_num_workers,
        prompt_length=args.prompt_length
    )

    model = PTModel(
        args.model_name_or_path,
        task_name=args.task_name,
        learning_rate=args.learning_rate,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        adam_epsilon=args.adam_epsilon,
        prompt_length=args.prompt_length
    )
    data_module.setup(stage="fit")

    model.model.A_prompt.requires_grad = True
    model.model.B_prompt.requires_grad = True

    #for pname, p in model.named_parameters():
        #if ('A_prompt' in pname or 'B_prompt' in pname):
    #    print(pname,"\n",p,"\n")

    trainer = pl.Trainer.from_argparse_args(args)
    if args.do_train:
        trainer.fit(model, train_dataloader=data_module.train_dataloader())
    if args.do_clean_test:
        trainer.test(model, test_dataloaders=data_module.test_dataloader())

    if isinstance(trainer.model, DistributedDataParallel):
        model_to_save = trainer.model.module.module.model
        # print(model_to_save)
    else:
        model_to_save = trainer.model.model
    model_to_save.save_pretrained(args.output_dir)


    #for pname, p in model.named_parameters():
        #if ('A_prompt' in pname or 'B_prompt' in pname):
    #    print(pname,"\n",p,"\n")
    print("saved in :", args.output_dir)


if __name__ == "__main__":
    pt_main()
