import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from sklearn import metrics
import pdb

def model_rel_dist(m1, m2, mode='rel'):
    """
    Computes distance between two nn.Modules of the same architecture
    """
    params1 = [p for p in m1.parameters()]
    params2 = [p for p in m2.parameters()]
    normsq1, normsq2, dot = 0., 0., 0.
    for i in range(len(params1)):
        normsq1 += torch.sum(params1[i] ** 2)
        normsq2 += torch.sum(params2[i] ** 2)
        dot += torch.sum(params1[i] * params2[i])
    if mode == 'rel':
        dist = torch.sqrt(normsq1 + normsq2 - 2*dot) / torch.sqrt(normsq1)
    elif mode == 'abs':
        dist = torch.sqrt(normsq1 + normsq2 - 2*dot)
    return dist, torch.sqrt(normsq1), torch.sqrt(normsq2)


def interpolate_model(m1, m2, mt, t=0.5):
    params1 = [p for p in m1.parameters()]
    params2 = [p for p in m2.parameters()]
    for (i, p) in enumerate(mt.parameters()):
        p.data.copy_((1-t) * params1[i].data + t * params2[i].data)
    # Copy batchnorm params
    bns1 = [bn for bn in m1.modules() if isinstance(bn, torch.nn.modules.batchnorm.BatchNorm2d)]
    bns2 = [bn for bn in m2.modules() if isinstance(bn, torch.nn.modules.batchnorm.BatchNorm2d)]
    bnst = [bn for bn in mt.modules() if isinstance(bn, torch.nn.modules.batchnorm.BatchNorm2d)]
    for (bn1, bn2, bnt) in zip(bns1, bns2, bnst):
        bnt.running_mean.copy_((1-t) * bn1.running_mean + t * bn2.running_mean)
        bnt.running_var.copy_((1-t) * bn1.running_var + t * bn2.running_var)


def predictive_entropy(probs):
    return -torch.sum(probs * torch.log(probs), dim=1).mean()


def expected_calibration_error(probs, targets, num_partitions=15, verbose=False):
    """
    Computes the expected calibration error.
    """
    confs, _ = torch.max(probs, 1)
    n_test = probs.shape[0]
    ece = 0.
    for m in range(num_partitions):
        lower_prob, upper_prob = m / num_partitions, (m + 1) / num_partitions
        inds = torch.bitwise_and(confs > lower_prob, confs <= upper_prob)
        B_m = torch.sum(inds)
        if B_m == 0:
            continue
        acc = 1. * torch.sum(torch.argmax(probs[inds, :], dim=1) == targets[inds]) / B_m
        conf = torch.mean(confs[inds])
        ece += torch.abs(acc - conf) * B_m / n_test
        if verbose:
            print(f"Block [{m}/{num_partitions}]: B_m={B_m}, "
                  f"acc={acc:.4f}, conf={conf:.4f}, ece={torch.abs(acc - conf):.4f}")
    return ece


def kl_ece(probs, targets, num_partitions=15):
    confs, _ = torch.max(probs, 1)
    n_test = probs.shape[0]
    ece = 0.
    for m in range(num_partitions):
        lower_prob, upper_prob = m / num_partitions, (m + 1) / num_partitions
        inds = torch.bitwise_and(confs > lower_prob, confs <= upper_prob)
        B_m = torch.sum(inds)
        if B_m == 0:
            continue
        acc = 1. * torch.sum(torch.argmax(probs[inds, :], dim=1) == targets[inds]) / B_m
        if acc == 0. or acc == 1.:
            continue
        ece += (acc * (torch.log(acc) - torch.log(confs[inds]).mean()) +
                (1.-acc) * (torch.log(1.-acc) - torch.log(1-confs[inds]).mean())) * B_m / n_test
        # if verbose:
        #     print(f"Block [{m}/{num_partitions}]: B_m={B_m}, "
        #           f"acc={acc:.4f}, conf={conf:.4f}, ece={torch.abs(acc - conf):.4f}")
    return ece


def ece_per_class(probs, targets, num_classes=100, num_partitions=15):
    eces = []
    for i in range(num_classes):
        inds = (targets == i)
        eces.append(expected_calibration_error(probs[inds, :], targets[inds], num_partitions=num_partitions).item())
    return eces


def ece_adaptive_binning(probs, targets, num_partitions=5, verbose=False):
    confs, _ = torch.max(probs, 1)
    with torch.no_grad():
        _, sorted_inds = torch.sort(confs)
    n = probs.shape[0]
    inc = n // num_partitions
    ece = 0.
    for i in range(num_partitions):
        j_lower = i * inc
        j_upper = (i+1) * inc if i < num_partitions-1 else n
        inds = sorted_inds[j_lower:j_upper]
        B_m = torch.sum(inds)
        if B_m == 0:
            continue
        acc = 1. * torch.sum(torch.argmax(probs[inds, :], dim=1) == targets[inds]) / B_m
        conf = torch.mean(confs[inds])
        ece += torch.abs(acc - conf) * B_m / n
        if verbose:
            print(f"Block [{m}/{num_partitions}]: B_m={B_m}, "
                  f"acc={acc:.4f}, conf={conf:.4f}, ece={torch.abs(acc - conf):.4f}")
    return ece

def temp_scale(probs, tau=1.0):
    new_probs = probs ** (1./tau)
    return new_probs / torch.sum(new_probs, dim=1, keepdim=True)


