import torch
import numpy as np
import torch
from utils.utils import *
from scipy.stats import norm


def evaluate_acc(confidences: torch.Tensor, true_labels: torch.Tensor) -> float:
    acc = torch.max(confidences, dim=1)[1].eq(true_labels).float().mean().item()
    return acc


def evaluate_nll(confidences: torch.Tensor, true_labels: torch.Tensor, eps = 1e-12) -> float:
    nll = torch.nn.functional.nll_loss(torch.log(eps + confidences), true_labels).item()
    return nll

def evaluate_ece(confidences: torch.Tensor,
                 true_labels: torch.Tensor,
                 n_bins: int = 15) -> float:
    
    pred_confidences, pred_labels = torch.max(confidences, dim=1)

    ticks = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = ticks[:-1]
    bin_uppers = ticks[ 1:]
    
    accuracies = pred_labels.eq(true_labels)
    ece = torch.zeros(1, device=confidences.device)
    avg_accuracies = []
    avg_confidences = []
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = pred_confidences.gt(
            bin_lower.item()
        ) * pred_confidences.le(
            bin_upper.item()
        )
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = pred_confidences[in_bin].mean()
            ece += torch.abs(
                avg_confidence_in_bin - accuracy_in_bin
            ) * prop_in_bin
            avg_accuracies.append(accuracy_in_bin.item())
            avg_confidences.append(avg_confidence_in_bin.item())
        else:
            avg_accuracies.append(None)
            avg_confidences.append(None)

    return ece.item()

def evaluate_test_measures(probs_list, targets_list, probs_v_list, targets_v_list, measures_name):
    results_test = {}
    if 'acc' in measures_name:
        results_test['acc'] = evaluate_acc(probs_list, targets_list)
    if 'nll' in measures_name:
        results_test['nll'] = evaluate_nll(probs_list, targets_list)
    if 'ece' in measures_name:
        results_test['ece'] = evaluate_ece(probs_list, targets_list) 

    return results_test 

def evaluate_values(model, loader):
    model.eval()

    outputs_list = []
    targets_list = []

    with torch.no_grad():
        for _, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            outputs_list.append(outputs)
            targets_list.append(targets)
        
        outputs_list = torch.cat(outputs_list, dim = 0)
        targets_list = torch.cat(targets_list)

    return outputs_list, targets_list 

def evaluate_rmse(output, y):
    rmse = ((y-output)**2).mean().sqrt()
    return rmse

def evaluate_nll_normal(output, y, sigma):   
    exponent = -((y - output)**2)/(2 * sigma**2)
    log_coeff = -0.5*torch.log(2*torch.tensor(np.pi))-torch.log(sigma) 
    ll = exponent + log_coeff
    return -torch.mean(ll)


def A(mu, var):
    EPS = torch.finfo(torch.float32).eps
    sigma = torch.sqrt(var)
    r = (mu/(sigma+EPS)).detach().cpu().numpy()    
    A1 = 2*(sigma)*(torch.from_numpy(norm.pdf(r)).float().cuda())
    A2 = mu*(torch.from_numpy(2*norm.cdf(r)-1).float().cuda())    
    return(A1 + A2)


def evaluate_crps(mu,y,var):    
    if y.dim()>0 :
        if len(y)>1:
            return(None)
    crps1 = A(y-mu, var)    
    crps2 = 0.5*A(0,2*var)    
    return crps1 - crps2


def nll_loss(outputs_s_list, outputs_t_list):
    B = outputs_t_list.shape[0]
    
    outputs_s = outputs_s_list[0]
    sigma_f = outputs_s_list[4]
    
    loss = ((outputs_t_list - outputs_s) ** 2).sum() / ((sigma_f ** 2) * 2)
    loss += B * sigma_f.log()
    
    return loss

def Gaussian_kernel_matrix(Xi, Xj, sigma=1.0):
    matrix = - torch.cdist(Xi, Xj, p=2)**2
    matrix /= (2.0 * sigma**2)
    matrix = torch.exp(matrix)
    return matrix

def mmd_loss(source_reps, target_reps, sigmas=[1.0]):
    mmd = 0.0
    for sigma in sigmas:
        KXX = Gaussian_kernel_matrix(source_reps, source_reps, sigma=sigma)
        KXY = Gaussian_kernel_matrix(source_reps, target_reps, sigma=sigma)
        KYY = Gaussian_kernel_matrix(target_reps, target_reps, sigma=sigma)
        mmd += KXX.mean() - 2 * KXY.mean() + KYY.mean()
    return mmd    