# flake8: noqa

from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
from tianshou.utils.net.common import MLP
from tianshou.utils.net.continuous import Critic
from torch import nn
from typing import Dict, List, Union, Tuple, Optional
from torch.nn import functional as F

class Swish(nn.Module):
    def __init__(self) -> None:
        super(Swish, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x * torch.sigmoid(x)
        return x

class EnsembleLinear(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        num_ensemble: int,
        weight_decay: float = 0.0
    ) -> None:
        super().__init__()

        self.num_ensemble = num_ensemble

        self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim)))
        self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim)))

        nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5))

        self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone()))
        self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone()))

        self.weight_decay = weight_decay

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.weight
        bias = self.bias

        if len(x.shape) == 2:
            x = torch.einsum('ij,bjk->bik', x, weight)
        else:
            x = torch.einsum('bij,bjk->bik', x, weight)

        x = x + bias

        return x

    def load_save(self) -> None:
        self.weight.data.copy_(self.saved_weight.data)
        self.bias.data.copy_(self.saved_bias.data)

    def update_save(self, indexes: List[int]) -> None:
        self.saved_weight.data[indexes] = self.weight.data[indexes]
        self.saved_bias.data[indexes] = self.bias.data[indexes]
    
    def get_decay_loss(self) -> torch.Tensor:
        decay_loss = self.weight_decay * (0.5*((self.weight**2).sum()))
        return decay_loss

class EnsembleCostModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_dims: Union[List[int], Tuple[int]],
        num_ensemble: int = 5,
        num_elites: int = 5,
        activation: nn.Module = nn.ReLU,
        weight_decays: Optional[Union[List[float], Tuple[float]]] = None,
        device: str = "cuda:0"
    ) -> None:
        super().__init__()

        self.num_ensemble = num_ensemble
        self.num_elites = num_elites
        self.device = torch.device(device)

        self.activation = activation()

        module_list = []
        hidden_dims = [obs_dim+action_dim] + list(hidden_dims)
        if weight_decays is None:
            weight_decays = [0.0] * (len(hidden_dims) + 1)
        for in_dim, out_dim, weight_decay in zip(hidden_dims[:-1], hidden_dims[1:], weight_decays[:-1]):
            module_list.append(EnsembleLinear(in_dim, out_dim, num_ensemble, weight_decay))
        self.backbones = nn.ModuleList(module_list)

        self.output_layer = EnsembleLinear(
            hidden_dims[-1],
            1,
            num_ensemble,
            weight_decays[-1]
        )

        self.register_parameter(
            "elites",
            nn.Parameter(torch.tensor(list(range(0, self.num_elites))), requires_grad=False)
        )

        self.to(self.device)

    def forward(self, obs: Union[np.ndarray, torch.Tensor], act: Optional[Union[np.ndarray, torch.Tensor]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        obs = torch.as_tensor(
            obs,
            device=self.device,  # type: ignore
            dtype=torch.float32,
        ).flatten(1)
        if act is not None:
            act = torch.as_tensor(
                act,
                device=self.device,  # type: ignore
                dtype=torch.float32,
            ).flatten(1)
            obs = torch.cat([obs, act], dim=1)
        output = obs
        for layer in self.backbones:
            output = self.activation(layer(output))
        cost_prob = F.sigmoid(self.output_layer(output))
        return cost_prob

    def load_save(self) -> None:
        for layer in self.backbones:
            layer.load_save()
        self.output_layer.load_save()

    def update_save(self, indexes: List[int]) -> None:
        for layer in self.backbones:
            layer.update_save(indexes)
        self.output_layer.update_save(indexes)
    
    def get_decay_loss(self) -> torch.Tensor:
        decay_loss = 0
        for layer in self.backbones:
            decay_loss += layer.get_decay_loss()
        decay_loss += self.output_layer.get_decay_loss()
        return decay_loss

    def set_elites(self, indexes: List[int]) -> None:
        assert len(indexes) <= self.num_ensemble and max(indexes) < self.num_ensemble
        self.register_parameter('elites', nn.Parameter(torch.tensor(indexes), requires_grad=False))
    
    def random_elite_idxs(self, batch_size: int) -> np.ndarray:
        idxs = np.random.choice(self.elites.data.cpu().numpy(), size=batch_size)
        return idxs


class DoubleCritic(nn.Module):
    """Double critic network. Will create an actor operated in continuous \
    action space with structure of preprocess_net ---> 1(q value).

    :param preprocess_net1: a self-defined preprocess_net which output a flattened hidden
        state.
    :param preprocess_net2: a self-defined preprocess_net which output a flattened hidden
        state.
    :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net.
        Default to empty sequence (where the MLP now contains only a single linear
        layer).
    :param int preprocess_net_output_dim: the output dimension of preprocess_net.
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
    :param bool flatten_input: whether to flatten input data for the last layer. Default
        to True.

    For advanced usage (how to customize the network), please refer to tianshou's \
        `build_the_network tutorial <https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network>`_.

    .. seealso::

        Please refer to tianshou's `Net <https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.net.common.Net>`_
        class as an instance of how preprocess_net is suggested to be defined.
    """

    def __init__(
        self,
        preprocess_net1: nn.Module,
        preprocess_net2: nn.Module,
        hidden_sizes: Sequence[int] = (),
        device: Union[str, int, torch.device] = "cpu",
        preprocess_net_output_dim: Optional[int] = None,
        linear_layer: Type[nn.Linear] = nn.Linear,
        flatten_input: bool = True,
    ) -> None:
        super().__init__()
        self.device = device
        self.preprocess1 = preprocess_net1
        self.preprocess2 = preprocess_net2
        self.output_dim = 1
        input_dim = getattr(preprocess_net1, "output_dim", preprocess_net_output_dim)
        self.last1 = MLP(
            input_dim,  # type: ignore
            1,
            hidden_sizes,
            device=self.device,
            linear_layer=linear_layer,
            flatten_input=flatten_input,
        )
        self.last2 = deepcopy(self.last1)

    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> list:
        """Mapping: (s, a) -> logits -> Q(s, a)."""
        obs = torch.as_tensor(
            obs,
            device=self.device,  # type: ignore
            dtype=torch.float32,
        ).flatten(1)
        if act is not None:
            act = torch.as_tensor(
                act,
                device=self.device,  # type: ignore
                dtype=torch.float32,
            ).flatten(1)
            obs = torch.cat([obs, act], dim=1)
        logits1, hidden = self.preprocess1(obs)
        logits1 = self.last1(logits1)
        logits2, hidden = self.preprocess2(obs)
        logits2 = self.last2(logits2)
        return [logits1, logits2]

    def predict(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, list]:
        """Mapping: (s, a) -> logits -> Q(s, a).

        :return: q value, and a list of two q values (used for Bellman backup)"""
        q_list = self(obs, act, info)
        q = torch.min(q_list[0], q_list[1])
        return q, q_list


class DoubleCriticCNN(nn.Module):
    def __init__(
        self,
        preprocess_net1: nn.Module,
        preprocess_net2: nn.Module,
        hidden_sizes: Sequence[int] = (),
        device: Union[str, int, torch.device] = "cpu",
        preprocess_net_output_dim: Optional[int] = None,
        linear_layer: Type[nn.Linear] = nn.Linear,
        flatten_input: bool = True,
        cost_clamp: bool = False
    ) -> None:
        super().__init__()
        self.device = device
        self.preprocess1 = preprocess_net1
        self.preprocess2 = preprocess_net2
        self.output_dim = 1
        input_dim = getattr(preprocess_net1, "output_dim", preprocess_net_output_dim)
        self.last1 = MLP(
            input_dim,  # type: ignore
            1,
            hidden_sizes,
            device=self.device,
            linear_layer=linear_layer,
            flatten_input=flatten_input,
        )
        self.last2 = deepcopy(self.last1)

        self.Conv1= nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=2),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.Conv2 = deepcopy(self.Conv1)
        self.cost_clamp = cost_clamp

    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> list:
        """Mapping: (s, a) -> logits -> Q(s, a)."""
        obs = torch.as_tensor(
            obs,
            device=self.device,  # type: ignore
            dtype=torch.float32,
        ).flatten(1)
        img_feature = obs[:,4:4+25].reshape(-1,1,5,5)
        other_feature_0 = obs[:,:4]
        other_feature_1 = obs[:,29:]
        img_feature1 = self.Conv1(img_feature).flatten(start_dim=1,end_dim=-1)
        img_feature2 = self.Conv2(img_feature).flatten(start_dim=1,end_dim=-1)
        obs1 = torch.cat([other_feature_0, other_feature_1, img_feature1], dim=1)
        obs2 = torch.cat([other_feature_0, other_feature_1, img_feature2], dim=1)
        if act is not None:
            act = torch.as_tensor(
                act,
                device=self.device,  # type: ignore
                dtype=torch.float32,
            ).flatten(1)
            obs1 = torch.cat([obs1, act], dim=1)
            obs2 = torch.cat([obs2, act], dim=1)
        logits1, hidden = self.preprocess1(obs1)
        logits1 = self.last1(logits1)
        logits2, hidden = self.preprocess2(obs2)
        logits2 = self.last2(logits2)
        if self.cost_clamp:
            logits1 = torch.clamp(logits1, min=0.0, max=300.0)
            logits2 = torch.clamp(logits2, min=0.0, max=300.0)
        return [logits1, logits2]

    def predict(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, list]:
        """Mapping: (s, a) -> logits -> Q(s, a).

        :return: q value, and a list of two q values (used for Bellman backup)"""
        q_list = self(obs, act, info)
        q = torch.min(q_list[0], q_list[1])
        return q, q_list


class SingleCriticCNN(Critic):
    """Simple critic network. Will create an actor operated in continuous \
    action space with structure of preprocess_net ---> 1(q value). It differs from
    tianshou's original Critic in that the output will be a list to make the API
    consistent with :class:`~fsrl.utils.net.continuous.DoubleCritic`.

    :param preprocess_net: a self-defined preprocess_net which output a flattened hidden
        state.
    :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net.
        Default to empty sequence (where the MLP now contains only a single linear
        layer).
    :param int preprocess_net_output_dim: the output dimension of preprocess_net.
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
    :param bool flatten_input: whether to flatten input data for the last layer. Default
        to True.
    """

    def __init__(
        self,
        preprocess_net: nn.Module,
        hidden_sizes: Sequence[int] = (),
        device: Union[str, int, torch.device] = "cpu",
        preprocess_net_output_dim: Optional[int] = None,
        linear_layer: Type[nn.Linear] = nn.Linear,
        flatten_input: bool = True,
        cost_clamp: bool = False
    ) -> None:
        super().__init__(
            preprocess_net, hidden_sizes, device, preprocess_net_output_dim,
            linear_layer, flatten_input
        )
        self.Conv= nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=2),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.cost_clamp = cost_clamp

    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> torch.Tensor:
        """Mapping: (s, a) -> logits -> Q(s, a)."""
        obs = torch.as_tensor(
            obs,
            device=self.device,  # type: ignore
            dtype=torch.float32,
        ).flatten(1)
        img_feature = obs[:,4:4+25].reshape(-1,1,5,5)
        img_feature = self.Conv(img_feature).flatten(start_dim=1,end_dim=-1)
        other_feature_0 = obs[:,:4]
        other_feature_1 = obs[:,29:]
        obs = torch.cat([other_feature_0, other_feature_1, img_feature], dim=1)
        logits = super().forward(obs, act, info)
        if self.cost_clamp:
            logits = torch.clamp(logits, min=0.0, max=300.0)
        return [logits]

    def predict(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, list]:
        """Mapping: (s, a) -> logits -> Q(s, a).

        :return: q value, and a list of two q values (used for Bellman backup)
        """
        q = self(obs, act, info)[0]
        return q, [q]

