import math
import torch
import torch.nn as nn
from functools import partial

from .kl_div import KLDivergence
from .dist_kd import DIST
from .dskd import DSKD
from copy import deepcopy
import logging
import copy
logger = logging.getLogger()
from lib.models.resnet import BasicBlock
from .Relakd.Relaloss import IntraImageMiniBatch,CriterionMiniBatchCrossImagePair,MemoryBasedCrossImagePair,StudentSegContrast
import torch.nn.functional as F

KD_MODULES = {
    'cifar_wrn_40_1': dict(modules=['relu', 'fc'], channels=[64, 100]),
    'cifar_wrn_40_2': dict(modules=['relu', 'fc'], channels=[128, 100]),
    'cifar_wrn_40_1_aux': dict(modules=['relu', 'fc'], channels=[64, 100]),
    'cifar_wrn_40_2_aux': dict(modules=['relu', 'fc'], channels=[128, 100]),
    'cifar_resnet56': dict(modules=['layer3', 'fc'], channels=[64, 100]),
    'cifar_resnet20': dict(modules=['layer3', 'fc'], channels=[64, 100]),
    'cifar_resnet56_aux': dict(modules=['layer3', 'fc'], channels=[64, 100]),
    'cifar_resnet20_aux': dict(modules=['layer3', 'fc'], channels=[64, 100]),
    'cifar_resnet32x4': dict(modules=['layer3', 'fc'], channels=[256, 100]),
    'cifar_resnet32x4_aux': dict(modules=['layer3', 'fc'], channels=[256, 100]),
    'cifar_resnet8x4': dict(modules=['layer3', 'fc'], channels=[256, 100]),
    'cifar_resnet8x4_aux': dict(modules=['layer3', 'fc'], channels=[256, 100]),
    'tv_resnet50': dict(modules=['layer4', 'fc'], channels=[2048, 1000]),
    'tv_resnet101': dict(modules=['layer4', 'fc'], channels=[2048, 1000]),
    'tv_resnet34': dict(modules=['layer4', 'fc'], channels=[512, 1000]),
    'tv_resnet18': dict(modules=['layer4', 'fc'], channels=[512, 1000]),
    'resnet34_aux': dict(modules=['layer4', 'fc'], channels=[512, 1000]),
    'resnet18_aux': dict(modules=['layer4', 'fc'], channels=[512, 1000]),
    'resnet18': dict(modules=['layer4', 'fc'], channels=[512, 1000]),
    'tv_mobilenet_v2': dict(modules=['features.18', 'classifier'], channels=[1280, 1000]),
    'nas_model': dict(modules=['features.conv_out', 'classifier'], channels=[1280, 1000]),  # mbv2
    'timm_tf_efficientnet_b0': dict(modules=['conv_head', 'classifier'], channels=[1280, 1000]),
    'mobilenet_v1': dict(modules=['model.13', 'fc'], channels=[1024, 1000]),
    'timm_swin_large_patch4_window7_224': dict(modules=['norm', 'head'], channels=[1536, 1000]),
    'timm_swin_tiny_patch4_window7_224': dict(modules=['norm', 'head'], channels=[768, 1000]),
}

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2)
        return loss

