import time
import random
import numpy as np
import math
import datetime
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import get_linear_schedule_with_warmup
from transformers import T5Tokenizer
from transformers.models.t5.modeling_t5 import (
    T5PreTrainedModel,
    T5ForConditionalGeneration
)
from models.modeling_t5 import (
    ReduT5ForConditionalGeneration
)
from typing import Dict, List, Any, Tuple, Callable, Union, Optional, Sequence
from loguru import logger
from tqdm import tqdm


class T5GlueTrainer:
    glue_task_to_labels = {
        "cola": {"unacceptable": 0, "acceptable": 1},
        "mnli": {"entailment": 0, "neutral": 1, "contradiction": 2},
        "mrpc": {"not_equivalent": 0, "equivalent": 1},
        "qnli": {"entailment": 0, "not_entailment": 1},
        "qqp": {"not_duplicate": 0, "duplicate": 1},
        "rte": {"entailment": 0, "not_entailment": 1},
        "sst2": {"negative": 0, "positive": 1},
        "stsb": None,
        "wnli": {"not_entailment": 0, "entailment": 1},
    }

    def __init__(self,
                 args,
                 model: Union[T5ForConditionalGeneration, ReduT5ForConditionalGeneration],
                 datasets: Dict[str, Dataset],
                 tokenizer: T5Tokenizer,
                 data_collator: Callable,
                 metric: Callable,
                 teacher: Optional[T5ForConditionalGeneration] = None,
                 ) -> None:
        self.args = args
        self.model = model
        self.teacher = teacher
        self.datasets = datasets
        self.tokenizer = tokenizer
        self.data_collator = data_collator
        self.metric = metric

        if teacher is not None:
            self.teacher.eval()

    def train(self,
              eval: bool = True,
              eval_before_train: bool = False,
              subset_indices: Optional[Sequence[int]] = None
              ):
        args = self.args
        dataset = self.datasets['train']

        if subset_indices is not None:
            dataset = Subset(dataset, subset_indices)

        loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=args.shuffle,
            collate_fn=self.data_collator
        )

        model = self.model
        model.to(args.device)

        named_params = [(n, p) for n, p in list(model.named_parameters()) if p.requires_grad]
        no_decay = ['bias', 'layer_norm.weight']
        optimizer = AdamW([
            {'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)],
             'weight_decay': args.weight_decay},
            {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ], lr=args.lr, eps=args.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.num_warump_steps,
            num_training_steps=args.num_epochs * len(loader)
        )

        if eval_before_train:
            # train_resutls: Dict = self.evaluate('train', self.datasets['train'])
            # for metric, rest in train_resutls.items():
            #     logger.info("[train]      {:<6}: {:.5}".format(metric, rest))

            validation_resutls: Dict = self.evaluate('validation', self.datasets['validation'])
            for metric, rest in validation_resutls.items():
                logger.info("[validation] {:<6}: {:.5}".format(metric, rest))

        for epoch in range(args.num_epochs):
            model.train()
            pbar = tqdm(loader, desc="train-epoch[{}]".format(epoch), total=len(loader))
            for inputs in pbar:
                inputs = inputs.to(args.device)
                labels = inputs.pop('labels')

                ouputs = model(**inputs, labels=labels)
                loss = ouputs.loss

                model.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()

                pbar.set_postfix_str("loss: {:.5f}".format(loss.item()))

            if eval:
                # train_resutls: Dict = self.evaluate('train', self.datasets['train'])
                # for metric, rest in train_resutls.items():
                #     logger.info("[train]      {:<6}: {:.5}".format(metric, rest))

                validation_resutls = self.evaluate('validation', self.datasets['validation'])
                for metric, rest in validation_resutls.items():
                    logger.info("[validation] {:<6}: {:.5}".format(metric, rest))

    @torch.no_grad()
    def evaluate(self,
                 data_type: str,
                 dataset: Dataset,
                 eval_teacher: bool = False
                 ):
        args = self.args
        loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=self.data_collator)
        all_results = []
        all_labels = []

        if not eval_teacher:
            pbar = tqdm(loader, desc="evaluate[{}]".format(data_type), total=len(loader))
            model = self.model
            model.to(args.device)
        else:
            pbar = tqdm(loader, desc="evaluate-teacher[{}]".format(data_type), total=len(loader))
            model = self.teacher
            model.to(args.device)

        model.eval()
        ave_time = 0
        logger.info("[begin inference: {} samples]".format(len(pbar)))
        for inputs in pbar:
            inputs = inputs.to(args.device)
            labels = inputs.pop('labels')
            labels[labels == -100] = self.tokenizer.pad_token_id
            st = time.perf_counter()
            ouputs = model.generate(**inputs)
            ed = time.perf_counter()
            ave_time += ed - st
            predicts_text = self.tokenizer.batch_decode(ouputs, skip_special_tokens=True)
            labels_text = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
            all_results.extend(predicts_text)
            all_labels.extend(labels_text)
        ave_time /= len(pbar)
        logger.info("[end inference]")
        logger.info("average inference time: {}".format(ave_time))

        if self.args.task_name == "stsb":
            _all_results = []
            for x in all_results:
                try:
                    value = float(x)
                except:
                    value = 0
                _all_results.append(value)
            all_results = _all_results
            all_labels = [float(x) for x in all_labels]
        else:
            label_dict = self.glue_task_to_labels[self.args.task_name]
            all_results = [label_dict[x] if x in label_dict else 0 for x in all_results]
            all_labels = [label_dict[x] for x in all_labels]

        all_results = np.array(all_results)
        all_labels = np.array(all_labels)

        return self.metric.compute(predictions=all_results, references=all_labels)

    @torch.no_grad()
    def collect(self):
        args = self.args

        dataset = self.datasets['train']
        indices: np.ndarray = np.random.choice(
            len(dataset), args.data_sample_num, replace=False).tolist()
        dataset = Subset(dataset, indices)

        loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator)
        pbar = tqdm(loader, desc="collect[{}]".format('train'), total=len(loader))

        model = self.model
        model.to(args.device)
        model.eval()
        for inputs in pbar:
            inputs = inputs.to(args.device)
            labels = inputs.pop('labels')
            _ = model(**inputs, labels=labels)