class SingleCritic(Critic):
    """Simple critic network. Will create an actor operated in continuous \
    action space with structure of preprocess_net ---> 1(q value). It differs from
    tianshou's original Critic in that the output will be a list to make the API
    consistent with :class:`~fsrl.utils.net.continuous.DoubleCritic`.

    :param preprocess_net: a self-defined preprocess_net which output a flattened hidden
        state.
    :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net.
        Default to empty sequence (where the MLP now contains only a single linear
        layer).
    :param int preprocess_net_output_dim: the output dimension of preprocess_net.
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
    :param bool flatten_input: whether to flatten input data for the last layer. Default
        to True.
    """

    def __init__(
        self,
        preprocess_net: nn.Module,
        hidden_sizes: Sequence[int] = (),
        device: Union[str, int, torch.device] = "cpu",
        preprocess_net_output_dim: Optional[int] = None,
        linear_layer: Type[nn.Linear] = nn.Linear,
        flatten_input: bool = True
    ) -> None:
        super().__init__(
            preprocess_net, hidden_sizes, device, preprocess_net_output_dim,
            linear_layer, flatten_input
        )

    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> torch.Tensor:
        """Mapping: (s, a) -> logits -> Q(s, a)."""
        logits = super().forward(obs, act, info)
        return [logits]

    def predict(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        act: Optional[Union[np.ndarray, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, list]:
        """Mapping: (s, a) -> logits -> Q(s, a).

        :return: q value, and a list of two q values (used for Bellman backup)
        """
        q = self(obs, act, info)[0]
        return q, [q]
