import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.optim.swa_utils import AveragedModel, SWALR
import numpy as np
from itertools import combinations
import copy
from tqdm import tqdm
from Utils import load
from Utils import generator
from prune import prune_loop
from metrics import *
import math
import pickle

@torch.no_grad()
def eval(args, logger, model, loss, val_loader, tst_loader, device, epoch, verbose, is_swa=False):
    model.eval()
    
    # make val predictions
    val_confidences_list = []
    val_true_labels_list = []
    for data, target in val_loader:
        confidences = torch.softmax(model(data.to(device)), dim=-1)
        val_confidences_list.append(confidences.cpu().detach())
        val_true_labels_list.append(target.cpu().detach())
    val_confidences_list = torch.cat(val_confidences_list) # [N, K,]
    val_true_labels_list = torch.cat(val_true_labels_list) # [N,]
    
    # make tst predictions
    tst_confidences_list = []
    tst_true_labels_list = []
    for data, target in tst_loader:
        confidences = torch.softmax(model(data.to(device)), dim=-1)
        tst_confidences_list.append(confidences.cpu().detach())
        tst_true_labels_list.append(target.cpu().detach())
    tst_confidences_list = torch.cat(tst_confidences_list) # [N, K,]
    tst_true_labels_list = torch.cat(tst_true_labels_list) # [N,]
    
    # compute optimal temperature
    _confidences = val_confidences_list
    _true_labels = val_true_labels_list
    t_opt = get_optimal_temperature(_confidences, _true_labels)
    
    # # evaluate val predictions
    val_metrics = {}
    _confidences = val_confidences_list
    _true_labels = val_true_labels_list
    val_metrics['acc'] = evaluate_acc(_confidences, _true_labels)
    val_metrics['nll'] = evaluate_nll(_confidences, _true_labels)
    val_metrics['ece'] = evaluate_ece(_confidences, _true_labels)
    _confidences = torch.softmax(torch.log(_confidences + 1e-8) / t_opt, dim=-1)
    val_metrics['cnll'] = evaluate_nll(_confidences, _true_labels)
    val_metrics['cece'] = evaluate_ece(_confidences, _true_labels)
    
    # evaluate tst predictions
    tst_metrics = {}
    _confidences = tst_confidences_list
    _true_labels = tst_true_labels_list
    tst_metrics['acc'] = evaluate_acc(_confidences, _true_labels)
    tst_metrics['nll'] = evaluate_nll(_confidences, _true_labels)
    tst_metrics['ece'] = evaluate_ece(_confidences, _true_labels)
    _confidences = torch.softmax(torch.log(_confidences + 1e-8) / t_opt, dim=-1)
    tst_metrics['cnll'] = evaluate_nll(_confidences, _true_labels)
    tst_metrics['cece'] = evaluate_ece(_confidences, _true_labels)
    
    if is_swa:
        tst_metrics['prunable_sparsity'], tst_metrics['sparsity'] = 0., 0.
    else:
        tst_metrics['prunable_sparsity'], tst_metrics['sparsity'] = calculate_sparsity(model)
    
    if verbose and epoch % 10 == 0:
        logger.info(
            'Test evaluation: Accuracy: {:.4f}%, NLL: {:.4f}, ECE: {:.4f}, CNLL: {:.4f}, CECE: {:.4f}, Sparsity: {:.4f}, P-Sparsity: {:.4f}'.format(
            tst_metrics['acc'], tst_metrics['nll'], tst_metrics['ece'], tst_metrics['cnll'], tst_metrics['cece'], tst_metrics['sparsity'], tst_metrics['prunable_sparsity']))
        # # save
        # with open(f'{args.result_dir}/results.pkl', 'wb') as f:
        #     pickle.dump(tst_metrics, f, pickle.HIGHEST_PROTOCOL)
    return val_metrics, tst_metrics

