import numpy as np
from scipy.stats import spearmanr
import torch
from transformers import EvalPrediction


DEFAULT_EPS = 1e-10


def _get_ranks(x: torch.Tensor) -> torch.Tensor:
    tmp = x.argsort()
    ranks = torch.zeros_like(tmp)
    ranks[tmp] = torch.arange(len(x), device=x.device)
    return ranks


def compute_ndcg(predictions: torch.Tensor, labels: torch.Tensor, k: int) -> torch.Tensor:
    """Compute NDCG (Normalized Discounted Cumulative Gain) between predictions and labels.
    For RMSD, lower values are better, so we negate the values to convert to gains.
    
    Args:
        predictions: Shape (N, )
        labels: Shape (N, )
        k: Number of elements to consider (None means all elements)
    
    Returns:
        NDCG score
    """
    k = min(k, len(predictions))
    
    if isinstance(labels, torch.Tensor):
        # Convert RMSD to gains (negative since lower RMSD is better)
        # 3 is taken for proper gains scaling
        gains = torch.exp(-labels / 3)  # maps [0, max_rmsd] -> [1, 0] 
        
        # Sort predictions and get corresponding gains
        pred_sorted_idx = torch.argsort(predictions, descending=True)  # for predictions higher is better in terms of rmsd
        pred_sorted_gains = gains[pred_sorted_idx]

        # Compute ideal gains (sort by true RMSD)
        ideal_sorted_gains = torch.sort(gains, descending=True)[0]
        
        # Compute DCG and IDCG
        positions = torch.arange(1, len(predictions) + 1, dtype=torch.float, device=predictions.device)
        discount = torch.log2(positions + 1)
        
        dcg = torch.sum(pred_sorted_gains / discount)
        idcg = torch.sum(ideal_sorted_gains / discount)
        
        ndcg = dcg / idcg if idcg > 0 else torch.tensor(0.0, device=predictions.device)

        dcg_k = torch.sum(pred_sorted_gains[:k] / discount[:k])
        idcg_k = torch.sum(ideal_sorted_gains[:k] / discount[:k])
        
        ndcg_k = dcg_k / idcg_k if idcg_k > 0 else torch.tensor(0.0, device=predictions.device)
    else:
        # Convert RMSD to gains (negative since lower RMSD is better)
        # 3 is taken for proper gains scaling
        gains = np.exp(-labels / 3)  # maps [0, max_rmsd] -> [1, 0] 
        
        # Sort predictions and get corresponding gains
        pred_sorted_idx = np.argsort(-predictions)  # for predictions higher is better in terms of rmsd
        pred_sorted_gains = gains[pred_sorted_idx]

        # Compute ideal gains (sort by true RMSD)
        ideal_sorted_gains = -np.sort(-gains)

        # Compute DCG and IDCG
        positions = np.arange(1, len(predictions) + 1, dtype=np.float32)
        discount = np.log2(positions + 1)
        
        dcg = np.sum(pred_sorted_gains / discount)
        idcg = np.sum(ideal_sorted_gains / discount)
        
        ndcg = dcg / idcg if idcg > 0 else 0.0

        dcg_k = np.sum(pred_sorted_gains[:k] / discount[:k])
        idcg_k = np.sum(ideal_sorted_gains[:k] / discount[:k])
        
        ndcg_k = dcg_k / idcg_k if idcg_k > 0 else 0.0

    ndcg_1 = pred_sorted_gains[0] / ideal_sorted_gains[0] if ideal_sorted_gains[0] > 0 else 0.0
    return ndcg, ndcg_k, ndcg_1


def spearman_correlation(x: torch.Tensor, y: torch.Tensor):
    """Compute correlation between 2 1-D vectors
    Args:
        x: Shape (N, )
        y: Shape (N, )
    """
    x_rank = _get_ranks(x)
    y_rank = _get_ranks(y)
    
    n = x.size(0)
    upper = 6 * torch.sum((x_rank - y_rank).pow(2))
    down = n * (n ** 2 - 1.0)
    return 1.0 - (upper / down)


def compute_metrics_scoring(p: EvalPrediction):
    labels = p.label_ids
    predictions, sample_correlations, ndcg, ndcg_top5, ndcg_top1 = p.predictions

    result_metrics = {}
    result_metrics['correlation'] = spearmanr(labels, predictions).correlation
    if sample_correlations is not None:
        result_metrics['sample_correlation'] = sample_correlations.mean()
        result_metrics['ndcg_full'] = ndcg.mean()
        result_metrics['ndcg_top5'] = ndcg_top5.mean()
        result_metrics['ndcg_top1'] = ndcg_top1.mean()
    return result_metrics


