from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from lambda_ac.nn.common import MLP, DoubleHeadMLP, combine_ensemble_actions
from lambda_ac.rl_types import Encoder, EncoderModelNetwork, EncoderOutput, FeatureInput
from lambda_ac.util.distributions import BoundedGaussian, DiracDistribution, Distribution


class Model(EncoderModelNetwork):

    feature_head: nn.Module
    reward_head: nn.Module
    done_head: nn.Module
    combination_layer: nn.Module

    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(encoder)
        self.hidden_dim = hidden_dim
        self.action_dim = action_dim
        self.feature_dim = feature_dim
        self.hidden_layers_core = hidden_layers_core
        self.hidden_layers_head = hidden_layers_head
        self.predict_done = predict_done

        self.reward_head = MLP(
            input_dim=feature_dim + action_dim,
            output_dim=1,
            hidden_layers=hidden_layers_head,
            hidden_dim=hidden_dim,
            normalize_input=normalize_head_input,
            apply_spectral_norm=False,
            batch_norm=False,
        )
        self.done_head = MLP(
            input_dim=feature_dim + action_dim,
            output_dim=1,
            hidden_layers=hidden_layers_head,
            hidden_dim=hidden_dim,
            normalize_input=normalize_head_input,
            apply_spectral_norm=False,
            batch_norm=False,
        )

    def _register_networks(self):
        self.latent_network = nn.ModuleList(
            [
                self.combination_layer,
                self.feature_head,
                self.reward_head,
                self.done_head,
            ]
        )

    def _combine_hidden(
        self, x: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError()

    def _create_feature_distribution(self, x) -> Distribution:
        raise NotImplementedError()

    def forward(
        self, features: FeatureInput, actions
    ) -> Tuple[EncoderOutput, torch.Tensor, torch.Tensor]:
        xu = combine_ensemble_actions(features.encoded, actions)
        x, next_hidden = self._combine_hidden(xu, features.hidden)
        next_features = self.feature_head(x)
        reward = self.reward_head(x)
        done = self.done_head(x)
        if not self.predict_done:
            done[:] = -100.
        return (
            EncoderOutput(
                self._create_feature_distribution(next_features), next_hidden
            ),
            reward,
            done,
        )

    def sample(self, features, actions):
        pred, reward, mask = self.forward(features, actions)
        return pred, reward, mask

    def roll(self) -> None:
        return


class DeterministicOutput(Model):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=False,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self.feature_head = MLP(
            input_dim=feature_dim + action_dim,
            output_dim=feature_dim,
            hidden_layers=hidden_layers_head,
            hidden_dim=hidden_dim,
            normalize_input=normalize_head_input,
            apply_spectral_norm=spectral_norm,
        )

    def _create_feature_distribution(self, x: torch.Tensor) -> Distribution:
        return DiracDistribution(x)


class ProbabilisticOutput(Model):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self.feature_head = DoubleHeadMLP(
            input_dim=feature_dim + action_dim,
            output_dim=feature_dim,
            hidden_layers=hidden_layers_head,
            hidden_dim=hidden_dim,
            normalize_input=normalize_head_input,
            normalize_last_layer=False,
            apply_spectral_norm=spectral_norm,
        )

    def _create_feature_distribution(
        self, x: Tuple[torch.Tensor, torch.Tensor]
    ) -> Distribution:
        mean, std = x
        return BoundedGaussian(mean, std)


class RecurrentCombination(Model):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predcit_done=predict_done
        )
        self.combination_layer = nn.GRUCell(feature_dim, feature_dim)

    def _combine_hidden(
        self, x: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        next_hidden = self.combination_layer(x, hidden)
        return next_hidden, next_hidden


class SkipCombination(Model):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self.combination_layer = nn.Linear(feature_dim * 2, feature_dim)

    def _combine_hidden(
        self, x: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x_cat = torch.cat([x, hidden], dim=-1)
        x_cat = F.elu(x_cat)
        x_cat = self.combination_layer(x_cat)
        x_cat = F.elu(x_cat)
        return x_cat, x_cat


class IdentityConnection(Model):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self.combination_layer = nn.Identity()

    def _combine_hidden(
        self, x: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return x, hidden


class DeterministicIdentityConnection(IdentityConnection, DeterministicOutput):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self._register_networks()


class DeterministicSkipConnection(SkipCombination, DeterministicOutput):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self._register_networks()


class DeterministicRecurrentConnection(RecurrentCombination, DeterministicOutput):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self._register_networks()


class ProbabilisticIdentityConnection(IdentityConnection, ProbabilisticOutput):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self._register_networks()


class ProbabilisticSkipConnection(SkipCombination, ProbabilisticOutput):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self._register_networks()


class ProbabilisticRecurrentConnection(RecurrentCombination, ProbabilisticOutput):
    def __init__(
        self,
        action_dim: int,
        feature_dim: int,
        hidden_dim: int,
        hidden_layers_core: int,
        hidden_layers_head: int,
        normalize_head_input: bool,
        encoder: Encoder,
        spectral_norm: bool = False,
        predict_done: bool = True
    ):
        super().__init__(
            action_dim,
            feature_dim,
            hidden_dim,
            hidden_layers_core=hidden_layers_core,
            hidden_layers_head=hidden_layers_head,
            normalize_head_input=normalize_head_input,
            encoder=encoder,
            spectral_norm=spectral_norm,
            predict_done=predict_done
        )
        self._register_networks()
