import torch
import numpy as np
import random
from itertools import combinations
import torch.nn.functional as F


def normalize(logit):
    mean = logit.mean(dim=-1, keepdims=True)
    stdv = logit.std(dim=-1, keepdims=True)
    return (logit - mean) / (1e-7 + stdv)


def RankLoss(logits_student_in, logits_teacher_in, logit_stand = True, topk = 0, beta = 1, temprature = 2, T = 0.0125):
    logits_student = normalize(logits_student_in) if logit_stand else logits_student_in
    logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in
    # logits_teacher = F.softmax(logits_teacher / temprature, dim=1)
    # logits_student = F.softmax(logits_student / temprature, dim=1)

    # if topk > 0:
    #     _, k_index = torch.topk(logits_teacher, topk, dim=1)
    #     c_pair = sorted(combinations(k_index.tolist(), 2))
    # else:

    
    sort_index = torch.argsort(logits_teacher,dim=1)#, descending=True)
    K = 50
    k_teacher = logits_teacher.gather(1, sort_index[:, :K])
    k_student = logits_student.gather(1, sort_index[:, :K])

    # sort_index = torch.argsort(logits_teacher,dim=1)
    # K = 50
    # k_teacher = torch.masked_select(logits_teacher, sort_index < K , ).reshape((logits_teacher.shape[0],K))
    # k_student = torch.masked_select(logits_student, sort_index < K , ).reshape((logits_student.shape[0],K))

    # c = logits_teacher.shape[-1]
    c = k_teacher.shape[-1]
    pair_num = int(c * (c - 1) // 2)
    c_pair = random.sample(sorted(combinations(list(range(c)), 2)), pair_num)

    # student_prank = logits_student[:, c_pair].diff().squeeze()
    # teacher_prank = logits_teacher[:, c_pair].diff().squeeze(-1)

    student_prank = k_student[:, c_pair].diff().squeeze()
    teacher_prank = k_teacher[:, c_pair].diff().squeeze(-1)
    
    student_score = 1 / (1 + (student_prank * beta).exp())
    student_score = 1 - 2 * student_score
    
    # teacher_score = torch.sgn(teacher_prank)
    teacher_score = 1 / (1 + (teacher_prank * beta).exp())
    teacher_score = 1 - 2 * teacher_score

    score = student_score*teacher_score
    score = score / T

    loss = -1 * score.mean()            # correlation score multiply -1 to be a loss

    return loss