class KDLoss():
    '''
    kd loss wrapper.
    '''

    def __init__(
        self,
        student,
        teacher,
        student_name,
        teacher_name,
        ori_loss,
        kd_method='kdt4',
        ori_loss_weight=1.0,
        kd_loss_weight=1.0,
        kd_loss_kwargs={},
        addloss_name = None,
        addloss_name_2 = None, 
        alpha = 1.,
        aux = False,
    ):
        self.student = student
        self.teacher = teacher
        self.ori_loss = ori_loss
        self.ori_loss_weight = ori_loss_weight
        self.kd_method = kd_method
        self.kd_loss_weight = kd_loss_weight
        self.alpha = alpha
        self._teacher_out = None
        self._student_out = None

        # init kd loss
        # module keys for distillation. '': output logits
        teacher_modules = ['',]
        student_modules = ['',]
        if kd_method == 'kd':
            self.kd_loss = KLDivergence(tau=4)
        elif kd_method == 'dist':
            self.kd_loss = DIST(beta=1, gamma=1, tau=1)
        elif kd_method.startswith('dist_t'):
            tau = float(kd_method[6:])
            self.kd_loss = DIST(beta=1, gamma=1, tau=tau)
        elif kd_method.startswith('kdt'):
            tau = float(kd_method[3:])
            self.kd_loss = KLDivergence(tau)
        
        elif kd_method == 'dskd':
            # get configs
            ae_channels = kd_loss_kwargs.get('ae_channels', 1024)
            use_ae = kd_loss_kwargs.get('use_ae', True)
            tau = kd_loss_kwargs.get('tau', 1)
            self.tau = tau
            cond = kd_loss_kwargs.get('cond', False)
            self.auxiliary = aux
            if cond:
                _teacher = deepcopy(self.teacher)
                class ResNetTail(nn.Module):
                    def __init__(self, original_resnet, auxiliary):
                        super(ResNetTail, self).__init__()
                        if auxiliary:
                            self.avgpool = nn.AvgPool2d(kernel_size=8)     # original_resnet.backbone.avgpool   nn.AvgPool2d(kernel_size=8)
                            self.fc = original_resnet.backbone.fc
                        else:
                            self.avgpool = nn.AvgPool2d(kernel_size=8) # nn.AvgPool2d(kernel_size=8)   #original_resnet.avgpool
                            self.fc = original_resnet.fc

                    def forward(self, x):
                        x = self.avgpool(x)
                        x = x.view(x.size(0), -1)  # x = x.view(x.size(0), -1)    x = torch.flatten(x, 1)
                        x = self.fc(x)
                        return x
                self.dclassifier = ResNetTail(_teacher, self.auxiliary)
                for param in self.dclassifier.parameters():
                    param.requires_grad = False
            else:
                self.dclassifier = None
                
                    
            kernel_sizes = [3, 1]  # distillation on feature and logits
            student_modules = KD_MODULES[student_name]['modules']
            student_channels = KD_MODULES[student_name]['channels']
            teacher_modules = KD_MODULES[teacher_name]['modules']
            teacher_channels = KD_MODULES[teacher_name]['channels']
            self.diff = nn.ModuleDict()
            self.kd_loss = nn.ModuleDict()
            self.add_loss = nn.ModuleDict()
            self.add_loss_2 = nn.ModuleDict()
            for tm, tc, sc, ks in zip(teacher_modules, teacher_channels, student_channels, kernel_sizes):
                self.diff[tm] = DSKD(sc, tc, kernel_size=ks, use_ae=(ks!=1) and use_ae, ae_channels=ae_channels, dclassifier=self.dclassifier if self.dclassifier is not None and (ks!=1) else None)
                self.kd_loss[tm] = nn.MSELoss() if ks != 1 else KLDivergence(tau=tau) 
                # self.kd_loss[tm] = IntraImageMiniBatch() if ks != 1 else KLDivergence(tau=tau)
                if addloss_name == 'IntraImageMiniBatch':
                    self.add_loss[tm] =  IntraImageMiniBatch() if ks != 1 else None
                elif addloss_name == 'CriterionMiniBatchCrossImagePair':
                    self.add_loss[tm] =  CriterionMiniBatchCrossImagePair(temperature=1.) if ks != 1 else None
                elif addloss_name == 'MemoryBasedCrossImagePair':
                    self.add_loss[tm] =  MemoryBasedCrossImagePair() if ks != 1 else None
                elif addloss_name == 'StudentSegContrast':
                    self.add_loss[tm] =  StudentSegContrast() if ks != 1 else None
                else :
                    self.add_loss[tm] = None

                if addloss_name_2 == 'IntraImageMiniBatch':
                    self.add_loss_2[tm] =  IntraImageMiniBatch() if ks != 1 else None
                elif addloss_name_2 == 'CriterionMiniBatchCrossImagePair':
                    self.add_loss_2[tm] =  CriterionMiniBatchCrossImagePair(temperature=1.) if ks != 1 else None
                elif addloss_name_2 == 'MemoryBasedCrossImagePair':
                    self.add_loss_2[tm] =  MemoryBasedCrossImagePair() if ks != 1 else None
                elif addloss_name_2 == 'StudentSegContrast':
                    self.add_loss_2[tm] =  StudentSegContrast() if ks != 1 else None
                else :
                    self.add_loss_2[tm] = None
            self.student._diff = self.diff
            self.diff.cuda()
            if aux:
                pass


        elif kd_method == 'mse':
            # distillation on feature
            student_modules = KD_MODULES[student_name]['modules'][:1]
            student_channels = KD_MODULES[student_name]['channels'][:1]
            teacher_modules = KD_MODULES[teacher_name]['modules'][:1]
            teacher_channels = KD_MODULES[teacher_name]['channels'][:1]
            self.kd_loss = nn.MSELoss()
            self.align = nn.Conv2d(student_channels[0], teacher_channels[0], 1)
            self.align.cuda()
            # add align module to student for optimization
            self.student._align = self.align
        else:
            raise RuntimeError(f'KD method {kd_method} not found.')

        # register forward hook
        # dicts that store distillation outputs of student and teacher
        self._teacher_out = {}
        self._student_out = {}

        for student_module, teacher_module in zip(student_modules, teacher_modules):
            self._register_forward_hook(student, student_module, teacher=False)
            self._register_forward_hook(teacher, teacher_module, teacher=True)
        self.student_modules = student_modules
        self.teacher_modules = teacher_modules

        teacher.eval()
        self._iter = 0

    def __call__(self, x, targets):
        with torch.no_grad():
            if self.auxiliary:
                _logits, ts_logits, t_aux_feats= self.teacher(x, train=True) #
            else:
                _logits = self.teacher(x)
        criterion_div = DistillKL(3)
        
        # compute ori loss of student
        if self.auxiliary:
            logits, ss_logits, s_aux_feats = self.student(x, train=True)
        else:
            logits = self.student(x)
        ori_loss = self.ori_loss(logits, targets)
        aux_loss = 0
        kd_loss = 0
        if self.auxiliary:
            for i in range(len(ss_logits)):
                aux_loss = aux_loss + criterion_div(ss_logits[i], ts_logits[i].detach())
        
        for tm, sm in zip(self.teacher_modules, self.student_modules):
            student_feat = None
            # transform student feature
            ddim_loss_d = None
            if self.kd_method == 'dskd':
                if self.diff[tm].cond:
                    self._student_out[sm], student_feat_ori, refined_feat_guidance, self._teacher_out[tm], ddim_loss_d = \
                        self.diff[tm](self._reshape_BCHW(self._student_out[sm]), self._reshape_BCHW(self._teacher_out[tm]), target_class=targets)
                    student_feat = student_feat_ori
                else :
                    self._student_out[sm], refined_feat_guidance, self._teacher_out[tm] = \
                          self.diff[tm](self._reshape_BCHW(self._student_out[sm]), self._reshape_BCHW(self._teacher_out[tm]))
                    assert refined_feat_guidance == None, "The condition here guides that the denoising feature values should be empty"
                if self.auxiliary and tm != 'fc':
                    if self.diff[tm].cond:
                        for i in range(len(s_aux_feats)):
                            idx = i + 1
                            s_aux_feats[i], student_feat_ori, _refined_feat_guidance, t_aux_feats[i], aux_ddim_loss_d = \
                                getattr(self,'diff'+str(idx))[tm](self._reshape_BCHW(s_aux_feats[i]), self._reshape_BCHW(t_aux_feats[i]), target_class=targets)
                            # diff__loss = self.kd_loss[tm](s_aux_feats[i], t_aux_feats[i])  
                            __loss = self.kd_loss[tm](_refined_feat_guidance, student_feat_ori) 
                            kd_loss = kd_loss + aux_ddim_loss_d  + __loss * self.alpha # + diff__loss
                    else:
                        for i in range(len(s_aux_feats)):
                            idx = i + 1
                            s_aux_feats[i], _refined_feat_guidance, t_aux_feats[i] = \
                                getattr(self,'diff'+str(idx))[tm](self._reshape_BCHW(s_aux_feats[i]), self._reshape_BCHW(t_aux_feats[i]), target_class=targets)
                            assert _refined_feat_guidance == None, "The condition here guides that the denoising feature values should be empty"
                            diff__loss = self.kd_loss[tm](s_aux_feats[i], t_aux_feats[i])
                            kd_loss = kd_loss + diff__loss
                

            if hasattr(self, 'align'):
                self._student_out[sm] = self.align(self._student_out[sm])

            # compute kd loss
            if isinstance(self.kd_loss, nn.ModuleDict):
                if student_feat is not None:
                    kd_loss_ = self.kd_loss[tm](refined_feat_guidance, student_feat) * self.alpha  #+ self.kd_loss[tm](self._student_out[sm], self._teacher_out[tm])
                else:
                    kd_loss_ = self.kd_loss[tm](self._student_out[sm], self._teacher_out[tm]) ######
                if self.add_loss[tm] is not None:
                    add_loss_ = self.add_loss[tm](self._student_out[sm], self._teacher_out[tm])
                    kd_loss_ += add_loss_ * self.alpha
                if self.add_loss_2[tm] is not None:
                    add_loss_2 = self.add_loss_2[tm](self._student_out[sm], self._teacher_out[tm])
                    kd_loss_ += add_loss_2 * self.alpha
            else:
                kd_loss_ = self.kd_loss(self._student_out[sm], self._teacher_out[tm])
            if self.kd_method == 'dskd':
                if ddim_loss_d is not None:
                    kd_loss += ddim_loss_d
            else:
                if self._iter % 50 == 0:
                    logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f}')
            kd_loss += kd_loss_

        self._teacher_out = {}
        self._student_out = {}

        self._iter += 1
        return ori_loss * self.ori_loss_weight + kd_loss * self.kd_loss_weight + aux_loss

    def _register_forward_hook(self, model, name, teacher=False):
        if name == '':
            # use the output of model
            model.register_forward_hook(partial(self._forward_hook, name=name, teacher=teacher))
        else:
            module = None
            if self.auxiliary:
                for k, m in model.backbone.named_modules():
                    if k == name:
                        module = m
                        break
            else:
                for k, m in model.named_modules():
                    if k == name:
                        module = m
                        break
            module.register_forward_hook(partial(self._forward_hook, name=name, teacher=teacher))

    def _forward_hook(self, module, input, output, name, teacher=False):
        if teacher:
            if isinstance(output, tuple):
                self._teacher_out[name] = output[0] 
            else:
                self._teacher_out[name] = output[0] if len(output) == 1 else output
        else:
            if isinstance(output, tuple):
                self._student_out[name] = output[0]
            else:
                self._student_out[name] = output[0] if len(output) == 1 else output

    def _reshape_BCHW(self, x):
        """
        Reshape a 2d (B, C) or 3d (B, N, C) tensor to 4d BCHW format.
        """
        if x.dim() == 2:
            x = x.view(x.shape[0], x.shape[1], 1, 1)
        elif x.dim() == 3:
            # swin [B, N, C]
            B, N, C = x.shape
            H = W = int(math.sqrt(N))
            x = x.transpose(-2, -1).reshape(B, C, H, W)
        return x
    

