from typing import Tuple, Optional, List, Dict
import torch.nn as nn
import torch

from .util_domain_adaptation import WarmStartGradientReverseLayer

__all__ = ['Classifier']


class Classifier(nn.Module):
    """A generic Classifier class for domain adaptation.

    Parameters:
        - **backbone** (class:`nn.Module` object): Any backbone to extract 1-d features from data
        - **num_classes** (int): Number of classes
        - **bottleneck** (class:`nn.Module` object, optional): Any bottleneck layer. Use no bottleneck by default
        - **bottleneck_dim** (int, optional): Feature dimension of the bottleneck layer. Default: -1
        - **head** (class:`nn.Module` object, optional): Any classifier head. Use `nn.Linear` by default

    .. note::
        Different classifiers are used in different domain adaptation algorithms to achieve better accuracy
        respectively, and we provide a suggested `Classifier` for different algorithms.
        Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with
        the domain adaptation algorithm in this algorithm library.

    .. note::
        The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy
        by default. If you have other optimization strategies, please over-ride `get_parameters`.

    Inputs:
        - **x** (tensor): input data fed to `backbone`

    Outputs: predictions, features
        - **predictions**: classifier's predictions
        - **features**: features after `bottleneck` layer and before `head` layer

    Shape:
        - Inputs: (minibatch, *) where * means, any number of additional dimensions
        - predictions: (minibatch, `num_classes`)
        - features: (minibatch, `features_dim`)

    """

    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,
                 bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, only_predictions=False):
        super(Classifier, self).__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        self.only_predictions = only_predictions
        if bottleneck is None:
            self.bottleneck = nn.Identity()
            self._features_dim = backbone.out_features
        else:
            self.bottleneck = bottleneck
            assert bottleneck_dim > 0
            self._features_dim = bottleneck_dim

        if head is None:
            self.head = nn.Linear(self._features_dim, num_classes)
        else:
            self.head = head

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """"""
        f = self.backbone(x)
        f = f.view(-1, self.backbone.out_features)
        f = self.bottleneck(f)
        predictions = self.head(f)
        if not self.only_predictions and self.training:
            return predictions, f
        else:
            return predictions

    def get_parameters(self) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.backbone.parameters(), "lr_mult": 1.},
            {"params": self.bottleneck.parameters(), "lr_mult": 1.},
            {"params": self.head.parameters(), "lr_mult": 1.},
        ]
        return params


class ImageClassifier(Classifier):
    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256,
                 only_predictions=False):
        self.noise_level = None
        bottleneck = nn.Sequential(
            nn.Linear(backbone.out_features, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU()
        )
        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim,
                                              only_predictions=only_predictions)


class ImageClassifierMDD(nn.Module):
    r"""Classifier for MDD.
    Parameters:
        - **backbone** (class:`nn.Module` object): Any backbone to extract 1-d features from data
        - **num_classes** (int): Number of classes
        - **bottleneck_dim** (int, optional): Feature dimension of the bottleneck layer. Default: 1024
        - **width** (int, optional): Feature dimension of the classifier head. Default: 1024

    .. note::
        Classifier for MDD has one backbone, one bottleneck, while two classifier heads.
        The first classifier head is used for final predictions.
        The adversarial classifier head is only used when calculating MarginDisparityDiscrepancy.

    .. note::
        Remember to call function `step()` after function `forward()` **during training phase**! For instance,

        >>> # x is inputs, classifier is an ImageClassifier
        >>> outputs, outputs_adv = classifier(x)
        >>> classifier.step()

    Inputs:
        - **x** (Tensor): input data

    Outputs: (outputs, outputs_adv)
        - **outputs**: logits outputs by the main classifier
        - **outputs_adv**: logits outputs by the adversarial classifier

    Shapes:
        - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.
        - outputs, outputs_adv: :math:`(minibatch, C)`, where C means the number of classes.

    """

    def __init__(self, backbone: nn.Module, num_classes: int,
                 bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024):
        super(ImageClassifierMDD, self).__init__()
        self.backbone = backbone
        self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000., auto_step=False)

        self.bottleneck = nn.Sequential(
            nn.Linear(backbone.out_features, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.bottleneck[0].weight.data.normal_(0, 0.005)
        self.bottleneck[0].bias.data.fill_(0.1)

        # The classifier head used for final predictions.
        self.head = nn.Sequential(
            nn.Linear(bottleneck_dim, width),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(width, num_classes)
        )
        # The adversarial classifier head
        self.adv_head = nn.Sequential(
            nn.Linear(bottleneck_dim, width),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(width, num_classes)
        )
        for dep in range(2):
            self.head[dep * 3].weight.data.normal_(0, 0.01)
            self.head[dep * 3].bias.data.fill_(0.0)
            self.adv_head[dep * 3].weight.data.normal_(0, 0.01)
            self.adv_head[dep * 3].bias.data.fill_(0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        features = self.bottleneck(features)
        outputs = self.head(features)

        # we only evaluate adversarial discriminator during training
        if self.training:
            features_adv = self.grl_layer(features)
            outputs_adv = self.adv_head(features_adv)
            return outputs, outputs_adv

        return outputs

    def step(self):
        """Call step() each iteration during training.
        Will increase :math:`\lambda` in GRL layer.
        """
        self.grl_layer.step()

    def get_parameters(self) -> List[Dict]:
        """
        :return: A parameters list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.backbone.parameters(), "lr_mult": 0.1},
            {"params": self.bottleneck.parameters(), "lr_mult": 1.},
            {"params": self.head.parameters(), "lr_mult": 1.},
            {"params": self.adv_head.parameters(), "lr_mult": 1}
        ]
        return params


class ImageClassifierDANN(Classifier):
    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256):
        bottleneck = nn.Sequential(
            nn.Linear(backbone.out_features, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU()
        )
        super(ImageClassifierDANN, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim)