import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

from mask_metric import masked_mae, masked_mse


def Loss_all(teacher_HR, teacher_fore, student_HR, student_fore, label, m, batch_size, temperature):
    """
    :param teacher_HR: Hidden representation of the teacher model
    :param teacher_fore: Forecasting results of the teacher model
    :param student_HR: Hidden representation of the student model
    :param student_fore: Forecasting results of the student model
    :param m: The number of missing rates
    :return: Loss
    """
    teacher_HR = teacher_HR.unsqueeze(-1).expand(-1, -1, -1, -1, m)
    teacher_fore = teacher_fore.unsqueeze(-1).expand(-1, -1, -1, -1, m)
    label = label.unsqueeze(-1).expand(-1, -1, -1, -1, m)

    L_HD = masked_mse(student_HR, teacher_HR)

    L_RD = masked_mse(student_fore, teacher_fore, 0.0)

    L_pre = masked_mae(student_fore, label, 0.0)

    L_CL = 0.0
    for i in m-1:
        HR_1 = student_HR[:,:,:,:,i]
        for j in range(i+1,m):
            HR_2 = student_HR[:,:,:,:,j]
            L_CL += ContrastiveLoss(HR_1, HR_2,batch_size,temperature)

    L_CL = 2*L_CL/(m*(m-1))

    L_final = L_pre + L_HD + L_RD + L_CL

    return L_final


def ContrastiveLoss(HR_1, HR_2,batch_size,temperature):
    HR_1 = HR_1.reshape(batch_size,-1)
    HR_2 = HR_2.reshape(batch_size,-1)

    z = torch.cat((HR_1, HR_2), dim=0)

    z = F.normalize(z, dim=1)

    similarity_matrix = torch.matmul(z, z.T)

    sim_ij = torch.diag(similarity_matrix, batch_size)
    sim_ji = torch.diag(similarity_matrix, batch_size)

    positives = torch.cat([sim_ij, sim_ji], dim=0).view(2 * batch_size, 1)

    mask_data = get_correlated_mask(batch_size)

    negatives_mask = mask_data.repeat(2, 2)

    negatives = similarity_matrix[negatives_mask].view(2 * batch_size, -1)

    logits = torch.cat((positives, negatives), dim=1)
    logits /= temperature

    labels = torch.zeros(2 * batch_size).to(z.device).long()

    loss = nn.CrossEntropyLoss(logits, labels)
    loss /= (2 * batch_size)

def get_correlated_mask(batch_size):
    mask = torch.ones((batch_size, batch_size), dtype=bool)
    mask = mask.fill_diagonal_(0)
    return mask