def preprocess_logits_for_metrics_scoring(predictions, labels, model_objective):
    predictions = predictions[0]

    corr = torch.tensor(0.0)
    ndcg = torch.tensor(0.0)
    ndcg_top5 = torch.tensor(0.0)
    ndcg_top1 = torch.tensor(0.0)
    corr = spearman_correlation(labels, predictions)
    ndcg, ndcg_top5, ndcg_top1 = compute_ndcg(predictions, labels, k=5)

    return predictions.cpu(), corr.cpu(), ndcg.cpu(), ndcg_top5.cpu(), ndcg_top1.cpu()


# **** Adapted from https://github.com/allegro/allRank/blob/master/allrank/models/losses/neuralNDCG.py

def deterministic_neural_sort(s, tau, mask):
    """
    Deterministic neural sort.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
    """
    dev = s.device

    n = s.size()[1]
    one = torch.ones((n, 1), dtype=torch.float32, device=dev)
    s = s.masked_fill(mask[:, :, None], -1e8)
    A_s = torch.abs(s - s.permute(0, 2, 1))
    A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)

    B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))

    temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
    temp = [t.type(torch.float32) for t in temp]
    temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
    scaling = torch.stack(temp).type(torch.float32).to(dev)  # type: ignore

    s = s.masked_fill(mask[:, :, None], 0.0)
    C = torch.matmul(s, scaling.unsqueeze(-2))

    P_max = (C - B).permute(0, 2, 1)
    P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
    P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
    sm = torch.nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    return P_hat


def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
    """
    Sinkhorn scaling procedure.
    :param mat: a tensor of square matrices of shape N x M x M, where N is batch size
    :param mask: a tensor of masks of shape N x M
    :param tol: Sinkhorn scaling tolerance
    :param max_iter: maximum number of iterations of the Sinkhorn scaling
    :return: a tensor of (approximately) doubly stochastic matrices
    """
    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
        mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)

    for _ in range(max_iter):
        mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
        mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)

        if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
            break

    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)

    return mat


def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1):
    """
    Discounted Cumulative Gain at k.

    Compute DCG at ranks given by ats or at the maximum rank if ats is None.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
    :param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
    :param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
    """
    y_true = y_true.clone()
    y_pred = y_pred.clone()

    actual_length = y_true.shape[1]

    if ats is None:
        ats = [actual_length]
    ats = [min(at, actual_length) for at in ats]

    _, indices = y_pred.sort(descending=True, dim=-1)
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices)

    discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
        device=true_sorted_by_preds.device)

    gains = gain_function(true_sorted_by_preds)

    discounted_gains = (gains * discounts)[:, :np.max(ats)]

    cum_dcg = torch.cumsum(discounted_gains, dim=1)

    ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)

    dcg = cum_dcg[:, ats_tensor]

    return dcg


def neuralNDCG(y_pred, y_true, temperature=1., k=None):
    """
    NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param temperature: temperature for the NeuralSort algorithm
    :param k: rank at which the loss is truncated
    :return: loss value, a torch.Tensor
    """
    dev = y_pred.device

    if k is None:
        k = y_true.shape[1]

    mask = torch.zeros_like(y_true, dtype=torch.bool)
    P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
                             mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
    P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])

    # Mask P_hat and apply to true labels, ie approximately sort them
    P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
    y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
    y_true_masked = torch.exp(-y_true_masked / 3)  # maps [0, max_rmsd] -> [1, 0] 

    ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
    discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
    discounted_gains = ground_truth * discounts

    gains_true = torch.exp(-y_true / 3)
    idcg = dcg(gains_true, gains_true, ats=[k], gain_function=lambda x: x).permute(1, 0)
    # idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)

    discounted_gains = discounted_gains[:, :, :k]
    ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG


def listMLE(y_pred, y_true, eps=DEFAULT_EPS):
    """
    ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: loss value, a torch.Tensor
    """

    y_true = torch.exp(-y_true / 3)
    # shuffle for randomised tie resolution
    random_indices = torch.randperm(y_pred.shape[-1])
    y_pred_shuffled = y_pred[:, random_indices]
    y_true_shuffled = y_true[:, random_indices]

    _, indices = y_true_shuffled.sort(descending=True, dim=-1)
    preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)

    max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
    preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
    cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])

    observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max

    return torch.mean(torch.sum(observation_loss, dim=1))
    