def temp_scale_split_test(probs, targets,
                          cal_split=5000,
                          max_iter=100,
                          init_tau=1.5,
                          eta=0.1,
                          verbose=True):
    probs_cal, targets_cal = probs[:cal_split, :], targets[:cal_split]
    probs_test, targets_test = probs[cal_split:, :], targets[cal_split:]
    tau = torch.tensor(init_tau, device='cuda', requires_grad=True)
    optimizer = optim.SGD([tau], lr=eta)
    nll_loss = nn.NLLLoss().cuda()
    for iter in range(max_iter):
        optimizer.zero_grad()
        scaled_probs_cal = temp_scale(probs_cal, tau)
        loss = nll_loss(torch.log(scaled_probs_cal), targets_cal)
        # pdb.set_trace()
        if verbose:
            print(f"Temp scaling tau={tau.item():.4f}, loss={loss.item():.4f}")
        loss.backward()
        optimizer.step()
    scaled_probs_test = temp_scale(probs_test, tau)
    scaled_nll_test = nll_loss(torch.log(scaled_probs_test), targets_test)
    scaled_ece_test = expected_calibration_error(scaled_probs_test, targets_test)
    return tau, scaled_nll_test, scaled_ece_test


def adjust_scale(init_temp_scale, epoch,
                 method='const', cycle=-1):
    epoch = epoch + 1
    if method == 'const':
        return init_temp_scale
    elif method == 'decay':
        assert(cycle > 0)
        if epoch < cycle // 2:
            return init_temp_scale
        elif epoch < cycle:
            return (2 - 2*epoch/cycle) * init_temp_scale + (2*epoch/cycle - 1) * 1.0
        else:
            return 1.0


def accuracy_against_confidence(probs, targets,
                                grid_size=0.1):
    conf_thres = np.arange(0, 1, grid_size)
    accs = np.zeros_like(conf_thres)
    counts = np.zeros_like(conf_thres, dtype=np.int)
    probs_max = probs.max(dim=1)
    confs, preds = probs_max[0], probs_max[1]
    for (i, conf) in enumerate(conf_thres):
        inds = (confs >= conf)
        counts[i] = inds.float().sum().item()
        accs[i] = (preds[inds] == targets[inds]).sum().item() / counts[i]
    return conf_thres, accs, counts


def accuracy_against_counts(probs, targets, count_interval=100,
                            temps=None):
    count_thres = np.arange(probs.shape[0], 0, -count_interval)
    accs = np.zeros_like(count_thres, dtype=np.float)
    probs_max = probs.max(dim=1)
    confs, preds = probs_max[0], probs_max[1]
    if temps is not None:
        sorted_inds = temps.squeeze().sort(descending=True)[1]
    else:
        sorted_inds = confs.sort()[1]
    for (i, count) in enumerate(count_thres):
        inds = sorted_inds[-count:]
        accs[i] = (preds[inds] == targets[inds]).sum().item() / count
    auc_conf = metrics.auc(count_thres, accs)
    oracle_accs = np.zeros_like(accs)
    oracle_sorted_inds = (preds == targets).float().sort()[1]
    for (i, count) in enumerate(count_thres):
        inds = oracle_sorted_inds[-count:]
        oracle_accs[i] = (preds[inds] == targets[inds]).sum().item() / count
    auc_oracle_conf = metrics.auc(count_thres, oracle_accs)
    random_accs = np.linspace(accs[0], 1., len(count_thres) + 1)[:-1]
    auc_random = metrics.auc(count_thres, random_accs)
    return count_thres, accs, (auc_random - auc_conf) / (auc_random - auc_oracle_conf)



def optimal_individual_temp(probs, targets,
                            eta=0.01, init_tau=1.0,
                            max_iter=20000,
                            verbose=1000,
                            min_temp=0.2,
                            overwrite_correct_preds=False,
                            save_path=None):
    n = probs.shape[0]
    temps = torch.ones((n, 1), device='cuda', requires_grad=True)
    with torch.no_grad():
        temps.fill_(init_tau)
    optimizer = optim.SGD([temps], lr=eta)
    nll_loss = nn.NLLLoss().cuda()
    for iter in range(max_iter):
        optimizer.zero_grad()
        temps_transformed = F.relu6(temps) + min_temp
        scaled_probs = temp_scale(probs, temps_transformed)
        loss = probs.shape[0] * nll_loss(torch.log(scaled_probs), targets)
        if verbose > 0 and iter % verbose == 0:
            print(f"Avg temp={F.relu(temps).mean().item():.4f}, loss={loss.item():.4f}")
        loss.backward()
        optimizer.step()
    if overwrite_correct_preds:
        correct_preds = (probs.argmax(dim=1) == targets)
        with torch.no_grad():
            temps_transformed[correct_preds] = min_temp
        scaled_probs = temp_scale(probs, temps_transformed)
        loss = probs.shape[0] * nll_loss(torch.log(scaled_probs), targets)
    scaled_ece = expected_calibration_error(scaled_probs, targets)
    if save_path is not None:
        torch.save(temps_transformed, save_path)
    return temps_transformed, loss, scaled_ece


def optimal_nll_rank_preserving(probs, targets):
    ranks = (probs > probs[range(probs.shape[0]), targets].view([-1, 1])).sum(dim=1).float()
    return torch.log(ranks + 1).mean()


def process_ddp_statedict(state_dict):
    new_dict = state_dict.copy()
    for key in state_dict.keys():
        if key.startswith('module.'):
            new_dict[key[7:]] = state_dict[key].clone()
            del new_dict[key]
    return new_dict