"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, List, Dict, Tuple, Callable
import torch.nn as nn
import torch.nn.functional as F
import torch

from common.modules.classifier import Classifier as ClassifierBase
from dalib.modules.grl import WarmStartGradientReverseLayer
from ..modules.entropy import entropy

class MDD(nn.Module):
    def __init__(self, in_features, num_classes, margin=4, training_classifer=False, 
                        mlp=False, entropy_eval=False):
        super(MDD, self).__init__()
        self.margin = margin
        
        if training_classifer:
            self.classifer1 = None
        else:
            if mlp:
                self.classifer1 = nn.Sequential(
                    nn.Linear(in_features, 512),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(512, num_classes))
            else:
                self.classifer1 = nn.Linear(in_features, num_classes)
        if mlp:
            self.classifer2 = nn.Sequential(
                    nn.Linear(in_features, 512),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(512, num_classes))
        else:
            self.classifer2 = nn.Linear(in_features, num_classes)
        self.entropy_eval = entropy_eval

    def source_discrepancy(self, y, y_adv):
        _, prediction = y.max(dim=1)
        return F.cross_entropy(y_adv, prediction)

    def target_discrepancy(self, y, y_adv):
        _, prediction = y.max(dim=1)
        return -F.nll_loss(shift_log(1. - F.softmax(y_adv, dim=1)), prediction)

    def forward(self, g_s, f_s, g_t, f_t, label_s=None, label_t=None, d_label=None, training=True):
        if self.classifer1 is None:
            y1_s = g_s
            y1_t = g_t
        else:
            y1_s = self.classifer1(f_s)
            y1_t = self.classifer1(f_t)
        y2_s = self.classifer2(f_s)
        y2_t = self.classifer2(f_t)
        source_loss = self.margin * self.source_discrepancy(y1_s, y2_s)
        target_loss = -self.target_discrepancy(y1_t, y2_t)
        if training:
            loss = source_loss+target_loss
            return loss
        else:
            if self.entropy_eval:
                return entropy(F.softmax(y2_t, dim=1)).mean()
            pred1_s = torch.argmax(y1_s, 1)
            pred2_s = torch.argmax(y2_s, 1)
            pred1_t = torch.argmax(y1_t, 1)
            pred2_t = torch.argmax(y2_t, 1)
            discrepancy_s = torch.sum((pred1_s != pred2_s).float()) / float(pred1_s.size(0))
            discrepancy_t = torch.sum((pred1_t != pred2_t).float()) / float(pred1_t.size(0))
            return discrepancy_t - discrepancy_s


class MarginDisparityDiscrepancy(nn.Module):
    r"""The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) <https://arxiv.org/abs/1904.05801>`_.

    MDD can measure the distribution discrepancy in domain adaptation.

    The :math:`y^s` and :math:`y^t` are logits output by the main head on the source and target domain respectively.
    The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial head.

    The definition can be described as:

    .. math::
        \mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) =
        -\gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} L_s (y^s, y_{adv}^s) +
        \mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} L_t (y^t, y_{adv}^t),

    where :math:`\gamma` is a margin hyper-parameter, :math:`L_s` refers to the disparity function defined on the source domain
    and :math:`L_t` refers to the disparity function defined on the target domain.

    Args:
        source_disparity (callable): The disparity function defined on the source domain, :math:`L_s`.
        target_disparity (callable): The disparity function defined on the target domain, :math:`L_t`.
        margin (float): margin :math:`\gamma`. Default: 4
        reduction (str, optional): Specifies the reduction to apply to the output:
          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
          ``'mean'``: the sum of the output will be divided by the number of
          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``

    Inputs:
        - y_s: output :math:`y^s` by the main head on the source domain
        - y_s_adv: output :math:`y^s` by the adversarial head on the source domain
        - y_t: output :math:`y^t` by the main head on the target domain
        - y_t_adv: output :math:`y_{adv}^t` by the adversarial head on the target domain
        - w_s (optional): instance weights for source domain
        - w_t (optional): instance weights for target domain

    Examples::

        >>> num_outputs = 2
        >>> batch_size = 10
        >>> loss = MarginDisparityDiscrepancy(margin=4., source_disparity=F.l1_loss, target_disparity=F.l1_loss)
        >>> # output from source domain and target domain
        >>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)
        >>> # adversarial output from source domain and target domain
        >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)
        >>> output = loss(y_s, y_s_adv, y_t, y_t_adv)
    """

    def __init__(self, adv_head, margin: Optional[float] = 4, reduction: Optional[str] = 'mean'):
        super(MarginDisparityDiscrepancy, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.adv_head = adv_head
        self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000,
                                                       auto_step=True)
                                                       
    def source_discrepancy(self, y: torch.Tensor, y_adv: torch.Tensor):
            _, prediction = y.max(dim=1)
            return F.cross_entropy(y_adv, prediction, reduction='none')

    def target_discrepancy(self, y: torch.Tensor, y_adv: torch.Tensor):
        _, prediction = y.max(dim=1)
        return -F.nll_loss(shift_log(1. - F.softmax(y_adv, dim=1)), prediction, reduction='none')

    def forward(self, y_s, f_s, y_t, f_t, w_s = None, w_t = None):
        f = torch.cat((f_s, f_t), dim=0)
        y_adv = self.adv_head(self.grl_layer(f))
        y_s_adv, y_t_adv = y_adv.chunk(2, dim=0)
        source_loss = -self.margin * self.source_discrepancy(y_s, y_s_adv)
        target_loss = self.target_discrepancy(y_t, y_t_adv)
        if w_s is None:
            w_s = torch.ones_like(source_loss)
        source_loss = source_loss * w_s
        if w_t is None:
            w_t = torch.ones_like(target_loss)
        target_loss = target_loss * w_t

        loss = source_loss + target_loss
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

def shift_log(x: torch.Tensor, offset: Optional[float] = 1e-6) -> torch.Tensor:
    r"""
    First shift, then calculate log, which can be described as:

    .. math::
        y = \max(\log(x+\text{offset}), 0)

    Used to avoid the gradient explosion problem in log(x) function when x=0.

    Args:
        x (torch.Tensor): input tensor
        offset (float, optional): offset size. Default: 1e-6

    .. note::
        Input tensor falls in [0., 1.] and the output tensor falls in [-log(offset), 0]
    """
    return torch.log(torch.clamp(x + offset, max=1.))


class ImageClassifier(ClassifierBase):
    def __init__(self, backbone, num_classes, bottleneck_dim = 256, 
                    mlp_classifier=False, width = 1024, **kwargs):
        bottleneck = nn.Sequential(
            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            # nn.Flatten(),
            nn.Linear(backbone.out_features, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU(),
            # nn.Dropout(0.5)
        )
        if mlp_classifier:
            head = nn.Sequential(
                nn.Linear(bottleneck_dim, width),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(width, num_classes)
            )
        else:
            head = None
        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, head, **kwargs)

class ClassifierHead(nn.Module):
    def __init__(self, num_classes, bottleneck_dim = 256, 
                    mlp_classifier=False, width = 1024):
        super(ClassifierHead, self).__init__()
        if mlp_classifier:
            self.head = nn.Sequential(
                nn.Linear(bottleneck_dim, width),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(width, num_classes)
            )
        else:
            self.head = nn.Linear(bottleneck_dim, num_classes)

    def forward(self, x):
        return self.head(x)

    def get_parameters(self, lr_multi_D=1.0) -> List[Dict]:
        return [{"params": self.parameters(), "lr": lr_multi_D}]
        