import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import pdb
import numpy as np

from ._base import Distiller

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_b, lam, index


def get_modulating_factor(feature, feature_mixed, lam, index):
    batch_size = feature.shape[0]
    features_ground_truth = torch.arange(batch_size, dtype=torch.long).view(-1, 1).to(feature.device)
    feature = feature.view(batch_size, -1)
    feature_mixed=feature_mixed.view(batch_size, -1)

    feature = lam*feature + (1-lam)*feature[index]
    feature = feature/feature.norm(dim=1, keepdim=True)
    feature_mixed = feature_mixed/feature_mixed.norm(dim=1, keepdim=True)

    feature_logit = feature @ feature_mixed.t()
    #feature_logit = torch.softmax(feature_logit, dim=-1)
    #step 1: modulating factor
    modulating_factor = feature_logit.gather(1, features_ground_truth).view(-1, 1)
    feature_logit = torch.softmax(feature_logit, dim=0)
    return modulating_factor


def hcl_loss(fstudent, fteacher):
    loss_all = 0.0
    for fs, ft in zip(fstudent, fteacher):
        n, c, h, w = fs.shape
        loss = F.mse_loss(fs, ft, reduction="none")
        loss = loss.view(fs.shape[0], -1)
        loss = torch.mean(loss, dim=-1).view(-1, 1)
        cnt = 1.0
        tot = 1.0
        for l in [4, 2, 1]:
            if l >= h:
                continue
            tmpfs = F.adaptive_avg_pool2d(fs, (l, l))
            tmpft = F.adaptive_avg_pool2d(ft, (l, l))
            cnt /= 2.0
            loss_i = F.mse_loss(tmpfs, tmpft, reduction="none") * cnt
            loss_i = loss_i.view(tmpfs.shape[0], -1)
            loss_i = torch.mean(loss_i, dim=-1).view(-1, 1)
            #print("loss_i.shape=", loss_i.shape, "   loss=", loss.shape)
            loss += loss_i
            tot += cnt
        loss = loss / tot
        loss_all = loss_all + loss
    return loss_all


class ReviewKD_Ours(Distiller):
    def __init__(self, student, teacher, cfg):
        super(ReviewKD_Ours, self).__init__(student, teacher)
        self.shapes = cfg.REVIEWKD.SHAPES
        self.out_shapes = cfg.REVIEWKD.OUT_SHAPES
        in_channels = cfg.REVIEWKD.IN_CHANNELS
        out_channels = cfg.REVIEWKD.OUT_CHANNELS
        self.ce_loss_weight = cfg.REVIEWKD.CE_WEIGHT
        self.reviewkd_loss_weight = cfg.REVIEWKD.REVIEWKD_WEIGHT
        self.warmup_epochs = cfg.REVIEWKD.WARMUP_EPOCHS
        self.stu_preact = cfg.REVIEWKD.STU_PREACT
        self.max_mid_channel = cfg.REVIEWKD.MAX_MID_CHANNEL

        #our hyper-parameters
        self.base_weight = cfg.EXPERIMENT.BASE_WEIGHT
        self.mixed = cfg.EXPERIMENT.MIXED
        self.alpha = cfg.EXPERIMENT.ALPHA

        abfs = nn.ModuleList()
        mid_channel = min(512, in_channels[-1])
        for idx, in_channel in enumerate(in_channels):
            abfs.append(
                ABF(
                    in_channel,
                    mid_channel,
                    out_channels[idx],
                    idx < len(in_channels) - 1,
                )
            )
        self.abfs = abfs[::-1]

    def get_learnable_parameters(self):
        return super().get_learnable_parameters() + list(self.abfs.parameters())

    def get_extra_parameters(self):
        num_p = 0
        for p in self.abfs.parameters():
            num_p += p.numel()
        return num_p

    def forward_train(self, image, target, **kwargs):
        image_mixed, y_b, lam, index = mixup_data(image, target, alpha=self.alpha)

        logits_student, features_student = self.student(image)
        logits_student_mixed, features_student_mixed = self.student(image_mixed)

        with torch.no_grad():
            logits_teacher, features_teacher = self.teacher(image)

        # get features
        if self.stu_preact:
            x = features_student["preact_feats"] + [
                features_student["pooled_feat"].unsqueeze(-1).unsqueeze(-1)
            ]
        else:
            x = features_student["feats"] + [
                features_student["pooled_feat"].unsqueeze(-1).unsqueeze(-1)
            ]
        x = x[::-1]
        results = []
        out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0])
        results.append(out_features)
        for features, abf, shape, out_shape in zip(
            x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]
        ):
            out_features, res_features = abf(features, res_features, shape, out_shape)
            results.insert(0, out_features)
        features_teacher = features_teacher["preact_feats"][1:] + [
            features_teacher["pooled_feat"].unsqueeze(-1).unsqueeze(-1)
        ]
        # losses
        if self.mixed:
            loss_ce = lam * F.cross_entropy(logits_student_mixed, target, reduction="none").view(-1, 1) + \
                        (1-lam) * F.cross_entropy(logits_student_mixed, y_b, reduction="none").view(-1, 1)
        else:
            loss_ce = F.cross_entropy(logits_student, target, reduction="none").view(-1, 1)
        loss_ce = self.ce_loss_weight * loss_ce


        loss_reviewkd = (
            self.reviewkd_loss_weight
            * min(kwargs["epoch"] / self.warmup_epochs, 1.0)
            * hcl_loss(results, features_teacher)
        )
        modulating_factor = self.base_weight - get_modulating_factor(feature=features_student['pooled_feat'], 
                                                                    feature_mixed=features_student_mixed['pooled_feat'], 
                                                                    lam=lam, index=index)
        loss = modulating_factor *(loss_ce + loss_reviewkd)
        return logits_student, loss.mean()


class ABF(nn.Module):
    def __init__(self, in_channel, mid_channel, out_channel, fuse):
        super(ABF, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channel),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                mid_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False
            ),
            nn.BatchNorm2d(out_channel),
        )
        if fuse:
            self.att_conv = nn.Sequential(
                nn.Conv2d(mid_channel * 2, 2, kernel_size=1),
                nn.Sigmoid(),
            )
        else:
            self.att_conv = None
        nn.init.kaiming_uniform_(self.conv1[0].weight, a=1)  # pyre-ignore
        nn.init.kaiming_uniform_(self.conv2[0].weight, a=1)  # pyre-ignore

    def forward(self, x, y=None, shape=None, out_shape=None):
        n, _, h, w = x.shape
        # transform student features
        x = self.conv1(x)
        if self.att_conv is not None:
            # upsample residual features
            y = F.interpolate(y, (shape, shape), mode="nearest")
            # fusion
            z = torch.cat([x, y], dim=1)
            z = self.att_conv(z)
            x = x * z[:, 0].view(n, 1, h, w) + y * z[:, 1].view(n, 1, h, w)
        # output
        if x.shape[-1] != out_shape:
            x = F.interpolate(x, (out_shape, out_shape), mode="nearest")
        y = self.conv2(x)
        return y, x