import math
import torch
from torch.nn import functional as F
import torch.nn as nn

# from .config import ConfigBase


class LocalLoss(torch.nn.Module):

    def __init__(
        self ,
        criterion: torch.nn.Module
        
    ):
        super().__init__()
        self.criterion = criterion    

    def forward(
        self,
        stu_outputs,
        labels,
        epoch,
        max_epoch,
        tea_outputs,
        loss_type = 'hard'
    ):
        
        if not isinstance(stu_outputs, torch.Tensor):
            stu_outputs, stu_dense_logits = stu_outputs
            if stu_dense_logits.dim() == 3:
                stu_dense_logits = self.align_cnn_logits(stu_dense_logits)
        
        if not isinstance(tea_outputs, torch.Tensor):
            tea_outputs, tea_dense_logits = tea_outputs
            if tea_dense_logits.dim() == 3:
                tea_dense_logits = self.align_cnn_logits(tea_dense_logits)
        
        if stu_dense_logits.dim() == tea_dense_logits.dim() and stu_dense_logits.shape[2] != tea_dense_logits.shape[2]:
            if stu_dense_logits.shape[2] > tea_dense_logits.shape[2]:
                stu_dense_logits = self.align_feature(input=stu_dense_logits,size=tea_dense_logits.shape[2])
            else:
                tea_dense_logits = self.align_feature(input=tea_dense_logits,size=stu_dense_logits.shape[2])

        
        loss_local = self.get_loss_local(stu_dense_logits, tea_dense_logits, local_loss_type=loss_type)
        return loss_local

    def align_feature(self,input=None,size=5):
        adaptive_pool = nn.AdaptiveAvgPool2d((size,size))
        return adaptive_pool(input)

    def align_cnn_logits(self, stu_dense_logits):
        N, M, C = stu_dense_logits.shape
        m = int(torch.sqrt(torch.tensor(M, dtype=torch.float32)))
        stu_dense_logits = stu_dense_logits.permute(0, 2, 1).reshape(N, C, m, m)
        return stu_dense_logits

    
    def get_loss_local(
        self,
        stu_dense_logits,
        teacher_dense_logits,
        local_loss_type = 'hard'
    ):
        teacher_logits = teacher_dense_logits
        # local loss
        if local_loss_type == "hard":
            loss_local = F.cross_entropy(
                stu_dense_logits, 
                teacher_logits.argmax(dim=1)
                )
        elif local_loss_type == "soft":
            T = 1
            loss_local = F.kl_div(
                F.log_softmax(stu_dense_logits / T, dim=1),
                F.log_softmax(teacher_logits / T, dim=1),
                reduction='batchmean',
                log_target=True
            ) * (T * T)
        else:
            raise NotImplementedError(local_loss_type)
        return loss_local

    