from collections import OrderedDict
import torch.nn.functional as F
import torch
import torch.nn as nn
import random
import os
import torch.backends.cudnn as cudnn
import numpy as np
import  math


def accuracy_mse(predict, target, dataset, scale=100.):
    predict = dataset.denormalize(predict.detach()) * scale
    target = dataset.denormalize(target) * scale
    return F.mse_loss(predict, target)

class AverageMeterGroup:
    """Average meter group for multiple average meters"""

    def __init__(self):
        self.meters = OrderedDict()

    def update(self, data, n=1):
        for k, v in data.items():
            if k not in self.meters:
                self.meters[k] = AverageMeter(k, ":4f")
            self.meters[k].update(v, n=n)

    def __getattr__(self, item):
        return self.meters[item]

    def __getitem__(self, item):
        return self.meters[item]

    def __str__(self):
        return "  ".join(str(v) for v in self.meters.values())

    def summary(self):
        return "  ".join(v.summary() for v in self.meters.values())


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self, name, fm=':f'):
        """
        Initialization of AverageMeter
        Parameters
        ----------
        name : str
            Name to display.
        fmt : str
            Format string to print the values.
        """
        self.name = name
        self.fm = fm
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fm + '} ({avg' + self.fm + '})'
        return fmtstr.format(**self.__dict__)

    def summary(self):
        fmstr = '{name}: {avg' + self.fm + '}'
        return fmstr.format(**self.__dict__)


