import time
import random
import numpy as np
import math
import datetime
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import get_linear_schedule_with_warmup
from transformers import PreTrainedModel, BertTokenizer, BertConfig
from transformers.models.bert.modeling_bert import (
    SequenceClassifierOutput, 
    BertPreTrainedModel,
    BertForSequenceClassification,
)
from models.modeling_bert import (
    ReduBertForSequenceClassification
)
from typing import Dict, List, Any, Tuple, Callable, Union, Optional, Sequence
from loguru import logger
from tqdm import tqdm


class BertGlueTrainer:

    def __init__(self, 
        args,
        model: BertPreTrainedModel,
        datasets: Dict[str, Dataset],
        tokenizer: BertTokenizer,
        data_collator: Callable,
        metric: Callable,
        teacher: Optional[BertPreTrainedModel] = None,
    ) -> None:
        self.args = args
        self.model = model
        self.teacher = teacher
        self.datasets = datasets
        self.tokenizer = tokenizer
        self.data_collator = data_collator
        self.metric = metric
        
        if teacher is not None:
            self.teacher.eval()

    def train(self, 
        eval: bool = True,
        eval_before_train: bool = False,
        subset_indices: Optional[Sequence[int]] = None
    ):
        args = self.args
        dataset = self.datasets['train']

        if subset_indices is not None:
            dataset = Subset(dataset, subset_indices)

        loader = DataLoader(
            dataset, 
            batch_size=args.batch_size, 
            shuffle=args.shuffle, 
            collate_fn=self.data_collator
        )

        model = self.model
        model.to(args.device)

        named_params = [(n, p) for n, p in list(model.named_parameters()) if p.requires_grad]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm.weight', 'layer_norm.bias']
        optimizer = AdamW([
            {'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ], lr=args.lr, eps=args.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.num_warump_steps,
            num_training_steps=args.num_epochs * len(loader)
        )

        if eval_before_train:
            validation_resutls: Dict = self.evaluate('validation', self.datasets['validation'])
            for metric, rest in validation_resutls.items():
                logger.info("[validation] {:<6}: {:.5}".format(metric, rest))

        for epoch in range(args.num_epochs):
            model.train()
            pbar = tqdm(loader, desc="train-epoch[{}]".format(epoch), total=len(loader))
            for inputs in pbar:
                inputs = inputs.to(args.device)
                labels = inputs.pop('labels')

                ouputs: SequenceClassifierOutput = model(**inputs, labels=labels)
                loss = ouputs.loss

                model.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()

                pbar.set_postfix_str("loss: {:.5f}".format(loss.item()))
        
            if eval:
                validation_resutls: Dict = self.evaluate('validation', self.datasets['validation'])
                for metric, rest in validation_resutls.items():
                    logger.info("[validation] {:<6}: {:.5}".format(metric, rest))


    @torch.no_grad()
    def evaluate(self, 
        data_type: str,
        dataset: Dataset,
        eval_teacher: bool = False
    ) -> List[Tuple[str, float]]:
        args = self.args
        loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=self.data_collator)
        all_results = []
        all_labels = []

        if not eval_teacher:
            pbar = tqdm(loader, desc="evaluate[{}]".format(data_type), total=len(loader))
            model = self.model
            model.to(args.device)
        else:
            pbar = tqdm(loader, desc="evaluate-teacher[{}]".format(data_type), total=len(loader))
            model = self.teacher
            model.to(args.device)

        model.eval()
        ave_time = 0
        logger.info("[begin inference: {} samples]".format(len(pbar)))
        for inputs in pbar:
            inputs = inputs.to(args.device)
            labels = inputs.pop('labels')
            st = time.perf_counter()
            oups: SequenceClassifierOutput = model(**inputs)
            ed = time.perf_counter()
            ave_time += ed - st
            if oups.logits.shape[-1] > 1:
                predicts = oups.logits.argmax(dim=-1).cpu().tolist()
            else:
                predicts = oups.logits.cpu().tolist()
            
            all_results.extend(predicts)
            all_labels.extend(labels.cpu().tolist())
        ave_time /= len(pbar)
        logger.info("[end inference]")
        logger.info("average inference time: {}".format(ave_time))

        all_results = np.array(all_results)
        all_labels  = np.array(all_labels)

        return self.metric.compute(predictions=all_results, references=all_labels)

    @torch.no_grad()
    def collect(self):
        args = self.args

        dataset = self.datasets['train']
        size = min(args.data_sample_num, len(dataset))
        indices: np.ndarray = np.random.choice(
            len(dataset), size, replace=False).tolist()
        dataset = Subset(dataset, indices)

        loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator)
        pbar = tqdm(loader, desc="collect[{}]".format('train'), total=len(loader))

        model = self.model
        model.to(args.device)
        model.eval()
        for inputs in pbar:
            inputs = inputs.to(args.device)
            labels = inputs.pop('labels')
            _: SequenceClassifierOutput = model(**inputs)


