from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger(__name__)

class quantile_loss(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
    
    def _shape_similarity(self, train_shape, test_shape):
        train_shape = train_shape.unsqueeze(2) # b, n, 1, k, c
        test_shape = test_shape.unsqueeze(1) # b, 1, m, k, c
        diff = train_shape - test_shape # b, n, m, k, c

        diff_x = diff[:, :, :, :, 0] # b, n, m, k
        diff_y = diff[:, :, :, :, 1] # b, n, m, k

        diff = torch.sqrt(diff_x * diff_x + diff_y * diff_y) # b, n, m, k
        diff = diff.mean(-1) # b, n, m
        sim = torch.exp(-diff) # b, n, m
        return sim

    def _score_similarity(self, train_score, test_score):
        train_score = train_score.unsqueeze(2) # b, n, 1, 1
        test_score = test_score.unsqueeze(1) # b, 1, m, 1
        diff = train_score - test_score # b, n, m, 1
        return diff[:, :, :, 0]
      
    def forward(self, train_shape, test_shape, train_gt_score, test_score):
        '''
        train_shape: b, p, k, 2
        test_shape:, b, p, k, 2
        train_gt_score: b, p, 1
        test_score, b, p, 1
        '''
        
        shape_sim = self._shape_similarity(train_shape, test_shape) # b, p, p
        score_sim = self._score_similarity(train_gt_score, test_score) # b, p, p
        shape_prob = F.softmax(shape_sim, dim=-1)
        score_prob = F.softmax(score_sim, dim=-1)
        loss = F.kl_div(shape_prob.log(), score_prob, reduction='batchmean')

        return loss

