import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViTKDLoss(nn.Module):

    def __init__(self,
                 student_dims,
                 teacher_dims,
                 alpha_vitkd=0.00003,
                 beta_vitkd=0.000003,
                 lambda_vitkd=0.5,
                 ):
        super(ViTKDLoss, self).__init__()
        self.alpha_vitkd = alpha_vitkd
        self.beta_vitkd = beta_vitkd
        self.lambda_vitkd = lambda_vitkd

    def forward1(self, preds_S, preds_T, lamda3, lamda4, lamda5):
        low_s = preds_S[0]
        low_t = preds_T[0]
        high_s = preds_S[1]
        high_t = preds_T[1]

        B = low_s.shape[1]
        loss_mse = nn.MSELoss(reduction='sum')

        xc = low_s

        # loss_lr = loss_mse(xc, low_t) / B * self.alpha_vitkd
        loss_lr = loss_mse(xc, low_t) / B * (1/(12 * lamda3)) if lamda3 != 0 else 0

        x = high_s
        loss_gen = loss_mse(x, high_t) / B * (1/(12 * lamda4)) if lamda4 != 0 else 0

        return loss_lr, loss_gen, 0

    def forward(self, preds_S, preds_T, lamda3, lamda4, lamda5):
        low_s = preds_S[0]
        low_t = preds_T[0]
        high_s = preds_S[1]
        high_t = preds_T[1]

        B = low_s.shape[1] # batch
        loss_mse = nn.MSELoss(reduction='sum')

        teacher_orign = torch.concat((low_t, high_t), dim=0).permute(1, 0, 2, 3)
        student_orign = torch.concat((low_s, high_s), dim=0).permute(1, 0, 2, 3)

        teacher_pool_head = torch.sum(teacher_orign, dim=1)
        student_pool_head = torch.sum(student_orign, dim=1)

        teacher_pool_patch = torch.sum(teacher_orign, dim=2) #
        student_pool_patch = torch.sum(student_orign, dim=2)

        teacher_pool_dim = torch.sum(teacher_orign, dim=3)
        student_pool_dim = torch.sum(student_orign, dim=3)

        loss_1 = loss_mse(teacher_pool_head, student_pool_head) / B * (1/(12 * lamda3)) if lamda3 != 0 else 0
        loss_2 = loss_mse(teacher_pool_patch, student_pool_patch) / B * (1/(12 * lamda4)) if lamda4 != 0 else 0
        loss_3 = loss_mse(teacher_pool_dim, student_pool_dim) / B * (1/(12 * lamda5)) if lamda5 != 0 else 0

        return loss_1, loss_2, loss_3
