import torch
import torch.nn as nn

from typing import Tuple
from torch_geometric.nn import Sequential
from torch_geometric.data import Data
from torch import Tensor

from gcbf.nn.mlp import MLP
from gcbf.nn.gnn import ControllerGNNLayer

from .utils import reparameterize, evaluate_log_pi


class StateIndependentPolicy(nn.Module):
    """
    Stochastic policy \pi(a|s)
    """

    def __init__(
            self,
            node_dim: int,
            edge_dim: int,
            action_dim: int,
            phi_dim: int
    ):
        super().__init__()

        self.feat_transformer = Sequential('x, edge_attr, edge_index', [
            (ControllerGNNLayer(node_dim=node_dim, edge_dim=edge_dim, output_dim=1024, phi_dim=phi_dim),
             'x, edge_attr, edge_index -> x'),
        ])
        self.feat_2_action = MLP(in_channels=1024 + action_dim, out_channels=action_dim, hidden_layers=(512, 128, 32))
        self.log_stds = nn.Parameter(torch.zeros(1, action_dim))

    def net(self, data: Data) -> Tensor:
        x = self.feat_transformer(data.x, data.edge_attr, data.edge_index)
        if hasattr(data, 'agent_mask'):
            x = x[data.agent_mask]
        actions = self.feat_2_action(torch.cat([x, data.u_ref], dim=1))

        return actions

    def forward(self, data: Data) -> torch.Tensor:
        """
        Get the mean of the stochastic policy
        """
        return self.net(data)

    def sample(self, data: Data) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Sample actions given states
        """
        u = self.net(data)
        action, log_pi = reparameterize(u, self.log_stds)
        # actions, log_pis = [], []
        # for i in range(u.shape[0]):
        #     actions_i, log_pi_i = reparameterize(u[i], self.log_stds)
        #     actions.append(actions_i)
        #     log_pis.append(log_pi_i)
        return action, log_pi

    def evaluate_log_pi(self, data: Data, actions: torch.Tensor, num_agents: int) -> torch.Tensor:
        """
        Evaluate the log(pi(a|s)) of the given actions
        """
        u = self.net(data).view(-1, num_agents, self.log_stds.shape[1])
        #
        # u2 = []
        # data_list = data.to_data_list()
        # for i in range(len(data_list)):
        #     u2.append(self.net(data_list[i]))
        # u2 = torch.stack(u2)

        # actions = actions.view(-1, u.shape[1])
        log_pis = evaluate_log_pi(u, self.log_stds, actions).squeeze(-1)
        # for i in range(actions.shape[0]):
        #     log_pis.append(evaluate_log_pi(u[i], self.log_stds, actions[i]))
        return log_pis


class MAPPOCritic(nn.Module):
    def __init__(
            self,
            node_dim: int,
            edge_dim: int,
            action_dim: int,
            phi_dim: int
    ):
        super().__init__()

        self.feat_transformer = Sequential('x, edge_attr, edge_index', [
            (ControllerGNNLayer(node_dim=node_dim, edge_dim=edge_dim, output_dim=1024, phi_dim=phi_dim),
             'x, edge_attr, edge_index -> x'),
        ])
        self.feat_2_critic = MLP(in_channels=1024 + action_dim, out_channels=1, hidden_layers=(512, 128, 32))

    def forward(self, data: Data) -> torch.Tensor:
        x = self.feat_transformer(data.x, data.edge_attr, data.edge_index)
        if hasattr(data, 'agent_mask'):
            x = x[data.agent_mask]
        values = self.feat_2_critic(torch.cat([x, data.u_ref], dim=1))
        return values