class MultiKDLoss():
    '''
    kd loss wrapper.
    '''

    def __init__(
        self,
        student,
        teacher,
        student_name,
        teacher_name,
        ori_loss,
        kd_method='kdt4',
        ori_loss_weight=1.0,
        kd_loss_weight=1.0,
        kd_loss_kwargs={},
        addloss_name = None,
        addloss_name_2 = None, 
        alpha = 1.
    ):
        self.student = student
        self.teacher = teacher
        self.ori_loss = ori_loss
        self.ori_loss_weight = ori_loss_weight
        self.kd_method = kd_method
        self.kd_loss_weight = kd_loss_weight
        self.alpha = alpha
        self._teacher_out = None
        self._student_out = None

        # init kd loss
        # module keys for distillation. '': output logits
        teacher_modules = ['',]
        student_modules = ['',]
        if kd_method == 'kd':
            self.kd_loss = KLDivergence(tau=4)
        elif kd_method == 'dist':
            self.kd_loss = DIST(beta=1, gamma=1, tau=1)
        elif kd_method.startswith('dist_t'):
            tau = float(kd_method[6:])
            self.kd_loss = DIST(beta=1, gamma=1, tau=tau)
        elif kd_method.startswith('kdt'):
            tau = float(kd_method[3:])
            self.kd_loss = KLDivergence(tau)
        elif kd_method == 'dskd':
            # get configs
            ae_channels = kd_loss_kwargs.get('ae_channels', 1024)
            use_ae = kd_loss_kwargs.get('use_ae', True)
            tau = kd_loss_kwargs.get('tau', 1)

            print(kd_loss_kwargs)
            kernel_sizes = [3, 1]  # distillation on feature and logits
            student_modules = KD_MODULES[student_name]['modules']
            student_channels = KD_MODULES[student_name]['channels']
            teacher_modules = KD_MODULES[teacher_name]['modules']
            teacher_channels = KD_MODULES[teacher_name]['channels']
            self.diff = nn.ModuleDict()
            self.kd_loss = nn.ModuleDict()
            self.add_loss = nn.ModuleDict()
            self.add_loss_2 = nn.ModuleDict()
            for tm, tc, sc, ks in zip(teacher_modules, teacher_channels, student_channels, kernel_sizes):
                self.diff[tm] = DSKD(sc, tc, kernel_size=ks, use_ae=(ks!=1) and use_ae, ae_channels=ae_channels)
                self.kd_loss[tm] = nn.MSELoss() if ks != 1 else KLDivergence(tau=tau) 
                # self.kd_loss[tm] = IntraImageMiniBatch() if ks != 1 else KLDivergence(tau=tau)
                if addloss_name == 'IntraImageMiniBatch':
                    self.add_loss[tm] =  IntraImageMiniBatch() if ks != 1 else None
                elif addloss_name == 'CriterionMiniBatchCrossImagePair':
                    self.add_loss[tm] =  CriterionMiniBatchCrossImagePair(temperature=1.) if ks != 1 else None
                elif addloss_name == 'MemoryBasedCrossImagePair':
                    self.add_loss[tm] =  MemoryBasedCrossImagePair() if ks != 1 else None
                elif addloss_name == 'StudentSegContrast':
                    self.add_loss[tm] =  StudentSegContrast() if ks != 1 else None
                else :
                    self.add_loss[tm] = None

                if addloss_name_2 == 'IntraImageMiniBatch':
                    self.add_loss_2[tm] =  IntraImageMiniBatch() if ks != 1 else None
                elif addloss_name_2 == 'CriterionMiniBatchCrossImagePair':
                    self.add_loss_2[tm] =  CriterionMiniBatchCrossImagePair(temperature=1.) if ks != 1 else None
                elif addloss_name_2 == 'MemoryBasedCrossImagePair':
                    self.add_loss_2[tm] =  MemoryBasedCrossImagePair() if ks != 1 else None
                elif addloss_name_2 == 'StudentSegContrast':
                    self.add_loss_2[tm] =  StudentSegContrast() if ks != 1 else None
                else :
                    self.add_loss_2[tm] = None
            
            self.diff.cuda()
            # add diff module to student for optimization
            self.student._diff = self.diff

        elif kd_method == 'mse':
            # distillation on feature
            student_modules = KD_MODULES[student_name]['modules'][:1]
            student_channels = KD_MODULES[student_name]['channels'][:1]
            teacher_modules = KD_MODULES[teacher_name]['modules'][:1]
            teacher_channels = KD_MODULES[teacher_name]['channels'][:1]
            self.kd_loss = nn.MSELoss()
            self.align = nn.Conv2d(student_channels[0], teacher_channels[0], 1)
            self.align.cuda()
            # add align module to student for optimization
            self.student._align = self.align
        else:
            raise RuntimeError(f'KD method {kd_method} not found.')

        # register forward hook
        # dicts that store distillation outputs of student and teacher
        self._teacher_out = {}
        self._student_out = {}

        for student_module, teacher_module in zip(student_modules, teacher_modules):
            self._register_forward_hook(student, student_module, teacher=False)
            self._register_forward_hook(teacher, teacher_module, teacher=True)
        self.student_modules = student_modules
        self.teacher_modules = teacher_modules

        teacher.eval()
        self._iter = 0

    def __call__(self, x, targets):
        with torch.no_grad():
            t_logits = self.teacher(x)

        # compute ori loss of student
        logits = self.student(x)
        ori_loss = self.ori_loss(logits, targets)

        kd_loss = 0

        for tm, sm in zip(self.teacher_modules, self.student_modules):
            if self.kd_method == 'dskd':
                self._student_out[sm], self._teacher_out[tm], diff_loss, ae_loss = \
                    self.diff[tm](self._reshape_BCHW(self._student_out[sm]), self._reshape_BCHW(self._teacher_out[tm]))
            if hasattr(self, 'align'):
                self._student_out[sm] = self.align(self._student_out[sm])

            # compute kd loss
            if isinstance(self.kd_loss, nn.ModuleDict):
                kd_loss_ = self.kd_loss[tm](self._student_out[sm], self._teacher_out[tm]) 
                if self.add_loss[tm] is not None:
                    add_loss_ = self.add_loss[tm](self._student_out[sm], self._teacher_out[tm])
                    kd_loss_ += add_loss_ * self.alpha
                if self.add_loss_2[tm] is not None:
                    add_loss_2 = self.add_loss_2[tm](self._student_out[sm], self._teacher_out[tm])
                    kd_loss_ += add_loss_2 * self.alpha
            else:
                kd_loss_ = self.kd_loss(self._student_out[sm], self._teacher_out[tm])
            if self.kd_method == 'dskd':
                if ae_loss is not None:
                    kd_loss += diff_loss + ae_loss
                    if self._iter % 50 == 0:
                        logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f} Diff loss: {diff_loss.item():.4f} AE loss: {ae_loss.item():.4f} ')
                else:
                    kd_loss += diff_loss
                    if self._iter % 50 == 0:
                        logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f} Diff loss: {diff_loss.item():.4f} ')
            else:
                if self._iter % 50 == 0:
                    logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f}')
            kd_loss += kd_loss_

        self._teacher_out = {}
        self._student_out = {}

        self._iter += 1
        return ori_loss * self.ori_loss_weight + kd_loss * self.kd_loss_weight

    def _register_forward_hook(self, model, name, teacher=False):
        if name == '':
            # use the output of model
            model.register_forward_hook(partial(self._forward_hook, name=name, teacher=teacher))
        else:
            module = None
            for k, m in model.named_modules():
                if k == name:
                    module = m
                    break
            module.register_forward_hook(partial(self._forward_hook, name=name, teacher=teacher))

    def _forward_hook(self, module, input, output, name, teacher=False):
        if teacher:
            self._teacher_out[name] = output[0] if len(output) == 1 else output
        else:
            self._student_out[name] = output[0] if len(output) == 1 else output

    def _reshape_BCHW(self, x):
        """
        Reshape a 2d (B, C) or 3d (B, N, C) tensor to 4d BCHW format.
        """
        if x.dim() == 2:
            x = x.view(x.shape[0], x.shape[1], 1, 1)
        elif x.dim() == 3:
            # swin [B, N, C]
            B, N, C = x.shape
            H = W = int(math.sqrt(N))
            x = x.transpose(-2, -1).reshape(B, C, H, W)
        return x