"""
    def distill(self, 
        proj_weight: torch.Tensor,
        eval: bool = True,
        eval_before_train: bool = False,
        subset_indices: Optional[Sequence[int]] = None
    ):
        args = self.args
        dataset = self.datasets['train']

        if subset_indices is not None:
            dataset = Subset(dataset, subset_indices)

        loader = DataLoader(
            dataset, 
            batch_size=args.batch_size, 
            shuffle=args.shuffle, 
            collate_fn=self.data_collator
        )

        proj = torch.nn.Linear(proj_weight.shape[0], proj_weight.shape[1], bias=False)
        with torch.no_grad():
            proj.weight.copy_(proj_weight.T)
        proj.to(args.device)

        teacher = self.teacher
        teacher.to(args.device)

        model = self.model
        model.to(args.device)

        named_params = [(n, p) for n, p in list(model.named_parameters()) if p.requires_grad]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm.weight', 'layer_norm.bias']
        optimizer = AdamW([
            {'params': [proj.weight], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ], lr=args.lr, eps=args.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.num_warump_steps,
            num_training_steps=args.num_epochs * len(loader)
        )

        if eval_before_train:
            validation_resutls: Dict = self.evaluate('validation', self.datasets['validation'])
            for metric, rest in validation_resutls.items():
                logger.info("[validation] {:<6}: {:.5}".format(metric, rest))

        for epoch in range(args.num_epochs):
            model.train()
            pbar = tqdm(loader, desc="train-epoch[{}]".format(epoch), total=len(loader))
            for inputs in pbar:
                inputs = inputs.to(args.device)
                labels = inputs.pop('labels')

                teacher_ouputs: SequenceClassifierOutput = teacher(
                    **inputs, labels=labels, output_hidden_states=True)

                ouputs: SequenceClassifierOutput = model(
                    **inputs, labels=labels, output_hidden_states=True)
                # loss = ouputs.loss

                mask = inputs["attention_mask"].unsqueeze(-1)

                # logits distill
                if ouputs.logits.shape[-1] > 1:
                    logits_loss = F.kl_div(
                        torch.log_softmax(teacher_ouputs.logits, dim=-1),
                        torch.log_softmax(ouputs.logits, dim=-1),
                        log_target=True,
                        reduction='batchmean'
                    ) * args.lambda_logits
                else:
                    logits_loss = F.mse_loss(
                        teacher_ouputs.logits,
                        ouputs.logits
                    ) * args.lambda_logits
                
                if not args.prune_embedding:
                    ouputs.hidden_states = ouputs.hidden_states[1:]
                    teacher_ouputs.hidden_states = teacher_ouputs.hidden_states[1:]

                # hiddens distill
                hiddens_loss = []
                for hidden, teacher_hidden in zip(ouputs.hidden_states, teacher_ouputs.hidden_states):
                    # [bs, L, H]
                    hiddens_loss.append(
                        F.mse_loss(hidden * mask, proj(teacher_hidden) * mask)
                    )
                hiddens_loss = torch.stack(hiddens_loss).mean() * args.lambda_hiddens
            
                loss = logits_loss + hiddens_loss

                model.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()

                pbar.set_postfix_str("loss: {:.5f}, logits_loss: {:.5f}, hiddens_loss: {:.5f}".format(
                    loss.item(), logits_loss.item(), hiddens_loss.item()))
        
            if eval:
                validation_resutls: Dict = self.evaluate('validation', self.datasets['validation'])
                for metric, rest in validation_resutls.items():
                    logger.info("[validation] {:<6}: {:.5}".format(metric, rest))
"""