from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import torch
import torch.nn as nn

logger = logging.getLogger(__name__)

import torch
import torch.nn as nn

class LambdaRankLoss(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # no learnable params

    def forward(self, scores: torch.Tensor, relevances: torch.Tensor):
        """
        Args:
          scores      : Tensor of shape (batch, list_size) — predicted scores
          relevances  : Tensor of shape (batch, list_size) — integer relevances

        Returns:
          loss        : scalar — backpropagatable loss
        """
        scores = scores[:, :, 0]
        relevances = relevances[:, :, 0]
        device = scores.device
        batch, list_size = scores.size()
        
        # 1) Compute pairwise masks: S_ij = 1 if rel_i > rel_j, else 0
        rel_i = relevances.unsqueeze(2)      # (B, L, 1)
        rel_j = relevances.unsqueeze(1)      # (B, 1, L)
        S = (rel_i > rel_j).float()          # (B, L, L)

        # 2) Gains and discounts
        gain = (2**relevances - 1).unsqueeze(2)  # (B, L, 1)
        gain = torch.abs(gain - gain.transpose(1, 2))  # (B, L, L)

        positions = torch.arange(list_size, device=device).float()
        discount = 1.0 / torch.log2(positions + 2.0)    # (L,)
        discount_i = discount.unsqueeze(1)              # (L,1)
        discount_j = discount.unsqueeze(0)              # (1,L)
        discount = torch.abs(discount_i - discount_j)  # (L, L)

        NDCG = gain * discount                       # (B, L, L)

        # 3) Pairwise score differences and sigmoid
        s_i = scores.unsqueeze(2)           # (B, L, 1)
        s_j = scores.unsqueeze(1)           # (B, 1, L)
        diff = s_i - s_j                     # (B, L, L)
        sigmoid = torch.sigmoid(diff)       # (B, L, L)

        # 4) Compute lambdas
        lambda_ij = NDCG * sigmoid          # (B, L, L)
        # Only for pairs where rel_i > rel_j
        lambda_ij = lambda_ij * S

        # Sum over j for each i, and subtract the opposite for j
        λ_i = lambda_ij.sum(dim=2)           # (B, L)
        λ_j = lambda_ij.sum(dim=1)           # (B, L)
        lambdas = λ_i - λ_j                  # (B, L)

        # 5) The loss is simply the inner-product of scores and lambdas
        #    (up to a constant, this yields dLoss/dscores = lambdas)
        loss = torch.sum(-scores * lambdas) / batch
        return loss

class crps_loss(nn.Module):
    '''
    class-ranked probability scores
    '''
    def __init__(self, cfg):
        super().__init__()
        self.cfg=cfg
    
    def forward(self, y):
        '''
        y: torch.tener n, 1
        '''
        diff = y - y.T   # shape (n, n)
        mask_upper = torch.triu(torch.ones_like(diff), diagonal=1).bool()
        diff_upper = diff[mask_upper]
        return diff_upper.abs().mean()

class order_loss_v2(nn.Module):
    '''
    Take integral of |z-\delta|
    '''
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.c = 1.0
    
    def forward(self, ys, y_gts):
        '''
        y: torch.tensor b, n, 1
        y: torch.tensor b, n, 1
        '''
        loss = torch.tensor(0.0).to(ys.device)
        for y, y_gt in zip(ys, y_gts):
            gt_diff = y_gt - y_gt.T   # shape (n, n)
            dt_diff = y - y.T

            mask_g1 = (gt_diff > 0).float() * (dt_diff > 0).float()
            mask_g2 = (gt_diff < 0).float() * (dt_diff > 0).float()
            mask_g3 = (gt_diff > 0).float() * (dt_diff < 0).float()
            mask_g4 = (gt_diff > 0).float() * (dt_diff < 0).float()

            c = self.c
            g1 = 0.5*(dt_diff-c)*(dt_diff-c) + 0.5*dt_diff*dt_diff
            g2 = 0.5*(dt_diff+c)*(dt_diff+c) - 0.5*dt_diff*dt_diff 
            g3 = 0.5*(dt_diff-c)*(dt_diff-c) - 0.5*dt_diff*dt_diff
            g4 = 0.5*(dt_diff+c)*(dt_diff+c) + 0.5*dt_diff*dt_diff

            l = mask_g1 * g1 + mask_g2 * g2 + mask_g3 * g3 + mask_g4 * g4      
            loss += l.mean()
            
        loss = loss / ys.size(0)
        return loss

class order_loss_v3(nn.Module):
    '''
    Take integral of 1/4b(b+|z-\delta|)exp(-|z-\delta|/b)
    '''
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
    
    def forward(self, y, y_gt):
        '''
        y: torch.tensor n, 1
        y_gt: torch.tener n, 1
        '''
        pass

class kl_loss(nn.Module):
    '''
    kl distance between different poses
    '''
    def __init__(self, cfg):
        super().__init__()
        self.cfg=cfg
    
    def forward(self, y, y_gt):
        '''
        y: torch.tener n, 1
        y_gt: torch.tener n, 1
        '''
        pass
    
class distance_loss(nn.Module):
    '''
    l1 distnace between diferent poses
    '''
    def __init__(self, cfg):
        super().__init__()
        self.cfg=cfg
    
    def forward(self, y, y_gt):
        '''
        y: torch.tener n, 1
        y_gt: torch.tener n, 1
        '''
        diff = y - y.T   # shape (n, n)
        mask_upper = torch.triu(torch.ones_like(diff), diagonal=1).bool()
        diff_upper = diff[mask_upper]
        return diff_upper.abs().mean()

class ranking_loss(nn.Module):
    '''
    ranking loss from ICLM'05 "Learning to Rank using Gradient Descent"
    '''
    def __init__(self, cfg):
        super().__init__()
        self.cfg=cfg
    
    def forward(self, ys, y_gts):
        '''
        ys: torch.tensor b, n, 1
        y_gts: torch.tensor b, n, 1
        '''
        loss = torch.tensor(0.0).to(ys.device)
        for y, y_gt in zip(ys, y_gts):
            dt_diff = y - y.T   # shape (n, n)
            gt_diff = y_gt - y_gt.T 
        
            P_bar = torch.exp(gt_diff) / (1+torch.exp(gt_diff))
            l = -P_bar * dt_diff + torch.log(1+torch.exp(dt_diff))
            loss += l.mean()
          
        loss = loss / ys.size(0)
        return loss