import torch
import numpy as np
import random
from itertools import combinations
import torch.nn.functional as F


def normalize(logit):
    mean = logit.mean(dim=-1, keepdims=True)
    stdv = logit.std(dim=-1, keepdims=True)
    return (logit - mean) / (1e-7 + stdv)
    # minl = logit.min(dim=-1, keepdims=True)
    # maxl = logit.max(dim=-1,keepdims=True)
    # return (logit - minl) / (maxl - minl)


def RankLoss(logits_student_in, logits_teacher_in, logit_stand = True, topk = 0, beta = 1, temprature = 5, T = 0.0125):

    logits_student = normalize(logits_student_in/temprature) if logit_stand else logits_student_in/temprature
    logits_teacher = normalize(logits_teacher_in/temprature) if logit_stand else logits_teacher_in/temprature
    # logits_teacher = F.softmax(logits_teacher / temprature, dim=1)
    # logits_student = F.softmax(logits_student / temprature, dim=1)

    if topk > 0:
        _, k_index = torch.topk(logits_teacher, topk, dim=1)
        c_pair = sorted(combinations(k_index.tolist(), 2))
        print(c_pair.shape)
    else:
        c = logits_teacher.shape[-1]
        pair_num = int(c * (c - 1) // 2)
        c_pair = random.sample(sorted(combinations(list(range(c)), 2)), pair_num)

    student_prank = logits_student[:, c_pair].diff().squeeze()
    teacher_prank = logits_teacher[:, c_pair].diff().squeeze(-1)
    
    student_score = 1 / (1 + (student_prank * beta).exp())
    student_score = 1 - 2 * student_score
    
    # teacher_score = torch.sgn(teacher_prank)
    teacher_score = 1 / (1 + (teacher_prank * beta).exp())
    teacher_score = 1 - 2 * teacher_score

    score = student_score*teacher_score
    score = score / T

    loss = -1 * score.mean()            # correlation score multiply -1 to be a loss

    return loss