import torch
import numpy as np
from scipy import stats
from sklearn.metrics import *
from sklearn.metrics.pairwise import *
from sklearn.preprocessing import normalize

# eps = np.finfo(np.float64).eps
eps = 1e-3

#Projecting onto probability simplex
def proj(Y):
    n, m = Y.shape
    X = np.sort(Y, 1)[:, ::-1]
    Xtmp = (np.cumsum(X, 1) - 1) * (1 / (np.arange(m) + 1))
    return np.maximum(Y - np.reshape(Xtmp[np.arange(n), np.sum(X > Xtmp, 1) - 1], (-1, 1)), 0)


def KL_div(Y, Y_hat):
    Y = torch.clip(Y, eps, 1)
    Y_hat = torch.clip(Y_hat, eps, 1)   
    kl = torch.sum(Y * (torch.log(Y) - torch.log(Y_hat)), 1)
    if torch.isinf(kl).any():
        print(Y)
        print(Y_hat)
        print(torch.log(Y) - torch.log(Y_hat))
        raise
    return kl.sum()
    return torch.nn.functional.kl_div(Y_hat.log(), Y, reduction='sum')


def Cheby(Y, Y_hat):
    diff_abs = torch.abs(Y - Y_hat)
    cheby = torch.max(diff_abs, dim=-1).values
    return cheby.sum()


def Clark(Y, Y_hat):
    Y = torch.clip(Y, eps, 1)
    Y_hat = torch.clip(Y_hat, eps, 1)
    sum_2 = torch.pow(Y + Y_hat, 2)
    diff_2 = torch.pow(Y - Y_hat, 2)
    clark = torch.sqrt(torch.sum(diff_2 / sum_2, dim=-1))
    
    return clark.sum()
    
def Canberra(Y, Y_hat):
    Y = torch.clip(Y, eps, 1)
    Y_hat = torch.clip(Y_hat, eps, 1)
    
    sum_2 = Y + Y_hat
    diff_abs = torch.abs(Y - Y_hat)
    can = torch.sum(diff_abs / sum_2, dim=-1)
    
    return can.sum()

def Cosine(Y, Y_hat):
    return torch.Tensor([((1 - paired_cosine_distances(Y.cpu(), Y_hat.cpu())).sum())]).reshape(([]))


def Intersection(Y, Y_hat):
    inter = torch.min(Y, Y_hat).sum(dim=-1)
    return inter.sum()
    

def spearman_rank(Y, Y_hat):
    sum = 0.0
    for i in range(Y.shape[0]):
        sum += stats.spearmanr(Y[i].cpu(), Y_hat[i].cpu()).statistic
    return torch.Tensor([sum]).reshape(([]))


def tau_Kendall_corr(Y, Y_hat):
    sum = torch.zeros((Y.shape[0])).cuda()
    k = Y.shape[1]
    for i in range(k-1):
        for j in range(i+1, k):
            sum += (Y[:, i] > Y[:, j]) ^ (Y_hat[:, i]>Y_hat[:, j])
    sum = 2 * ((k - 1) * k / 2 - 2 * sum) / ((k - 1) * k)
    return sum.sum()


def score(Y, Y_hat):

    cheby = Cheby(Y, Y_hat)
    clark = Clark(Y, Y_hat)
    can = Canberra(Y, Y_hat)
    kl = KL_div(Y, Y_hat)
    inter = Intersection(Y, Y_hat)
    spear = spearman_rank(Y, Y_hat)
    tau = tau_Kendall_corr(Y, Y_hat)
    cosine = Cosine(Y, Y_hat)

    return (cheby, clark, can, kl, cosine, inter, spear, tau)