from __future__ import annotations

import itertools
from abc import ABC, abstractmethod
from typing import Callable, Final, Generic, Literal, Optional, Sequence, Type, TypeVar

import torch as th

from .protocols import ModuleHasForwardTrain

_ActivationTypeStr = Literal["relu", "tanh", "sigmoid", "elu"]
_ACTIVATION_TYPE_STR_TO_TYPE = {
    "relu": th.nn.ReLU,
    "tanh": th.nn.Tanh,
    "sigmoid": th.nn.Sigmoid,
    "elu": th.nn.ELU,
}

ModuleT = TypeVar("ModuleT", bound=th.nn.Module)


def make_fcn_from_conf(
    in_features: int,
    out_features: int,
    hidden_sizes: Sequence[int],
    use_batch_norms: Sequence[bool],
    activation_types: Sequence[Optional[_ActivationTypeStr]],
    dropout_ps: Sequence[Optional[float]],
):
    n_layers: int = len(hidden_sizes)
    assert n_layers == len(use_batch_norms)
    assert n_layers == len(activation_types)
    assert n_layers == len(dropout_ps)
    layer_specs = [
        (
            hsz,
            th.nn.BatchNorm1d if use_bn else None,
            _ACTIVATION_TYPE_STR_TO_TYPE[act_type] if act_type is not None else None,
            dropout_p if dropout_p is not None else None,
        )
        for hsz, use_bn, act_type, dropout_p in zip(
            hidden_sizes, use_batch_norms, activation_types, dropout_ps
        )
    ]
    fcn = make_fcn(in_features, out_features, layer_specs)
    return fcn


def make_fcn(
    in_features: int,
    out_features: int,
    layer_specs: Sequence[
        tuple[
            int,
            Optional[Type[th.nn.Module] | Callable[[int], th.nn.Module]],
            Optional[Type[th.nn.Module] | Callable[[], th.nn.Module]],
            Optional[float],
        ]
    ],
) -> th.nn.Sequential:
    """make fully connected network

    Args:
        nnet_in_features (int): input feature size of the FCN
        nnet_out_features (int): output feature size of the FCN
        layer_specs (Sequence[ tuple[ int, Optional[Type[th.nn.Module]  |  Callable[[int], th.nn.Module]], Optional[Type[th.nn.Module]  |  Callable[[], th.nn.Module]], ] ]): a sequence of (hidden_size, make_norm_func, make_activation_func, dropout_p)

    Returns:
        th.nn.Sequential: the fully connected layer
    """
    fcn = th.nn.Sequential()
    for idx, (hsz, make_norm_func, make_activation_func, dropout_p) in enumerate(
        layer_specs
    ):
        idx: int
        hsz: int
        if idx == 0:
            # first layer
            fcn.append(th.nn.Linear(in_features, hsz))
            if make_norm_func is not None:
                fcn.append(make_norm_func(hsz))
        else:
            # all intermediate layer
            fcn.append(th.nn.Linear(layer_specs[idx - 1][0], hsz))
            if make_norm_func is not None:
                fcn.append(make_norm_func(hsz))
        if make_activation_func is not None:
            fcn.append(make_activation_func())
        if dropout_p is not None:
            fcn.append(th.nn.Dropout(p=dropout_p))
        if idx + 1 == len(layer_specs):
            # final output layer
            fcn.append(th.nn.Linear(hsz, out_features))
    return fcn


class EnsembleModuleList(th.nn.ModuleList):
    def forward(self, inputs: th.Tensor, **kwargs) -> th.Tensor:
        outs_l: list[th.Tensor] = list()
        for m in self:
            outs_: th.Tensor = m.forward(inputs, **kwargs)
            outs_l.append(outs_[:, None])
        outs: th.Tensor = th.cat(outs_l, dim=1)
        return outs


class SimpleMovingAveragedModel(th.optim.swa_utils.AveragedModel, Generic[ModuleT]):
    module: ModuleT
    use_buffers: bool

    n_warmup: int
    _opt_step: int = 0

    @property
    def opt_step(self):
        return self._opt_step

    @opt_step.setter
    def opt_step(self, opt_step: int):
        assert opt_step >= 0
        self._opt_step = opt_step

    def __init__(
        self,
        model: th.nn.Module,
        n_warmup: int,
        use_buffers: bool = False,
    ) -> None:
        super().__init__(model, use_buffers=use_buffers)
        self.n_warmup = n_warmup
        self.forward = self.module.forward
        self.state_dict = self.module.state_dict
        self._opt_step = 0

    @th.no_grad()
    def update_parameters(self, model: th.nn.Module):
        # sequence of parameters to be averaged
        self_param = (
            itertools.chain(self.module.parameters(), self.module.buffers())
            if self.use_buffers
            else self.module.parameters()
        )
        model_param = (
            itertools.chain(model.parameters(), model.buffers())
            if self.use_buffers
            else model.parameters()
        )
        # detach parameters from current graph
        self_param_detached: list[th.Tensor] = list()
        model_param_detached: list[th.Tensor] = list()
        for p_averaged, p_model in zip(self_param, model_param):
            p_model_ = p_model.detach().to(p_averaged.device)
            self_param_detached.append(p_averaged.detach())
            model_param_detached.append(p_model_)
            # copy p_model to p_averaged during warmup
            if self.opt_step <= self.n_warmup:
                # copy from p_model_ to p_averaged
                p_averaged.detach().copy_(p_model_)
        # compute parameter moving average and copy to this/self model
        if self.n_warmup < self.opt_step:
            n_warmup: int = self.n_warmup
            alpha_avg: float = (self.opt_step - n_warmup) / (
                self.opt_step - n_warmup + 1
            )
            alpha_model: float = 1 / (self.opt_step - n_warmup + 1)
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                p_averaged.detach().copy_(
                    alpha_avg * p_averaged.detach() + alpha_model * p_model
                )
        # keep the buffers in sync with the source model if not apply running averages to the buffers
        if not self.use_buffers:
            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
                b_swa.detach().copy_(b_model.detach().to(b_swa.device))
        self.opt_step = self.opt_step + 1
        return


class SDT(th.nn.Module, ModuleHasForwardTrain):
    """Fast implementation of soft decision tree in PyTorch.

     Adapted from https://github.com/xuyxu/Soft-Decision-Tree

    Args:
        in_features (int): number of input features
        out_features (int): number of output responses
        depth (int, optional): the depth of the soft decisiont ree. Since SDT is a binary tree, settign `depth` to a large value will drastically increase both training and evaluation. Defaults to 5.
        lambda_coeff (float, optional): the coefficient of the regularization term in the training loss. Defaults to 1e-3.

    Attributes:
      internal_node_num_ (int):
        The number of internal nodes in the tree. Given the tree depth `d`, it is equal to :math:`2^d - 1`.
      leaf_node_num_ (int): The number of leaf nodes in the tree. Given the tree depth `d`, it equals to :math:`2^d`.
      penalties_l (list[float]): A list storing the layer-wise coefficients of the regularization term.
      inner_nodes (th.nn.Sequential): A container that simulates all internal nodes in the soft decision tree. The sigmoid activation function is concatenated to simulate the probabilistic routing mechanism.
      leaf_nodes (th.nn.Linear): Simulates all leaf nodes in the tree.
    """

    in_features: Final[int]
    out_features: Final[int]
    depth: Final[int]
    lambda_coeff: Final[float]

    internal_node_num_: Final[int]
    leaf_node_num_: Final[int]

    penalties_l: list[float]

    inner_nodes: th.nn.Sequential
    leaf_nodes: th.nn.Linear

    def __init__(
        self,
        in_features: int,
        out_features: int,
        depth: int = 5,
        lambda_coeff: float = 1e-3,
    ):
        """soft decision tree (SDT)

        Args:
            in_features (int): number of input features
            out_features (int): number of output responses
            depth (int, optional): the depth of the soft decisiont ree. Since SDT is a binary tree, settign `depth` to a large value will drastically increase both training and evaluation. Defaults to 5.
            lambda_coeff (float, optional): the coefficient of the regularization term in the training loss. Defaults to 1e-3.

        Raises:
            ValueError: if tree `depth` is less than zero
            ValueError: if `lambda_coeff` is negative
        """
        super().__init__()
        if depth <= 0:
            raise ValueError(
                f"The tree depth should be strictly positive, but got {depth}"
                "instead."
            )
        if lambda_coeff < 0:
            raise ValueError(
                f"The coefficient of the regularization term should not be negative, but got {lambda_coeff} instead."
            )
        # tree info
        self.in_features = in_features
        self.out_features = out_features
        self.depth = depth
        self.lambda_coeff = lambda_coeff
        self.internal_node_num_ = 2**self.depth - 1
        self.leaf_node_num_ = 2**self.depth
        # Different penalty coefficients for nodes in different layers
        self.penalties_l = [
            self.lambda_coeff * (2 ** (-depth)) for depth in range(0, self.depth)
        ]
        # Initialize internal nodes and leaf nodes, the input dimension on
        # internal nodes is added by 1, serving as the bias.
        self.inner_nodes = th.nn.Sequential(
            th.nn.Linear(self.in_features + 1, self.internal_node_num_, bias=False),
            th.nn.Sigmoid(),
        )
        self.leaf_nodes = th.nn.Linear(
            self.leaf_node_num_, self.out_features, bias=False
        )

    def forward(self, xs: th.Tensor) -> th.Tensor:
        yhats: th.Tensor = self.forward_train(xs)[0]
        return yhats

    def forward_train(self, xs: th.Tensor) -> tuple[th.Tensor, dict[str, th.Tensor]]:
        mus: th.Tensor
        penalty_: th.Tensor
        mus, penalty_ = self._forward(xs)
        yhats: th.Tensor = self.leaf_nodes(mus)
        # When `x` is the training data, the model also returns the penalty
        # to compute the training loss.
        info: dict[str, th.Tensor] = {"penalty": penalty_}
        return yhats, info

    def _forward(self, xs: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
        bsz: int = xs.shape[0]
        xs = self._data_augment(xs)
        # compute path probabilities
        path_prob: th.Tensor = self.inner_nodes(xs)
        path_prob = th.unsqueeze(path_prob, dim=2)
        path_prob = th.cat((path_prob, 1 - path_prob), dim=2)
        mu: th.Tensor = th.ones((bsz, 1, 1), device=xs.device)
        penalty_ = th.tensor(0.0, device=xs.device)
        # Iterate through internal odes in each layer to compute the final path
        # probabilities and the regularization term.
        begin_idx: int = 0
        end_idx: int = 1
        for layer_idx in range(0, self.depth):
            path_prob_: th.Tensor = path_prob[:, begin_idx:end_idx, :]
            # Extract internal nodes in the current layer to compute the
            # regularization term
            penalty_ = penalty_ + self._cal_penalty(layer_idx, mu, path_prob_)
            mu = mu.view(bsz, -1, 1).repeat(1, 1, 2)
            # update path probabilities
            mu = mu * path_prob_
            begin_idx = end_idx
            end_idx = begin_idx + 2 ** (layer_idx + 1)
        mu = mu.view(bsz, self.leaf_node_num_)
        return mu, penalty_

    def _cal_penalty(
        self, layer_idx: int, mu: th.Tensor, path_prob_: th.Tensor
    ) -> th.Tensor:
        penalty = th.tensor(0.0, device=mu.device)
        batch_size: int = mu.shape[0]
        mu = mu.view(batch_size, 2**layer_idx)
        path_prob_ = path_prob_.view(batch_size, 2 ** (layer_idx + 1))
        for node in range(0, 2 ** (layer_idx + 1)):
            alpha: th.Tensor = th.sum(
                path_prob_[:, node] * mu[:, node // 2], dim=0
            ) / th.sum(mu[:, node // 2], dim=0)
            coeff: float = self.penalties_l[layer_idx]
            penalty -= 0.5 * coeff * (th.log(alpha) + th.log(1 - alpha))
        return penalty

    def _data_augment(self, x: th.Tensor) -> th.Tensor:
        """Add a constant input `1` onto the front of each sample."""
        bsz: int = x.shape[0]
        x = x.view(bsz, -1)
        bias: th.Tensor = th.ones((bsz, 1), device=x.device)
        x = th.cat((bias, x), 1)
        return x


class ProtoNet(th.nn.Module, ABC):
    in_features: Final[int]
    out_features: Final[int]
    _prototypes: th.Tensor

    @property
    def prototypes(self) -> th.Tensor:
        return self._prototypes

    @prototypes.setter
    def prototypes(self, prototypes: th.Tensor):
        assert self._prototypes.shape == prototypes.shape
        self._prototypes = th.nn.Parameter(prototypes)

    def __init__(self, in_features: int, out_features: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self._prototypes = th.nn.Parameter(th.randn((in_features, out_features)))

    @abstractmethod
    def forward(self, xs: th.Tensor) -> th.Tensor: ...


class CosineProtoNet(ProtoNet):
    def forward(self, xs: th.Tensor) -> th.Tensor:
        outs: th.Tensor = th.nn.functional.cosine_similarity(
            xs[:, :, None], self.prototypes[None, :, :]
        )
        return outs


class CDistProtoNet(ProtoNet):
    def forward(self, xs: th.Tensor) -> th.Tensor:
        prototypes: th.Tensor = th.permute(self.prototypes[None, :, :], (0, 2, 1))
        outs = th.cdist(xs[:, None, :], prototypes)[:, 0, :]
        return outs


class ClusterNet(th.nn.Module, ModuleHasForwardTrain, ABC):
    in_features: Final[int]
    out_features: Final[int]

    _centroids: th.nn.Parameter
    _values: th.nn.Parameter

    @property
    def n_clusters(self):
        return self.centroids.shape[0]

    @property
    def centroids(self) -> th.nn.Parameter:
        return self._centroids

    @centroids.setter
    def centroids(self, centroids: th.Tensor):
        if isinstance(self._centroids, th.nn.UninitializedParameter):
            self._centroids.materialize(centroids.shape)
        else:
            assert self._centroids.shape == centroids.shape
        self._centroids.copy_(centroids)

    @property
    def values(self) -> th.nn.Parameter:
        return self._values

    @values.setter
    def values(self, values: th.Tensor):
        if isinstance(self._values, th.nn.UninitializedParameter):
            self._values.materialize(values.shape)
        else:
            assert self._values.shape == values.shape
        self._values.copy_(values)

    def __init__(self, in_features: int, out_features: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.in_features = in_features
        self.out_features = out_features
        # (n_clusters, n_covs)
        self._centroids = th.nn.UninitializedParameter()
        # (n_clusters, n_outputs)
        self._values = th.nn.UninitializedParameter()

    def forward(self, xs: th.Tensor) -> th.Tensor:
        outs, _ = self.forward_train(xs)
        return outs


class CDistClusterNet(ClusterNet):
    def forward_train(self, xs: th.Tensor) -> tuple[th.Tensor, dict[str, th.Tensor]]:
        # (n, n_clusters)
        ds: th.Tensor = self.compute_distance_to_clusters(xs)
        # (n, out_features)
        ws_: th.Tensor = th.softmax(-ds, dim=1)
        outs: th.Tensor = -ws_ @ self.values
        return outs, {"ws_": ws_}

    def compute_distance_to_clusters(self, xs: th.Tensor) -> th.Tensor:
        # (1, n_clusters, n_covs)
        centroids: th.Tensor = self.centroids[None, :, :]
        # (n, n_clusters)
        ds: th.Tensor = th.cdist(xs[:, None, :], centroids)[:, 0, :]
        return ds

    def compute_distance_and_labels(self, xs: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
        ds: th.Tensor = self.compute_distance_to_clusters(xs)
        cids: th.Tensor = th.argmin(ds, dim=1)
        return ds, cids

    def compute_cluster_labels(self, xs: th.Tensor) -> th.Tensor:
        cids: th.Tensor = self.compute_distance_and_labels(xs)[1]
        return cids


class CosineClusterNet(ClusterNet):
    def forward_train(self, xs: th.Tensor) -> tuple[th.Tensor, dict[str, th.Tensor]]:
        # (n, n_clusters)
        ws_: th.Tensor = th.nn.functional.cosine_similarity(
            xs[:, :, None], self.centroids.T[None, :, :]
        )
        # (n, out_features)
        ws_: th.Tensor = th.softmax(ws_, dim=1)
        outs: th.Tensor = -ws_ @ self.values
        return outs, {"ws_": ws_}
