from typing import Any, List

import torch as th  # type:ignore
import torch.nn.functional as F  # type:ignore
from torch import nn


def _init_weights(m: Any) -> None:
    if isinstance(m, nn.Linear):
        th.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.1)


class PolicyNetwork(nn.Module):  # type:ignore

    def __init__(
        self,
        obs_shape: int,
        hidden_layers: List[int],
        n_actions: int,
    ):
        super(PolicyNetwork, self).__init__()
        self.layers = [nn.Linear(obs_shape, hidden_layers[0])]
        for i in range(len(hidden_layers) - 1):
            self.layers.append(
                nn.Linear(hidden_layers[i], hidden_layers[i + 1])
            )
        self.layers.append(nn.Linear(hidden_layers[-1], n_actions))
        self.layers = nn.ModuleList(self.layers)
        self.apply(_init_weights)

    def forward(self, obs: th.Tensor) -> th.Tensor:
        x = obs
        for i in range(len(self.layers) - 1):
            x = F.relu(self.layers[i](x))
        return self.layers[-1](x)

    def init_hidden(self):
        return None


class RewardConditionedPolicyNetwork(nn.Module):  # type:ignore

    def __init__(
        self,
        obs_shape: int,
        reward_dim: int,
        hidden_layers: List[int],
        n_actions: int,
    ):
        super(RewardConditionedPolicyNetwork, self).__init__()
        self.layers = [nn.Linear(obs_shape + reward_dim, hidden_layers[0])]
        for i in range(len(hidden_layers) - 1):
            self.layers.append(
                nn.Linear(hidden_layers[i], hidden_layers[i + 1])
            )
        self.layers.append(nn.Linear(hidden_layers[-1], n_actions))
        self.layers = nn.ModuleList(self.layers)
        self.apply(_init_weights)

    def forward(self, obs: th.Tensor, acc_reward: th.Tensor) -> th.Tensor:
        x = th.cat((obs, acc_reward), axis=-1)

        for i in range(len(self.layers) - 1):
            x = F.relu(self.layers[i](x))
        return self.layers[-1](x)

    def init_hidden(self):
        return None


class StateValueNetwork(nn.Module):  # type:ignore

    def __init__(
        self, obs_shape: int, reward_dim: int, hidden_layers: List[int]
    ):
        super(StateValueNetwork, self).__init__()
        self.layers = [nn.Linear(obs_shape + reward_dim, hidden_layers[0])]
        for i in range(len(hidden_layers) - 1):
            self.layers.append(
                nn.Linear(hidden_layers[i], hidden_layers[i + 1])
            )
        self.layers.append(nn.Linear(hidden_layers[-1], 1))
        self.layers = nn.ModuleList(self.layers)

    def forward(self, obs: th.Tensor, acc_reward: th.Tensor) -> th.Tensor:
        x = th.cat((obs, acc_reward), axis=1)
        for i in range(len(self.layers) - 1):
            x = F.relu(self.layers[i](x))
        return self.layers[-1](x)

class GRUPolicyNetwork(nn.Module):
    def __init__(
        self,
        obs_shape: int,
        gru_hidden_size: int,
        n_actions: int,
    ):
        super(GRUPolicyNetwork, self).__init__()

        self.gru_hidden_size = gru_hidden_size
        self.input_size = obs_shape + n_actions
        self.n_actions = n_actions
        self.gru = nn.GRUCell(self.input_size, self.gru_hidden_size)
        self.h2o = nn.Linear(self.gru_hidden_size, self.n_actions)
        self.apply(_init_weights)

    def forward(self, obs: th.Tensor, prev_action: th.Tensor, hidden: th.Tensor) -> th.Tensor:
        if(prev_action is None): prev_action = th.zeros((1,self.n_actions))
        else: prev_action = nn.functional.one_hot(prev_action, self.n_actions)
        in_gru = th.cat((obs, prev_action), axis=-1)
        new_hidden = self.gru(in_gru, hidden)
        output = self.h2o(new_hidden)
        return output, new_hidden

    def init_hidden(self):
        return th.rand(1, self.gru_hidden_size, requires_grad=True)


class RewardConditionedGRUPolicyNetwork(nn.Module):
    def __init__(
        self,
        obs_shape: int,
        reward_dim: int,
        gru_hidden_size: int,
        n_actions: int,
    ):
        super(RewardConditionedGRUPolicyNetwork, self).__init__()
        self.gru_hidden_size = gru_hidden_size
        self.input_size = obs_shape + reward_dim + n_actions
        self.n_actions = n_actions
        self.gru = nn.GRUCell(self.input_size, self.gru_hidden_size)
        self.h2o = nn.Linear(self.gru_hidden_size, self.n_actions)
        self.apply(_init_weights)

    def forward(
        self, obs: th.Tensor, acc_reward: th.Tensor, prev_action:th.Tensor , hidden: th.Tensor
    ) -> th.Tensor:
        if(prev_action is None): prev_action = th.zeros((1,self.n_actions), device = obs.device)
        else: prev_action = nn.functional.one_hot(prev_action, self.n_actions)

        in_gru = th.cat((obs, acc_reward, prev_action), axis=-1)
        new_hidden = self.gru(in_gru, hidden)
        output = self.h2o(new_hidden)
        return output, new_hidden

    def init_hidden(self):
        return th.rand(1, self.gru_hidden_size, requires_grad=True)
