"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import torch.nn as nn
import torch
import torch.nn.functional as F


class MCD(nn.Module):
    def __init__(self, in_features, num_classes, training_classifer=False, mlp=False):
        super(MCD, self).__init__()
        if training_classifer:
            self.classifer1 = None
        else:
            if mlp:
                self.classifer1 = nn.Sequential(
                    nn.Linear(in_features, 512),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(512, num_classes))
            else:
                self.classifer1 = nn.Linear(in_features, num_classes)
        if mlp:
            self.classifer2 = nn.Sequential(
                    nn.Linear(in_features, 512),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(512, num_classes))
        else:
            self.classifer2 = nn.Linear(in_features, num_classes)

    def forward(self, g_s, f_s, g_t, f_t, label_s=None, label_t=None, d_label=None, training=True):
        if self.classifer1 is None:
            y1_s = g_s
            y1_t = g_t
        else:
            y1_s = self.classifer1(f_s)
            y1_t = self.classifer1(f_t)
        y2_s = self.classifer2(f_s)
        y2_t = self.classifer2(f_t)
        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        if training:
            loss = F.cross_entropy(y1_s, label_s) + F.cross_entropy(y2_s, label_s) + \
                (entropy(y1_t) + entropy(y2_t)) * 0.01 - torch.mean(torch.abs(y1_t-y2_t))
            return loss
        else:
            pred1_t = torch.argmax(y1_t, 1)
            pred2_t = torch.argmax(y2_t, 1)
            discrepancy = torch.sum((pred1_t != pred2_t).float()) / float(pred1_t.size(0))
            return discrepancy


def classifier_discrepancy(predictions1: torch.Tensor, predictions2: torch.Tensor) -> torch.Tensor:
    r"""The `Classifier Discrepancy` in
    `Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR 2018) <https://arxiv.org/abs/1712.02560>`_.

    The classfier discrepancy between predictions :math:`p_1` and :math:`p_2` can be described as:

    .. math::
        d(p_1, p_2) = \dfrac{1}{K} \sum_{k=1}^K | p_{1k} - p_{2k} |,

    where K is number of classes.

    Args:
        predictions1 (torch.Tensor): Classifier predictions :math:`p_1`. Expected to contain raw, normalized scores for each class
        predictions2 (torch.Tensor): Classifier predictions :math:`p_2`
    """
    return torch.mean(torch.abs(predictions1 - predictions2))


def entropy(predictions: torch.Tensor) -> torch.Tensor:
    r"""Entropy of N predictions :math:`(p_1, p_2, ..., p_N)`.
    The definition is:

    .. math::
        d(p_1, p_2, ..., p_N) = -\dfrac{1}{K} \sum_{k=1}^K \log \left( \dfrac{1}{N} \sum_{i=1}^N p_{ik} \right)

    where K is number of classes.

    .. note::
        This entropy function is specifically used in MCD and different from the usual :meth:`~dalib.modules.entropy.entropy` function.

    Args:
        predictions (torch.Tensor): Classifier predictions. Expected to contain raw, normalized scores for each class
    """
    return -torch.mean(torch.log(torch.mean(predictions, 0) + 1e-6))


class ImageClassifierHead(nn.Module):
    r"""Classifier Head for MCD.

    Args:
        in_features (int): Dimension of input features
        num_classes (int): Number of classes
        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024

    Shape:
        - Inputs: :math:`(minibatch, F)` where F = `in_features`.
        - Output: :math:`(minibatch, C)` where C = `num_classes`.
    """

    def __init__(self, in_features: int, num_classes: int, bottleneck_dim: Optional[int] = 1024, pool_layer=None):
        super(ImageClassifierHead, self).__init__()
        self.num_classes = num_classes
        if pool_layer is None:
            # self.pool_layer = nn.Sequential(
            #     nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            #     nn.Flatten()
            # )
            self.pool_layer = nn.Identity()
        else:
            self.pool_layer = pool_layer
        self.head = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(bottleneck_dim, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU(),
            nn.Linear(bottleneck_dim, num_classes)
        )

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.head(self.pool_layer(inputs))