import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

criterion_kl = nn.KLDivLoss(size_average=True)
cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)


def normalize_feat(feat):
    assert len(feat.shape) == 2
    return F.normalize(feat, p=2, dim=1)


def logit_pair_loss(logits_input, logits_target, T=1.0):
    loss = criterion_kl(
        F.log_softmax(logits_input / T, dim=1),
        F.softmax(logits_target / T, dim=1),
    )
    return loss


class AlignLoss(nn.Module):
    """
    Regularization loss to align two latent representations.
    For ARAT, we use this loss to align the main model and the sub model, which shares all the layers except the BN layers.
    By default, ARAT takes asymmetric alignment, i.e., only aligning the main model to the sub model (x->y).

    Args:
        args: arguments
        loss_metric: loss function, including "cos-sim", "mse", "kl".
            By default, we use "cos-sim".
        align_type: alignment type, including "x->y", "y->x", "x->y,y->x", "x<->y".
            By default, we use "x->y".
        is_use_predictor: whether to use predictor
            By default, we use predictor, following SimSiam.
        feat_dim: feature dimension
        hidden_dim: dimension of bottleneck structure in predictor
            By default, we set it to 1/4 of the feature dimension.
    """

    def __init__(
        self,
        args,
        loss_metric="cos-sim",
        align_type="x->y",
        is_use_predictor=True,
        feat_dim=512,
        hidden_dim=512,
    ):
        """
        prev_dim: feature dimension
        dim: projection dimension
        pred_dim: dimension of bottleneck structure in predictor
        """
        super().__init__()
        self.args = args
        self.loss_metric = loss_metric
        self.align_type = align_type
        self.is_use_predictor = is_use_predictor
        assert align_type in ["x->y", "y->x", "x->y,y->x", "x<->y"]

        # loss function
        self.loss_func = get_pair_loss_func(args, loss_metric)

        # projection
        if is_use_predictor:
            self.feat_dim = feat_dim
            self.hidden_dim = hidden_dim

            # build a 2-layer predictor.
            # Based on SimSiam:
            # https://github.com/facebookresearch/simsiam/blob/main/simsiam/builder.py
            def make_predictor():
                return nn.Sequential(
                    nn.Linear(feat_dim, hidden_dim, bias=False),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU(inplace=True),  # hidden layer
                    nn.Linear(hidden_dim, feat_dim),
                )

            if align_type in ["x->y", "x->y,y->x", "x<->y"]:
                self.predictor_x_to_y = make_predictor()
            if align_type in ["y->x", "x->y,y->x", "x<->y"]:
                self.predictor_y_to_x = make_predictor()

    def forward(self, x, y):
        """
        x: features from main model
        y: features from sub model
        """
        assert x.shape == y.shape
        if self.is_use_predictor:
            if len(x.shape) == 4:
                x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
                y = F.avg_pool2d(y, y.size()[2:]).view(y.size(0), -1)

            if self.align_type in ["x->y", "x->y,y->x", "x<->y"]:
                x = self.predictor_x_to_y(x)
            if self.align_type in ["y->x", "x->y,y->x", "x<->y"]:
                y = self.predictor_y_to_x(y)

        kwargs = {}
        if self.align_type == "x->y":
            y = y.detach()
            return self.loss_func(x, y, **kwargs)
        elif self.align_type == "y->x":
            x = x.detach()
            return self.loss_func(x, y, **kwargs)
        elif self.align_type == "x->y,y->x":
            loss = 0
            loss += self.loss_func(x, y.detach())
            loss += self.loss_func(y, x.detach())
            return loss / 2
        elif self.align_type == "x<->y":
            return self.loss_func(x, y, **kwargs)


def get_pair_loss_func(args, metric, **kwargs):
    if metric == "cos-sim":

        def neg_cos_sim(x, y):
            return 1 - F.cosine_similarity(x, y, dim=-1).mean()

        return neg_cos_sim

    elif metric == "mse":

        def mse(x, y):
            return F.mse_loss(x, y)

        return mse

    elif metric == "kl":

        def kl_divergence(x, y):
            T = args.kd_temp
            loss = (
                F.kl_div(
                    F.log_softmax(x / T, dim=1),
                    F.softmax(y / T, dim=1),
                    reduction="batchmean",
                )
                * T
                * T
            )
            return loss

        return kl_divergence

    else:
        raise NotImplementedError