@torch.no_grad()
def eval_ens(args, logger, models, loss, val_loader, tst_loader, device, epoch, verbose):
    
    for model in models:
        model.eval()

    # make val predictions
    val_confidences_list = []
    val_true_labels_list = []
    for data, target in val_loader:
        data = data.to(device)
        confidences = torch.stack([torch.softmax(model(data), dim=-1) for model in models], dim=1)
        val_confidences_list.append(confidences.cpu().detach())
        val_true_labels_list.append(     target.cpu().detach())
    val_confidences_list = torch.cat(val_confidences_list) # [N, M, K,]
    val_true_labels_list = torch.cat(val_true_labels_list) # [N,]
    
    # make tst predictions
    tst_confidences_list = []
    tst_true_labels_list = []
    for data, target in tst_loader:
        data = data.to(device)
        confidences = torch.stack([torch.softmax(model(data), dim=-1) for model in models], dim=1)
        tst_confidences_list.append(confidences.cpu().detach())
        tst_true_labels_list.append(     target.cpu().detach())
    tst_confidences_list = torch.cat(tst_confidences_list) # [N, M, K,]
    tst_true_labels_list = torch.cat(tst_true_labels_list) # [N,]
    
    # compute optimal temperature
    _confidences = val_confidences_list.mean(dim=1)
    _true_labels = val_true_labels_list
    t_opt = get_optimal_temperature(_confidences, _true_labels)
    
    # evaluate val predictions
    val_metrics = {}
    _confidences = val_confidences_list.mean(dim=1)
    _true_labels = val_true_labels_list
    val_metrics['acc'] = evaluate_acc(_confidences, _true_labels)
    val_metrics['nll'] = evaluate_nll(_confidences, _true_labels)
    val_metrics['ece'] = evaluate_ece(_confidences, _true_labels)
    _confidences = torch.softmax(torch.log(_confidences) / t_opt, dim=-1)
    val_metrics['cnll'] = evaluate_nll(_confidences, _true_labels)
    val_metrics['cece'] = evaluate_ece(_confidences, _true_labels)
    
    # evaluate tst predictions
    tst_metrics = {}
    _confidences = tst_confidences_list.mean(dim=1)
    _true_labels = tst_true_labels_list
    tst_metrics['acc'] = evaluate_acc(_confidences, _true_labels)
    tst_metrics['nll'] = evaluate_nll(_confidences, _true_labels)
    tst_metrics['ece'] = evaluate_ece(_confidences, _true_labels)
    _confidences = torch.softmax(torch.log(_confidences) / t_opt, dim=-1)
    tst_metrics['cnll'] = evaluate_nll(_confidences, _true_labels)
    tst_metrics['cece'] = evaluate_ece(_confidences, _true_labels)
    
    if len(models) == 1:
        kld = 0.
    else:
        # compute diversity among predictions
        kld = 0.0    
        for iii, jjj in list(combinations(range(len(models)), 2)):
            ith_confidences = tst_confidences_list[:, iii] # [N, K,]
            jth_confidences = tst_confidences_list[:, jjj] # [N, K,]
            
            kld += torch.sum(ith_confidences * torch.log(ith_confidences + 1e-8), dim=-1).mean()
            kld += torch.sum(jth_confidences * torch.log(jth_confidences + 1e-8), dim=-1).mean()
            kld += - torch.sum(jth_confidences * torch.log(ith_confidences + 1e-8), dim=-1).mean()
            kld += - torch.sum(ith_confidences * torch.log(jth_confidences + 1e-8), dim=-1).mean()

        kld /= len(list(combinations(range(len(models)), 2)))  
    tst_metrics['kld'] = kld
    
    if len(models) == 1:
        iou = 0.
    # compute Mask IoU...
    else:
        iou = calculate_iou(models)    
    tst_metrics['iou'] = iou
    
    if verbose and epoch % 10 == 0:
        logger.info('Test ensemble evaluation: Accuracy: {:.4f}%, NLL: {:.4f}, ECE: {:.4f}, CNLL: {:.4f}, CECE: {:.4f}, function-KLD: {:.4f}, Mask IOU: {:.4f}'.format(
            tst_metrics['acc'], tst_metrics['nll'], tst_metrics['ece'], tst_metrics['cnll'], tst_metrics['cece'], tst_metrics['kld'], tst_metrics['iou']))
        # # save
        # with open(f'{args.result_dir}/results.pkl', 'wb') as f:
        #     pickle.dump(tst_metrics, f, pickle.HIGHEST_PROTOCOL)
    return val_metrics, tst_metrics

