from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import torch
import torch.nn as nn

from .order_loss import crps_loss, order_loss_v2, order_loss_v3, kl_loss, distance_loss, ranking_loss, LambdaRankLoss
from .quantile_loss import quantile_loss

class MultiLossFactory(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.l1_loss = nn.L1Loss(reduction='mean')
        self.l1_loss_factor = cfg.loss['l1_loss_factor']
        
        self.loss_type = cfg.loss['type']
        self.order_loss = eval('{}'.format(self.loss_type))(cfg)
        self.order_loss_factor = cfg.loss['order_loss_factor']

    def forward(self, y, y_gt):
        l1_loss = self.l1_loss(y, y_gt)
        l1_loss *= self.l1_loss_factor
        
        if 'v1' in self.loss_type or 'crps' in self.loss_type:
            order_loss = self.order_loss(y)
        else:
            order_loss = self.order_loss(y, y_gt)
        order_loss *= self.order_loss_factor

        return l1_loss, order_loss


class MultiLossFactoryGaussian(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.mse_loss = nn.MSELoss(reduction='mean')
        self.c=1.0
        self.sigma = 0.5

    def order_loss(self, ys, y_gts):
        '''
        y: b, n, 1
        y_gt: b, n, 1
        '''
        loss = torch.tensor(0.0).to(ys.device)
        for y, y_gt in zip(ys, y_gts):
            gt_diff = y_gt - y_gt.T   # shape (n, n)
            dt_diff = y - y.T

            mask_g1 = (gt_diff > 0).float() 
            mask_g2 = (gt_diff < 0).float()

            c = self.c
            g1 =  (dt_diff-c)**3
            g2 =  (dt_diff+c)**3

            l = mask_g1 * g1 + mask_g2 * g2 
            loss += l.mean()
            
        loss = loss / ys.size(0)
        return loss

    def forward(self, y, y_gt):
        '''
        y: b, n, 1
        y_gt: b, n, 1
        '''
        
        mse_loss = self.mse_loss(y, y_gt)
        order_loss = self.order_loss(y, y_gt) / 6 / self.sigma / self.sigma
        return mse_loss, order_loss


class MultiLossFactoryCRPS(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.l1_loss = nn.L1Loss(reduction='mean')

    def crps(self, ys):
        '''
        y: b, n, 1
        y_gt: b, n, 1
        '''
        loss = torch.tensor(0.0).to(ys.device)
        for y in ys:
            dt_diff = y - y.T
            l = -dt_diff.abs()
            loss += l.mean()
            
        loss = loss / ys.size(0)
        return loss


    def forward(self, y, y_gt):
        '''
        y: b, n, 1
        y_gt: b, n, 1
        '''
        
        l1_loss = self.l1_loss(y, y_gt)
        order_loss = self.crps(y)
        return l1_loss, order_loss


