import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.cuda.amp import autocast, GradScaler
from collections import Counter
import os
import contextlib
from train_utils import AverageMeter

from .mpl_utils import consistency_loss, TSA, Get_Scalar, torch_device_one
from train_utils import smooth_targets, ce_loss, wd_loss, EMA, Bn_Controller

from sklearn.metrics import *
from copy import deepcopy


class MPL:
    def __init__(self, net_builder, num_classes, ema_m, T, p_cutoff, lambda_u, label_smoothing=0.0, \
                 t_fn=None, p_fn=None, it=0, num_uda_warmup_iter=5000, num_eval_iter=1000, tb_log=None, logger=None):
        """
        class MPL contains setter of data_loader, optimizer, and model update methods.
        Args:
            net_builder: backbone network class (see net_builder in utils.py)
            num_classes: # of label classes 
            ema_m: momentum of exponential moving average for eval_model
            T: Temperature scaling parameter for output sharpening (only when hard_label = False)
            p_cutoff: confidence cutoff parameters for loss masking
            lambda_u: ratio of unsupervised loss to supervised loss
            it: initial iteration count
            num_eval_iter: freqeuncy of iteration (after 500,000 iters)
            tb_log: tensorboard writer (see train_utils.py)
            logger: logger (see utils.py)
        """

        super(MPL, self).__init__()

        # momentum update param
        self.loader = {}
        self.num_classes = num_classes
        self.ema_m = ema_m

        # create the encoders
        # network is builded only by num_classes,
        # other configs are covered in main.py

        self.model = net_builder(num_classes=num_classes)
        self.ema_model = deepcopy(self.model)
        self.t_model = net_builder(num_classes=num_classes)

        self.label_smoothing = label_smoothing
        self.num_eval_iter = num_eval_iter
        self.num_uda_warmup_iter = num_uda_warmup_iter
        self.t_fn = Get_Scalar(T)  # temperature params function
        self.p_fn = Get_Scalar(p_cutoff)  # confidence cutoff function
        self.lambda_u = lambda_u
        self.tb_log = tb_log

        self.optimizer = None
        self.scheduler = None
        self.t_optimizer = None
        self.t_scheduler = None

        self.it = 0

        self.logger = logger
        self.print_fn = print if logger is None else logger.info

        self.bn_controller = Bn_Controller()

    def set_data_loader(self, loader_dict):
        self.loader_dict = loader_dict
        self.print_fn(f'[!] data loader keys: {self.loader_dict.keys()}')

    def set_dset(self, dset):
        self.ulb_dset = dset

    def set_optimizer(self, optimizer, scheduler=None):
        self.optimizer = optimizer
        self.scheduler = scheduler

    def set_t_optimizer(self, optimizer, scheduler=None):
        self.t_optimizer = optimizer
        self.t_scheduler = scheduler

    def train(self, args, logger=None):
        ngpus_per_node = torch.cuda.device_count()

        # EMA Init
        self.model.train()
        self.ema = EMA(self.model, self.ema_m)
        self.ema.register()
        if args.resume == True:
            self.ema.load(self.ema_model)

        # for gpu profiling
        start_batch = torch.cuda.Event(enable_timing=True)
        end_batch = torch.cuda.Event(enable_timing=True)
        start_run = torch.cuda.Event(enable_timing=True)
        end_run = torch.cuda.Event(enable_timing=True)

        start_batch.record()
        best_eval_acc, best_it = 0.0, 0

        scaler = GradScaler()
        t_scaler = GradScaler()
        amp_cm = autocast if args.amp else contextlib.nullcontext

        # eval for once to verify if the checkpoint is loaded correctly
        if args.resume == True:
            eval_dict = self.evaluate(args=args)
            print(eval_dict)

        moving_dot_product = torch.zeros(1).cuda(args.gpu)
        # limit = 3.0**(0.5)  # 3 = 6 / (f_in + f_out)
        # nn.init.uniform_(moving_dot_product, -limit, limit)
        if args.use_free:
            p_model = (torch.ones(args.num_classes) / args.num_classes).cuda()
            time_p = p_model.mean()
            s_p_model = (torch.ones(args.num_classes) / args.num_classes).cuda()
            s_time_p = s_p_model.mean()
        else:
            p_model = None
            time_p = None
            s_p_model = None
            s_time_p = None

        for (_, x_lb, y_lb), (x_ulb_idx, x_ulb_w, x_ulb_s) in zip(self.loader_dict['train_lb'],
                                                                  self.loader_dict['train_ulb']):
            # prevent the training iterations exceed args.num_train_iter
            if self.it > args.num_train_iter:
                break

            end_batch.record()
            torch.cuda.synchronize()
            start_run.record()

            num_lb = x_lb.shape[0]
            num_ulb = x_ulb_w.shape[0]
            assert num_ulb == x_ulb_s.shape[0]

            x_lb, x_ulb_w, x_ulb_s = x_lb.cuda(args.gpu), x_ulb_w.cuda(args.gpu), x_ulb_s.cuda(args.gpu)
            x_ulb_idx = x_ulb_idx.cuda(args.gpu)
            y_lb = y_lb.cuda(args.gpu)

            # all calls to teacher
            with amp_cm():
                inputs = torch.cat([x_lb, x_ulb_w, x_ulb_s], dim=0)
                logits = self.t_model(inputs)
                logits_x_lb = logits[:num_lb]
                logits_x_ulb_w, logits_x_ulb_s = logits[num_lb:].chunk(2)

                # hyper-params for update
                T = self.t_fn(self.it)  # temperature
                p_cutoff = args.p_cutoff  # threshold
                tsa = TSA(args.TSA_schedule, self.it, args.num_train_iter, args.num_classes)  # Training Signal Annealing
                sup_mask = torch.max(torch.softmax(logits_x_lb, dim=-1), dim=-1)[0].le(tsa).float().detach()

                if self.label_smoothing:
                    targets_x_lb = smooth_targets(logits_x_lb, y_lb, self.label_smoothing)
                    use_hard_labels = False
                else:
                    targets_x_lb = y_lb
                    use_hard_labels = True
                sup_loss = (ce_loss(logits_x_lb, targets_x_lb, use_hard_labels, reduction='none') * sup_mask).mean()


                if args.use_free:
                    time_p, p_model = self.cal_time_p_and_p_model(logits_x_ulb_w, time_p, p_model)

                unsup_loss, mask, select, pseudo_lb = consistency_loss(logits_x_ulb_s,
                                                                       logits_x_ulb_w,
                                                                       'ce', 
                                                                       T, 
                                                                       p_cutoff,
                                                                       time_p=time_p,
                                                                       p_model=p_model,
                                                                       use_free=args.use_free)


                # 1st call to student
                inputs = torch.cat([x_lb, x_ulb_s], dim=0)
                logits = self.model(inputs)
                s_logits_x_lb_old = logits[:num_lb]
                s_logits_x_ulb_s = logits[num_lb:]

                if args.use_free:
                    with torch.no_grad():
                        s_logits_x_ulb_w = self.model(x_ulb_w)
                    s_time_p, s_p_model = self.cal_time_p_and_p_model(s_logits_x_ulb_w, s_time_p, s_p_model)

                # update student on unlabeled data
                s_unsup_loss, _, select, pseudo_lb = consistency_loss(s_logits_x_ulb_s,
                                                                      logits_x_ulb_s,
                                                                      'ce', 
                                                                       T, 
                                                                       p_cutoff,
                                                                       self.label_smoothing,
                                                                       time_p=s_time_p,
                                                                       p_model=s_p_model,
                                                                       use_free=args.use_free)
            
            # update student's parameters
            if args.amp:
                scaler.scale(s_unsup_loss).backward()
                if (args.clip > 0):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
                scaler.step(self.optimizer)
                scaler.update()
            else:
                s_unsup_loss.backward()
                if (args.clip > 0):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
                self.optimizer.step()
            

            # 2nd call to student
            with amp_cm():
                s_logits_x_lb_new = self.model(x_lb)

                # compute teacher's feedback coefficient
                s_sup_loss_old = F.cross_entropy(s_logits_x_lb_old.detach(), y_lb)
                s_sup_loss_new = F.cross_entropy(s_logits_x_lb_new.detach(), y_lb)
                dot_product = s_sup_loss_old - s_sup_loss_new
                moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01
                dot_product = dot_product - moving_dot_product
                dot_product = dot_product.detach()

                # compute mpl loss
                _, hard_pseudo_label = torch.max(logits_x_ulb_s.detach(), dim=-1)
                mpl_loss = dot_product * ce_loss(logits_x_ulb_s, hard_pseudo_label).mean()
                
                # compute total loss for update teacher
                weight_u = self.lambda_u * min(1., (self.it+1) / self.num_uda_warmup_iter)
                total_loss = sup_loss + weight_u * unsup_loss + mpl_loss

            # update teacher's parameters
            if args.amp:
                t_scaler.scale(total_loss).backward()
                if (args.clip > 0):
                    torch.nn.utils.clip_grad_norm_(self.t_model.parameters(), args.clip)
                t_scaler.step(self.t_optimizer)
                t_scaler.update()
            else:
                total_loss.backward()
                if (args.clip > 0):
                    torch.nn.utils.clip_grad_norm_(self.t_model.parameters(), args.clip)
                self.t_optimizer.step()

            self.scheduler.step()
            self.t_scheduler.step()
            self.ema.update()
            self.model.zero_grad()
            self.t_model.zero_grad()

            end_run.record()
            torch.cuda.synchronize()

            # tensorboard_dict update
            tb_dict = {}
            tb_dict['train/sup_loss'] = sup_loss.detach()
            tb_dict['train/unsup_loss'] = unsup_loss.detach()
            tb_dict['train/mpl_loss'] = mpl_loss.detach()
            tb_dict['train/s_unsup_loss'] = s_unsup_loss.detach()
            tb_dict['train/total_loss'] = total_loss.detach()
            tb_dict['train/mask_ratio'] = 1.0 - mask.detach()
            tb_dict['lr'] = self.optimizer.param_groups[0]['lr']
            tb_dict['train/prefecth_time'] = start_batch.elapsed_time(end_batch) / 1000.
            tb_dict['train/run_time'] = start_run.elapsed_time(end_run) / 1000.
            if args.use_free:
                tb_dict['train/time_p'] = time_p.item()
                tb_dict['train/p_model'] = p_model.mean().item()
                
            # Save model for each 10K steps and best model for each 1K steps
            if self.it % 10000 == 0:
                save_path = os.path.join(args.save_dir, args.save_name)
                if not args.multiprocessing_distributed or \
                        (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                    self.save_model('latest_model.pth', save_path)

            if self.it % self.num_eval_iter == 0:
                eval_dict = self.evaluate(args=args)
                tb_dict.update(eval_dict)

                save_path = os.path.join(args.save_dir, args.save_name)

                if tb_dict['eval/top-1-acc'] > best_eval_acc:
                    best_eval_acc = tb_dict['eval/top-1-acc']
                    best_it = self.it

                self.print_fn(
                    f"{self.it} iteration, USE_EMA: {self.ema_m != 0}, {tb_dict}, BEST_EVAL_ACC: {best_eval_acc}, at {best_it} iters")

                if not args.multiprocessing_distributed or \
                        (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):

                    if self.it == best_it:
                        self.save_model('model_best.pth', save_path)

                    if not self.tb_log is None:
                        self.tb_log.update(tb_dict, self.it)

            self.it += 1
            del tb_dict
            start_batch.record()
            if self.it > 0.8 * args.num_train_iter:
                self.num_eval_iter = 1000

        eval_dict = self.evaluate(args=args)
        eval_dict.update({'eval/best_acc': best_eval_acc, 'eval/best_it': best_it})
        return eval_dict


    def finetune(self, args, logger=None):
        ngpus_per_node = torch.cuda.device_count()

        # load ema weights
        ema_model_state_dict = {}
        for key, item in self.ema_model.state_dict().items():
            new_key = 'module.' + key
            ema_model_state_dict[new_key] = item
        src_model_state_dict = self.model.state_dict()
        self.model.load_state_dict(ema_model_state_dict)
        self.model.train()

        # set optimizer for ft
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)

        # for gpu profiling
        start_batch = torch.cuda.Event(enable_timing=True)
        end_batch = torch.cuda.Event(enable_timing=True)
        start_run = torch.cuda.Event(enable_timing=True)
        end_run = torch.cuda.Event(enable_timing=True)

        self.it = 0
        start_batch.record()
        best_eval_acc, best_it = 0.0, 0

        scaler = GradScaler()
        amp_cm = autocast if args.amp else contextlib.nullcontext

        for _, x_lb, y_lb in self.loader_dict['train_lb']:

            # prevent the training iterations exceed args.num_train_iter
            if self.it > args.num_ft_iter:
                break

            end_batch.record()
            torch.cuda.synchronize()
            start_run.record()

            x_lb = x_lb.cuda(args.gpu)
            y_lb = y_lb.cuda(args.gpu)

            num_lb = x_lb.shape[0]

            # inference and calculate sup/unsup losses
            with amp_cm():

                logits_x_lb = self.model(x_lb)

                if self.label_smoothing:
                    targets_x_lb = smooth_targets(logits_x_lb, y_lb, self.label_smoothing)
                    use_hard_labels = False
                else:
                    targets_x_lb = y_lb
                    use_hard_labels = True
                sup_loss = ce_loss(logits_x_lb, targets_x_lb, use_hard_labels=use_hard_labels, reduction='none').mean()

                total_loss = sup_loss

            # parameter updates
            if args.amp:
                scaler.scale(total_loss).backward()
                if (args.clip > 0):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
                scaler.step(self.optimizer)
                scaler.update()
            else:
                total_loss.backward()
                if (args.clip > 0):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
                self.optimizer.step()

            self.model.zero_grad()

            end_run.record()
            torch.cuda.synchronize()

            # tensorboard_dict update
            tb_dict = {}
            tb_dict['train/sup_loss'] = sup_loss.detach()
            tb_dict['train/total_loss'] = total_loss.detach()
            tb_dict['lr'] = self.optimizer.param_groups[0]['lr']
            tb_dict['train/prefecth_time'] = start_batch.elapsed_time(end_batch) / 1000.
            tb_dict['train/run_time'] = start_run.elapsed_time(end_run) / 1000.

            if self.it % 200 == 0:
                save_path = os.path.join(args.save_dir, args.save_name)
                if not args.multiprocessing_distributed or \
                        (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                    self.save_model('ft_latest_model.pth', save_path)

            if self.it % 200 == 0:
                eval_dict = self.evaluate(args=args, use_ema=False)
                tb_dict.update(eval_dict)

                save_path = os.path.join(args.save_dir, args.save_name)

                if tb_dict['eval/top-1-acc'] > best_eval_acc:
                    best_eval_acc = tb_dict['eval/top-1-acc']
                    best_it = self.it

                self.print_fn(
                    f"ft {self.it} iteration, {tb_dict}, BEST_EVAL_ACC: {best_eval_acc}, at {best_it} iters")

                if not args.multiprocessing_distributed or \
                        (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):

                    if self.it == best_it:
                        self.save_model('ft_model_best.pth', save_path)

            self.it += 1
            del tb_dict
            start_batch.record()

        eval_dict = self.evaluate(args=args, use_ema=False)
        eval_dict.update({'eval/best_acc': best_eval_acc, 'eval/best_it': best_it})
        return eval_dict


    @torch.no_grad()
    def evaluate(self, eval_loader=None, args=None, use_ema=True):
        self.model.eval()
        if use_ema:
            self.ema.apply_shadow()

        if eval_loader is None:
            eval_loader = self.loader_dict['eval']
        total_loss = 0.0
        total_num = 0.0
        y_true = []
        y_pred = []
        y_logits = []
        for _, x, y in eval_loader:
            x, y = x.cuda(args.gpu), y.cuda(args.gpu)
            num_batch = x.shape[0]
            total_num += num_batch
            logits = self.model(x)
            loss = F.cross_entropy(logits, y, reduction='mean')
            y_true.extend(y.cpu().tolist())
            y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())
            y_logits.extend(torch.softmax(logits, dim=-1).cpu().tolist())
            total_loss += loss.detach() * num_batch
        top1 = accuracy_score(y_true, y_pred)
        top5 = top_k_accuracy_score(y_true, y_logits, k=5)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        F1 = f1_score(y_true, y_pred, average='macro')
        AUC = roc_auc_score(y_true, y_logits, multi_class='ovo')
        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
        self.print_fn('confusion matrix:\n' + np.array_str(cf_mat))
        if use_ema:
            self.ema.restore()
        self.model.train()
        return {'eval/loss': total_loss / total_num, 'eval/top-1-acc': top1, 'eval/top-5-acc': top5,
                'eval/precision': precision, 'eval/recall': recall, 'eval/F1': F1, 'eval/AUC': AUC}

    
    @torch.no_grad()
    def cal_time_p_and_p_model(self,logits_x_ulb_w, time_p, p_model):
        prob_w = torch.softmax(logits_x_ulb_w, dim=1) 
        max_probs, max_idx = torch.max(prob_w, dim=-1)
        if time_p is None:
            time_p = max_probs.mean()
        else:
            time_p = time_p * 0.999 +  max_probs.mean() * 0.001
        if p_model is None:
            p_model = torch.mean(prob_w, dim=0)
        else:
            p_model = p_model * 0.999 + torch.mean(prob_w, dim=0) * 0.001
        return time_p, p_model
    
    def save_model(self, save_name, save_path):
        if self.it < 1000000:
            return
        save_filename = os.path.join(save_path, save_name)
        # copy EMA parameters to ema_model for saving with model as temp
        self.model.eval()
        self.ema.apply_shadow()
        ema_model = self.model.state_dict()
        self.ema.restore()
        self.model.train()

        torch.save({'model': self.model.state_dict(),
                    't_model': self.t_model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict(),
                    't_optimizer': self.t_optimizer.state_dict(),
                    't_scheduler': self.t_scheduler.state_dict(),
                    'it': self.it,
                    'ema_model': ema_model},
                   save_filename)

        self.print_fn(f"model saved: {save_filename}")

    def load_model(self, load_path):
        checkpoint = torch.load(load_path)

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.t_model.load_state_dict(checkpoint['t_model'])
        self.t_optimizer.load_state_dict(checkpoint['t_optimizer'])
        self.t_scheduler.load_state_dict(checkpoint['t_scheduler'])
        self.it = checkpoint['it']

        ema_state_dict = {}
        for key, item in checkpoint['ema_model'].items():
            if key.startswith('module'):
                new_key = '.'.join(key.split('.')[1:])
                ema_state_dict[new_key] = item
            else:
                ema_state_dict[key] = item
        self.ema_model.load_state_dict(ema_state_dict)
        self.print_fn('model loaded')

    def interleave_offsets(self, batch, nu):
        groups = [batch // (nu + 1)] * (nu + 1)
        for x in range(batch - sum(groups)):
            groups[-x - 1] += 1
        offsets = [0]
        for g in groups:
            offsets.append(offsets[-1] + g)
        assert offsets[-1] == batch
        return offsets

    def interleave(self, xy, batch):
        nu = len(xy) - 1
        offsets = self.interleave_offsets(batch, nu)
        xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
        for i in range(1, nu + 1):
            xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
        return [torch.cat(v, dim=0) for v in xy]


if __name__ == "__main__":
    pass
