from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from itertools import combinations


global c_pair_
pair_num = int(100 * 99 // 2)
c_pair_ = random.sample(sorted(combinations(list(range(100)), 2)), pair_num)
c_pair = torch.tensor(c_pair_).cuda()


def normalize(logit):
    mean = logit.mean(dim=-1, keepdims=True)
    stdv = logit.std(dim=-1, keepdims=True)
    return (logit - mean) / (1e-7 + stdv)
    # minl = logit.min(dim=-1, keepdims=True)
    # maxl = logit.max(dim=-1,keepdims=True)
    # return (logit - minl) / (maxl - minl)


class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self):
        super(DistillKL, self).__init__()

    def forward(self, y_s, y_t, temp):
        T = temp.cuda()
        
        KD_loss = 0
        KD_loss += nn.KLDivLoss(reduction='batchmean')(F.log_softmax(y_s/T, dim=1),
                                F.softmax(y_t/T, dim=1)) * T * T
        
        return KD_loss


class RankLoss(nn.Module):
    def __init__(self, logit_stand = True, topk = 0, beta = 4, T = 0.0125):
        super(RankLoss, self).__init__()
        self.logit_stand = logit_stand
        self.topk = topk
        self.beta = beta
        self.T = T

    def forward(self,logits_student_in, logits_teacher_in, temprature):
        logits_student = normalize(logits_student_in/temprature) if self.logit_stand else logits_student_in/temprature
        logits_teacher = normalize(logits_teacher_in/temprature) if self.logit_stand else logits_teacher_in/temprature
        # logits_teacher = F.softmax(logits_teacher / temprature, dim=1)
        # logits_student = F.softmax(logits_student / temprature, dim=1)
        # self.beta = temprature

        # if self.topk > 0:
        #     _, k_index = torch.topk(logits_teacher, self.topk, dim=1)
        #     c_pair = sorted(combinations(k_index.tolist(), 2))
        #     print(c_pair.shape)
        # else:
        #     c = logits_teacher.shape[-1]
        #     pair_num = int(c * (c - 1) // 2)
        #     c_pair = random.sample(sorted(combinations(list(range(c)), 2)), pair_num)

        student_prank = logits_student[:, c_pair].diff().squeeze()
        teacher_prank = logits_teacher[:, c_pair].diff().squeeze(-1)
        
        student_score = 1 / (1 + (student_prank * self.beta).exp())
        student_score = 1 - 2 * student_score
        
        # teacher_score = torch.sgn(teacher_prank)
        teacher_score = 1 / (1 + (teacher_prank * self.beta).exp())
        teacher_score = 1 - 2 * teacher_score

        score = student_score*teacher_score
        score = score / self.T

        loss = -1 * score.mean()            # correlation score multiply -1 to be a loss

        return loss