def set_seed(seed):
    """
        Fix all seeds
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    cudnn.enabled = True
    cudnn.benchmark = False
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def to_cuda(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    if isinstance(obj, tuple):
        return tuple(to_cuda(t, device) for t in obj)
    if isinstance(obj, list):
        return [to_cuda(t, device) for t in obj]
    if isinstance(obj, dict):
        return {k: to_cuda(v, device) for k, v in obj.items()}
    if isinstance(obj, (int, float, str)):
        return obj
    raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))

class BPRLoss(nn.Module):
    def __init__(self, exp_weighted=False):
        super(BPRLoss, self).__init__()
        self.exp_weighted = exp_weighted

    def forward(self, input, target):
        N = input.size(0)
        total_loss = 0
        for i in range(N):
            indices = (target>target[i])
            x = torch.log(1 + torch.exp(-(input[indices] - input[i])))
            if self.exp_weighted:
                x = (torch.exp(target[i]) - 1) * (torch.exp(target[iices]) - 1) * x
            else:
                x = x
            total_loss += torch.sum(x)
        if self.exp_weighted:
            return 2 / (N * (math.e - 1))**2 * total_loss
        else:
            return 2 / N**2                  * total_loss


def wpair2(input, target):
    N = input.size(0)
    target_min = torch.min(target)
    target_shifted = target - target_min

    diff = input.unsqueeze(1) - input.unsqueeze(0)
    target_diff = target_shifted.unsqueeze(1) - target_shifted.unsqueeze(0)
    mask = target_diff > 0

    weight = (torch.exp(target_shifted.unsqueeze(1)) - 1) * (torch.exp(target_shifted.unsqueeze(0)) - 1)
    weight = weight * mask.float()

    loss = F.softplus(-diff)
    total_loss = torch.sum(weight * loss) / (N * (N - 1))

    return total_loss
def list_mle(y_pred, y_true, k=None):
    if k is not None:
        sublist_indices = (y_pred.shape[1] * torch.rand(size=k)).long()
        y_pred = y_pred[:, sublist_indices]
        y_true = y_true[:, sublist_indices]
    _, indices = y_true.sort(descending=True, dim=-1)
    pred_sorted_by_true = y_pred[indices]
    max_pred = pred_sorted_by_true.max()
    pred_scaled = pred_sorted_by_true - max_pred
    cumsums = pred_scaled.exp().flip(dims=[0]).cumsum(dim=0).flip(dims=[0])
    listmle_loss = torch.log(cumsums + 1e-10) - pred_scaled
    return listmle_loss.sum()

def pair_loss(outputs, labels,device='cpu'):
    output = outputs.unsqueeze(1)
    output1 = output.repeat(1,outputs.shape[0])
    label = labels.unsqueeze(1)
    label1 = label.repeat(1,labels.shape[0])
    tmp = (output1-output1.t())*torch.sign(label1-label1.t())
    tmp = torch.log(1+torch.exp(-tmp))
    if device == 'cpu':
        eye_tmp = tmp * torch.eye(len(tmp))
    else:
        eye_tmp = tmp*torch.eye(len(tmp)).to(device)
    new_tmp = tmp - eye_tmp
    loss = torch.sum(new_tmp)/(outputs.shape[0]*(outputs.shape[0]-1))
    return loss



def calculate_rank(scores):

    sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
    rank = [0] * len(scores)
    current_rank = 1
    prev_score = None
    for i, index in enumerate(sorted_indices):
        if scores[index] != prev_score:
            rank[index] = current_rank
        else:
            rank[index] = rank[sorted_indices[i - 1]]
        current_rank += 1
        prev_score = scores[index]

    return rank

def top_k_best_rank(predicted_scores, true_scores, k):
    sorted_indices = np.argsort(predicted_scores)[::-1]
    true_indices = np.argsort(-true_scores)
    true_rank = np.argsort(true_indices)+1
    best_rank = 99999999
    for idx in range(k):
        if true_rank[sorted_indices[idx]] < best_rank:
            best_rank = true_rank[sorted_indices[idx]]
    return best_rank


def weighted_pair_loss(outputs, labels, weight_power=1.0,device='cpu'):

    N = outputs.shape[0]

    sorted_labels, true_ranks = torch.sort(labels)
    rank_weights = (N - true_ranks.float() + 1) / N
    output = outputs.unsqueeze(1)
    output1 = output.repeat(1, N)
    label = labels.unsqueeze(1)
    label1 = label.repeat(1, N)
    rank_weight1 = rank_weights.unsqueeze(1)
    rank_weight2 = rank_weights.unsqueeze(0)

    signed_diff = ((output1 - output1.t()) * torch.sign(label1 - label1.t())) * rank_weight1 * rank_weight2
    tmp = torch.log(1 + torch.exp(-signed_diff))
    if device == 'cpu':
        eye_tmp = tmp * torch.eye(N)
    else:
        eye_tmp = tmp * torch.eye(N).to(device)
    new_tmp = tmp - eye_tmp
    loss = torch.sum(new_tmp) / (N * (N - 1))
    return loss


class ListNetLoss(nn.Module):
    """
    Implementation of the ListNet loss function for learning-to-rank tasks.
    """

    def __init__(self):
        super(ListNetLoss, self).__init__()

    def forward(self, predicted_scores, true_scores):
        ideal_probs = F.softmax(true_scores, dim=-1)
        predicted_probs = F.softmax(predicted_scores, dim=-1)
        cross_entropy = -torch.sum(ideal_probs * torch.log(predicted_probs + 1e-12),
                                   dim=-1)
        loss = torch.mean(cross_entropy)
        return loss

def warp(predicted_scores, true_scores):
    batch_size = predicted_scores.size(0)
    max_num_trials = batch_size - 1
    Y = float(batch_size)
    min_target = torch.min(true_scores) - 0.0001
    shifted_target = true_scores - min_target
    total_loss = torch.tensor(0.0, requires_grad=True, device=predicted_scores.device)

    for i in range(batch_size):
        pos_idx = i
        sample_score_margin = -1
        num_trials = 0
        msk = torch.ones(batch_size, dtype=torch.bool, device=predicted_scores.device)
        msk[pos_idx] = False
        neg_indices = torch.arange(batch_size, device=predicted_scores.device)[msk]
        neg_idx = 0
        while (sample_score_margin < 0) and (num_trials < max_num_trials):
            neg_idx = random.sample(list(neg_indices.cpu().numpy()), 1)[0]
            msk[neg_idx] = False
            neg_indices = torch.arange(batch_size, device=predicted_scores.device)[msk]
            num_trials += 1
            sample_score_margin = (true_scores[neg_idx] - true_scores[pos_idx]) * (predicted_scores[neg_idx] - predicted_scores[pos_idx]) * (-1)

        if sample_score_margin < 0:
            continue
        else:
            loss_weight = torch.log(
                torch.tensor(math.floor((Y - 1) / num_trials), dtype=torch.float32, device=predicted_scores.device))
            loss_weight = loss_weight * shifted_target[pos_idx]
            total_loss = total_loss + loss_weight * ( predicted_scores[neg_idx] - predicted_scores[pos_idx])*torch.sign(true_scores[neg_idx] - true_scores[pos_idx])*(-1)

    return total_loss / batch_size



def mp_loss(predicted_scores, true_scores,device):
    predicted_scores = predicted_scores.to(device)
    true_scores = true_scores.to(device)
    batch_size = predicted_scores.size(0)
    predicted_scores_i = predicted_scores.unsqueeze(1).expand(batch_size, batch_size)
    predicted_scores_j = predicted_scores.unsqueeze(0).expand(batch_size, batch_size)
    true_scores_i = true_scores.unsqueeze(1).expand(batch_size, batch_size)
    true_scores_j = true_scores.unsqueeze(0).expand(batch_size, batch_size)
    score_diff = predicted_scores_j - predicted_scores_i
    true_diff = true_scores_j - true_scores_i
    loss_matrix = (score_diff - true_diff) ** 2
    total_loss = loss_matrix.sum() - loss_matrix.diagonal().sum()
    num_pairs = batch_size * (batch_size - 1)

    return total_loss / num_pairs


def margin_loss(predicted_scores, true_scores, max_compare_ratio=32, s_margin=0.5,
                do_limit=False,device='cuda:0'):
    loss = nn.MarginRankingLoss(margin=s_margin)
    max_pair = math.inf
    if (do_limit): max_pair = int(len(predicted_scores) * max_compare_ratio)
    true_scores = true_scores.cpu().detach().numpy()
    acc_diff = true_scores[:, None] - true_scores
    acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
    ex_thresh_inds = np.where(acc_abs_diff_matrix > 0.0)
    ex_thresh_num = len(ex_thresh_inds[0])

    if ex_thresh_num > max_pair:
        keep_inds = np.random.choice(np.arange(ex_thresh_num), max_pair, replace=False)
        ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds])

    better_labels = (acc_diff > 0)[ex_thresh_inds]

    s_1 = predicted_scores[ex_thresh_inds[1]]
    s_2 = predicted_scores[ex_thresh_inds[0]]
    better_pm = torch.where(torch.from_numpy(better_labels), torch.tensor(1), torch.tensor(-1)).to(device)
    return loss(s_2, s_1, better_pm)


def lambdaLoss(y_pred, y_true, eps=1e-10, padded_value_indicator=-1, weighing_scheme=None, k=None, sigma=1., mu=10.,
               reduction="sum", reduction_log="binary"):
    """
    LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
    Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
    :param k: rank at which the loss is truncated
    :param sigma: score difference weight used in the sigmoid function
    :param mu: optional weight used in NDCGLoss2++ weighing scheme
    :param reduction: losses reduction method, could be either a sum or a mean
    :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
    :return: loss value, a torch.Tensor
    """
    device = y_pred.device
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    # Apply padded mask
    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Sort predictions and true labels
    y_pred_sorted, indices_pred = y_pred.sort(descending=True)
    y_true_sorted, _ = y_true.sort(descending=True)

    # Gather true labels according to sorted predictions
    true_sorted_by_preds = torch.gather(y_true, dim=0, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, None] - true_sorted_by_preds[None, :]
    padded_pairs_mask = torch.isfinite(true_diffs)

    if weighing_scheme != "ndcgLoss1_scheme":
        padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)

    # NDCG@k mask
    if k is not None:
        ndcg_at_k_mask = torch.zeros((y_pred.size(0), y_pred.size(0)), dtype=torch.bool, device=device)
        ndcg_at_k_mask[:k, :k] = 1
    else:
        ndcg_at_k_mask = torch.ones((y_pred.size(0), y_pred.size(0)), dtype=torch.bool, device=device)

    # Clamp the sorted labels for correct gains and ideal DCGs (maxDCGs)
    true_sorted_by_preds.clamp_(min=0.)
    y_true_sorted.clamp_(min=0.)

    # Calculate gains, discounts and ideal DCGs per slate
    pos_idxs = torch.arange(1, y_pred.size(0) + 1).to(device)
    D = torch.log2(1. + pos_idxs.float())
    maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:k]).clamp(min=eps)
    G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs

    # Apply appropriate weighing scheme
    if weighing_scheme is None:
        weights = 1.
    else:
        weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds)

    # Calculate scores differences
    scores_diffs = (y_pred_sorted[:, None] - y_pred_sorted[None, :]).clamp(min=-1e8, max=1e8)
    scores_diffs.masked_fill_(torch.isnan(scores_diffs), 0.)
    weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
    if reduction_log == "natural":
        losses = torch.log(weighted_probas)
    elif reduction_log == "binary":
        losses = torch.log2(weighted_probas)
    else:
        raise ValueError("Reduction logarithm base can be either natural or binary")

    if reduction == "sum":
        loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
    elif reduction == "mean":
        loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


def ndcgLoss1_scheme(G, D, *args):
    return (G / D)[:, :, None]


def ndcgLoss2_scheme(G, D, *args):
    pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
    delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
    deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
    deltas.diagonal().zero_()

    return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])


def lambdaRank_scheme(G, D, *args):
    return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs(G[:, :, None] - G[:, None, :])


def ndcgLoss2PP_scheme(G, D, *args):
    return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D)


def rankNet_scheme(G, D, *args):
    return 1.


def rankNetWeightedByGTDiff_scheme(G, D, *args):
    return torch.abs(args[1][:, :, None] - args[1][:, None, :])


def rankNetWeightedByGTDiffPowed_scheme(G, D, *args):
    return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2))


def average_rank_of_top_10(predicted_scores,true_scores):
    top_10_indices = sorted(range(len(true_scores)), key=lambda i: true_scores[i], reverse=True)[:10]
    top_10_predicted_scores = [predicted_scores[i] for i in top_10_indices]
    sorted_predicted_scores = sorted(predicted_scores, reverse=True)
    top_10_ranks = [sorted_predicted_scores.index(score) + 1 for score in top_10_predicted_scores]
    average_rank = sum(top_10_ranks) / len(top_10_ranks)
    return average_rank

def rel_at_1( predicted_scores,true_scores, K):
    top_k_pred_indices = sorted(range(len(predicted_scores)), key=lambda i: predicted_scores[i], reverse=True)[:K]
    top_k_true_scores = [true_scores[i] for i in top_k_pred_indices]
    max_true_score_in_top_k = max(top_k_true_scores)
    max_true_score = max(true_scores)
    rel_at_1_value = max_true_score_in_top_k / max_true_score
    return rel_at_1_value


def get_best_acc( predicted_scores,true_scores, K):
    top_k_pred_indices = sorted(range(len(predicted_scores)), key=lambda i: predicted_scores[i], reverse=True)[:K]
    top_k_true_scores = [true_scores[i] for i in top_k_pred_indices]
    max_true_score_in_top_k = max(top_k_true_scores)
    return max_true_score_in_top_k



