from typing import Dict

import torch
import torch.nn as nn

from nxcl.config import ConfigDict as CfgNode

from .layers import *


__all__ = [
    "build_softmax_classifier",
]


class SoftmaxClassifier(nn.Module):

    def __init__(
            self,
            feature_dim: int,
            num_classes: int,
            use_bias: bool,
            linear: nn.Module = Linear,
            **kwargs,
        ) -> None:
        super(SoftmaxClassifier, self).__init__()
        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.use_bias    = use_bias
        self.linear      = linear

        self.fc = linear(
            in_features=self.feature_dim,
            out_features=self.num_classes,
            bias=self.use_bias,
            **kwargs
        )

    def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

        outputs = dict()

        # make predictions
        logits = self.fc(x, **kwargs)
        outputs["logits"] = logits
        outputs["confidences"] = torch.softmax(logits, dim=1)
        outputs["log_confidences"] = torch.log_softmax(logits, dim=1)

        return outputs


def build_softmax_classifier(cfg: CfgNode) -> nn.Module:

    # Linear layers may be replaced by its variants
    _linear_layers = cfg.MODEL.CLASSIFIER.SOFTMAX_CLASSIFIER.LINEAR_LAYERS
    if _linear_layers == "Linear":
        linear_layers = Linear
    elif _linear_layers == "Linear_Bezier":
        linear_layers = Linear_Bezier
    else:
        raise NotImplementedError(
            f"Unknown MODEL.CLASSIFIER.SOFTMAX_CLASSIFIER.LINEAR_LAYERS: {_linear_layers}"
        )

    classifier = SoftmaxClassifier(
        feature_dim = cfg.MODEL.CLASSIFIER.SOFTMAX_CLASSIFIER.FEATURE_DIM,
        num_classes = cfg.MODEL.CLASSIFIER.SOFTMAX_CLASSIFIER.NUM_CLASSES,
        use_bias    = cfg.MODEL.CLASSIFIER.SOFTMAX_CLASSIFIER.USE_BIAS,
        linear      = linear_layers,
    )

    # initialize weights
    if isinstance(classifier.fc.weight, nn.ParameterList):
        for idx in range(len(classifier.fc.weight)):
            nn.init.kaiming_normal_(classifier.fc.weight[idx], mode="fan_out", nonlinearity="relu")
    else:
        nn.init.kaiming_normal_(classifier.fc.weight, mode="fan_out", nonlinearity="relu")

    return classifier