def train(args, logger, model, loss, optimizer, scheduler, trn_loader, device, epoch, verbose, matching=False):
    # if epoch == 0:  logger.info('Start vanilla train!')
    model.train()
    total = 0
    for batch_idx, (data, target) in enumerate(trn_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        train_loss = loss(output, target)
        total += train_loss.item() * data.size(0)
        train_loss.backward()
        optimizer.step()
    
    if verbose: #and (epoch % log_interval == 0):
        # make tst predictions
        tst_confidences_list = []
        tst_true_labels_list = []
        for data, target in trn_loader:
            confidences = torch.softmax(model(data.to(device)), dim=-1)
            tst_confidences_list.append(confidences.cpu().detach())
            tst_true_labels_list.append(target.cpu().detach())
        tst_confidences_list = torch.cat(tst_confidences_list) # [N, K,]
        tst_true_labels_list = torch.cat(tst_true_labels_list) # [N,]
        tst_metrics = {}
        _confidences = tst_confidences_list
        _true_labels = tst_true_labels_list
        tst_metrics['acc'] = evaluate_acc(_confidences, _true_labels)
        
        logger.info('Train Epoch: {}\tlr: {:.6f}\tloss: {:.6f}\taccuracy: {:.6f}'.format(
            epoch,  scheduler.get_last_lr()[0], total / len(trn_loader.dataset), tst_metrics['acc']))
            
def train_ens(logger, models, loss, optimizer, scheduler, trn_loader, device, epoch, verbose, log_interval=5):
    if epoch == 0:  logger.info('Start vanilla train_ens!')
    num_models = len(models)
    for model in models:
        model.train()
    total = 0.
    for batch_idx, (data, target) in enumerate(trn_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        loss_ce = 0.
        for i, model in enumerate(models):
            y_pred = model(data)
            loss_ce += F.cross_entropy(y_pred, target)
        train_loss = loss_ce
        total += train_loss.item() * data.size(0) / num_models
        train_loss.backward()
        optimizer.step()
    if verbose: # (batch_idx % log_interval == 0):
        logger.info('Train Epoch: {} \tlr: {:.6f}\tloss: {:.6f}'.format(
            epoch, scheduler.get_last_lr()[0], total / len(trn_loader.dataset)))

def calculate_iou(models):
    def _iou(mask1, mask2):
        mask1_area = np.count_nonzero(mask1 == 1)
        mask2_area = np.count_nonzero(mask2 == 1)
        intersection = np.count_nonzero(np.logical_and(mask1, mask2))
        iou = intersection / (mask1_area + mask2_area - intersection)
        return iou
    
    masks = [np.array([]) for x in range(len(models))]
    for key in models[0].state_dict().keys():
        if ('weight' in key) and ('mask' in key):
            for i in range(len(models)):
                masks[i] = np.concatenate((masks[i], models[i].state_dict()[key].cpu().detach().flatten().numpy()), axis=0)
    
    iou = 0.
    for i, j in list(combinations(range(len(models)), 2)):
        iou += _iou(masks[i], masks[j])
    iou /= len(list(combinations(range(len(models)), 2)))
    
    return iou

def calculate_sparsity(model):
    mask = np.array([])
    for key in model.state_dict().keys():
        if ('conv' in key) and ('weight' in key) and ('mask' in key):
            mask = np.concatenate((mask, model.state_dict()[key].cpu().detach().flatten().numpy()), axis=0)
    prunable_sparsity = np.count_nonzero(mask != 0) / mask.size
    mask = np.array([])
    for key in model.state_dict().keys():
        if ('mask' in key):
            mask = np.concatenate((mask, model.state_dict()[key].cpu().detach().flatten().numpy()), axis=0)
    sparsity = np.count_nonzero(mask != 0) / (mask.size + model.state_dict()['fc.weight'].numel() + model.state_dict()['fc.bias'].numel())
    return prunable_sparsity, sparsity

def train_eval_loop(args, l, logger, model, loss, optimizer, scheduler, trn_loader, val_loader, tst_loader, device, epochs, verbose, prior=False, matching=False):
    if prior:
        prior, model = model
    val_metrics, tst_metrics = eval(args, logger, model, loss, val_loader, tst_loader, device, 0, verbose)
    best_acc, best_acc_epoch = int(-1), int(-1)
    best_model = copy.deepcopy(model)
    for epoch in range(1, epochs+1):
        if prior:
            train_prior(args, logger, (prior, model), loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        else:
            train(args, logger, model, loss, optimizer, scheduler, trn_loader, device, epoch, verbose, matching)
        val_metrics, tst_metrics = eval(args, logger, model, loss, val_loader, tst_loader, device, epoch, verbose)
        scheduler.step()
        if best_acc < val_metrics['acc']:
            best_acc, best_acc_epoch = val_metrics['acc'], epoch
            best_model = copy.deepcopy(model)
            best_tst_metrics = tst_metrics
    logger.info('BEST test evaluation at epoch {}: Accuracy: {:.4f}%, NLL: {:.4f}, ECE: {:.4f}, CNLL: {:.4f}, CECE: {:.4f}, Sparsity: {:.4f}'.format(
            best_acc_epoch, best_tst_metrics['acc'], best_tst_metrics['nll'], best_tst_metrics['ece'], best_tst_metrics['cnll'], best_tst_metrics['cece'], best_tst_metrics['sparsity']))
    
    with open('{}/results_i{:03d}.pkl'.format(args.result_dir, l+1), 'wb') as f:
        pickle.dump(tst_metrics, f, pickle.HIGHEST_PROTOCOL)
    return best_model

def train_swa_eval_loop(args, l, logger, model, loss, optimizer, swa_model, scheduler, trn_loader, val_loader, tst_loader, device, epochs, verbose, model_prior=None):
    swa_start = args.post_epochs * args.swa_start
    
    val_metrics, tst_metrics = eval(args, logger, model, loss, val_loader, tst_loader, device, 0, verbose)
    swa_model = swa_model.to(device)
    for epoch in range(1, epochs+1):
        if model_prior is None:
            train_loss = train(args, logger, model, loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        else:
            train_loss = train_reg(args, logger, model, model_prior, loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        if epoch == swa_start:
            print("---------------------------------Here SWA start------------------------------------")
        if epoch > swa_start:
            swa_model.update_parameters(model)
        else:
            scheduler.step()
    
    torch.optim.swa_utils.update_bn(trn_loader, swa_model, device)
    print("--------------------------------------------------------------------------------------")
    print("------------------------------------After SWA-----------------------------------------")
    val_metrics, tst_metrics = eval(args, logger, swa_model, loss, val_loader, tst_loader, device, epoch, verbose, is_swa=True)
    # save
    with open('{}/results_i{:03d}.pkl'.format(args.result_dir, l+1), 'wb') as f:
        pickle.dump(tst_metrics, f, pickle.HIGHEST_PROTOCOL)

def train_eval_loop_ens(args, logger, models, loss, optimizer, scheduler, trn_loader, val_loader, tst_loader, device, epochs, verbose, prior=False, self_distill=False, data_diversity=True, weight_diversity=False):
    if prior:
        prior, models = models
    val_metrics, tst_metrics = eval_ens(args, logger, models, loss, val_loader, tst_loader, device, 0, verbose)
    best_acc, best_acc_epoch = int(-1), int(-1)
    best_models = [copy.deepcopy(model) for model in models]
    for epoch in range(1, epochs+1):
        if prior:
            train_ens_prior(args, logger, (prior, models), loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        elif self_distill:
            train_ens_self_distill(args, logger, models, loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        elif weight_diversity:
            train_ens_WD(logger, models, loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        elif data_diversity:
            train_ens_DD(logger, models, loss, optimizer, scheduler, trn_loader, device, epoch, verbose)
        else:
            train_ens(logger, models, loss, optimizer, scheduler, trn_loader, device, epoch, verbose)

        val_metrics, tst_metrics = eval_ens(args, logger, models, loss, val_loader, tst_loader, device, epoch, verbose)
        scheduler.step()
        if best_acc < val_metrics['acc']:
            best_acc, best_acc_epoch = val_metrics['acc'], epoch
            best_models = [copy.deepcopy(model) for model in models]
            best_tst_metrics = tst_metrics
    logger.info('BEST test ensemble evaluation at epoch {}: Accuracy: {:.4f}%, NLL: {:.4f}, ECE: {:.4f}, CNLL: {:.4f}, CECE: {:.4f}, function-KLD: {:.4f}, Mask IOU: {:.4f}'.format(
            best_acc_epoch, best_tst_metrics['acc'], best_tst_metrics['nll'], best_tst_metrics['ece'], best_tst_metrics['cnll'], best_tst_metrics['cece'], best_tst_metrics['kld'], best_tst_metrics['iou']))
    return best_models

