import torch
import torch.nn.functional as F
import torch.nn as nn
from onpolicy.algorithms.utils.util import init, check
from onpolicy.algorithms.utils.cnn import CNNBase
from onpolicy.algorithms.utils.mlp import MLPBase
from onpolicy.algorithms.utils.rnn import RNNLayer
from onpolicy.algorithms.utils.act import ACTLayer
from onpolicy.algorithms.utils.popart import PopArt
from onpolicy.utils.util import get_shape_from_obs_space

class R_Actor(nn.Module):
    """
    Actor network class for MAPPO. Outputs actions given observations.
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param obs_space: (gym.Space) observation space.
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self, args, obs_space, action_space, num_agents, device=torch.device("cpu")):
        super(R_Actor, self).__init__()
        self.hidden_size = args.hidden_size
        self.n_rollout_threads = args.n_rollout_threads
        self._gain = args.gain
        self._use_orthogonal = args.use_orthogonal
        self._use_policy_active_masks = args.use_policy_active_masks
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.num_agents = num_agents

        obs_shape = get_shape_from_obs_space(obs_space)
        base = CNNBase if len(obs_shape) == 3 else MLPBase
        self.base = base(args, obs_shape)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args)
        self.to(device)

    def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=[False, False]):
        """
        Compute actions from the given inputs.
        :param obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param deterministic: (bool) whether to sample from action distribution or return the mode.

        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv).reshape(-1, *available_actions.shape[2:])
        
        actor_features = self.base(obs)
        b = rnn_states.shape[0]
        bs = actor_features.shape[0]

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            actor_features, rnn_states = self.rnn(actor_features.reshape(-1, *actor_features.shape[2:]), 
                                                    rnn_states.reshape(-1, *rnn_states.shape[2:]), 
                                                        masks.reshape(-1, *masks.shape[2:]))
            actions, action_log_probs = self.act(actor_features, available_actions, deterministic)
            rnn_states = rnn_states.reshape(b, self.num_agents, *rnn_states.shape[1:])

        else:
            actions, action_log_probs = self.act(actor_features.reshape(-1, *actor_features.shape[2:]), 
                                                available_actions, 
                                                deterministic)

        actions = actions.reshape(bs, self.num_agents, *actions.shape[1:])
        action_log_probs = action_log_probs.reshape(bs, self.num_agents, *action_log_probs.shape[1:])

        return actions, action_log_probs, rnn_states

    def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None, deterministic=[False, False]):
        """
        Compute log probability and entropy of given actions.
        :param obs: (torch.Tensor) observation inputs into network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv).reshape(-1, *available_actions.shape[2:])
        if active_masks is not None:
            active_masks = check(active_masks).to(**self.tpdv)

        actor_features = self.base(obs)

        if False:
            if self._use_naive_recurrent_policy or self._use_recurrent_policy:
                actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
            actions, action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features,
                                                                action, 
                                                                available_actions,
                                                                active_masks=
                                                                active_masks if self._use_policy_active_masks
                                                                else None, deterministic=deterministic)

        else:
            bs = actor_features.shape[0]
            num_agents = actor_features.shape[1]
            
            if self._use_naive_recurrent_policy or self._use_recurrent_policy:
                actor_features, rnn_states = self.rnn(actor_features.reshape(-1, *actor_features.shape[2:]), 
                                                        rnn_states.reshape(-1, *rnn_states.shape[2:]), 
                                                            masks.reshape(-1, *masks.shape[2:]))

                actions, action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features,
                                                                    action.reshape(-1, *action.shape[2:]), 
                                                                    available_actions,
                                                                    active_masks=
                                                                    active_masks.reshape(-1, *active_masks.shape[2:]) if self._use_policy_active_masks
                                                                    else None, deterministic=deterministic)
            else:
                actions, action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features.reshape(-1, *actor_features.shape[2:]),
                                                                    action.reshape(-1, *action.shape[2:]), 
                                                                    available_actions,
                                                                    active_masks=
                                                                    active_masks.reshape(-1, *active_masks.shape[2:]) if self._use_policy_active_masks
                                                                    else None, deterministic=deterministic)

            
            actions = actions.reshape(bs, num_agents, *actions.shape[1:])
            action_log_probs = action_log_probs.reshape(bs, num_agents, *action_log_probs.shape[1:])

        return actions, action_log_probs, dist_entropy

class R_Critic(nn.Module):
    """
    Critic network class for MAPPO. Outputs value function predictions given centralized input (MAPPO) or
                            local observations (IPPO).
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param cent_obs_space: (gym.Space) (centralized) observation space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self, args, cent_obs_space, num_agents, device=torch.device("cpu")):
        super(R_Critic, self).__init__()
        self.hidden_size = args.hidden_size
        self.n_rollout_threads = args.n_rollout_threads
        self._use_orthogonal = args.use_orthogonal
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self._use_popart = args.use_popart
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.num_agents = num_agents
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]

        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        base = CNNBase if len(cent_obs_shape) == 3 else MLPBase
        self.base = base(args, cent_obs_shape)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))


        self.v_out = init_(nn.Linear(self.hidden_size, 1))

        self.to(device)

    def forward(self, cent_obs, rnn_states, masks):
        """
        Compute actions from the given inputs.
        :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.

        :return values: (torch.Tensor) value function predictions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        cent_obs = check(cent_obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
       
        b = rnn_states.shape[0]
        bs = cent_obs.shape[0]
        
        critic_features = self.base(cent_obs)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            critic_features, rnn_states = self.rnn(critic_features.reshape(-1, *critic_features.shape[2:]), 
                                                        rnn_states.reshape(-1, *rnn_states.shape[2:]), 
                                                            masks.reshape(-1, *masks.shape[2:]))
            values = self.v_out(critic_features)
            rnn_states = rnn_states.reshape(b, self.num_agents, *rnn_states.shape[1:])

        else:
            values = self.v_out(critic_features.reshape(-1, *critic_features.shape[2:]))

        values = values.reshape(bs, self.num_agents, *values.shape[1:])

        return values, rnn_states

class R_QMIX(nn.Module):
    """
    Critic network class for MAPPO. Outputs value function predictions given centralized input (MAPPO) or
                            local observations (IPPO).
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param cent_obs_space: (gym.Space) (centralized) observation space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self, args, cent_obs_space, action_space, num_agents, device=torch.device("cpu")):
        super(R_QMIX, self).__init__()
        self.hidden_size = args.hidden_size 
        self.n_rollout_threads = args.n_rollout_threads
        self._use_orthogonal = args.use_orthogonal
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy 
        self._use_recurrent_policy = args.use_recurrent_policy 
        self._recurrent_N = args.recurrent_N
        self._use_popart = args.use_popart
        self._use_qmix = args.use_qmix
        self._use_joint = args.use_joint
        self.tpdv = dict(dtype=torch.float32, device=device) 
        self.embed_dim = args.embed_dim
        self.num_agents = num_agents

        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]
        if action_space.__class__.__name__ == "Discrete":
            action_dim = action_space.n
        elif action_space.__class__.__name__ == "Box":
            action_dim = action_space.shape[0]
        if self._use_joint:
            action_dim *= num_agents
        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        base = CNNBase if len(cent_obs_shape) == 3 else MLPBase
        self.state_dim = cent_obs_shape[0]
        self.base = base(args, [cent_obs_shape[0] + action_dim], hidden_size=self.hidden_size)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))

        if self._use_popart:
            self.v_out = init_(PopArt(self.hidden_size, 1, device=device))
        else:
            self.v_out = init_(nn.Linear(self.hidden_size, 1))

        if self._use_qmix:
            hypernet_embed = args.hypernet_embed
            self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim * self.num_agents))
            self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                               nn.ReLU(),
                                               nn.Linear(hypernet_embed, self.embed_dim))
            self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)

                # V(s) instead of a bias for the last layers
            self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))

        self.to(device)

    def forward(self, cent_obs, rnn_states, masks, actions):
        """
        Compute actions from the given inputs.
        :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.

        :return values: (torch.Tensor) value function predictions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        
        cent_obs = check(cent_obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        actions = check(actions).to(**self.tpdv).squeeze(-1)
        masks = check(masks).to(**self.tpdv)
        b = rnn_states.shape[0]
        bs = actions.shape[0]
        if self._use_joint:
            actions = actions.reshape(bs, 1, -1).repeat(1,self.num_agents,1)
                
        x = torch.cat([cent_obs, actions], dim=-1)
        
        critic_features = self.base(x)
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            critic_features, rnn_states = self.rnn(critic_features.reshape(-1, *critic_features.shape[2:]), 
                                                        rnn_states.reshape(-1, *rnn_states.shape[2:]), 
                                                             masks.reshape(-1, *masks.shape[2:]))
            values = self.v_out(critic_features)
            rnn_states = rnn_states.reshape(b, self.num_agents, *rnn_states.shape[1:])

        else:
            values = self.v_out(critic_features.reshape(-1, *critic_features.shape[2:]))

        values = values.reshape(bs, self.num_agents, *values.shape[1:])

        if self._use_qmix: 
            cent_obs = cent_obs[:,0,:].squeeze(1)
            values = values.view(bs, 1, self.num_agents)
            w1 = torch.abs(self.hyper_w_1(cent_obs))

            b1 = self.hyper_b_1(cent_obs)
            w1 = w1.view(bs, self.num_agents, self.embed_dim)
            b1 = b1.view(bs, 1, self.embed_dim)
            hidden = F.elu(torch.bmm(values, w1) + b1)

            w_final = torch.abs(self.hyper_w_final(cent_obs))
            w_final = w_final.view(bs, self.embed_dim, 1)
            v = self.V(cent_obs).view(bs, 1, 1)
            y = torch.bmm(hidden, w_final) + v
            return y.repeat(1,self.num_agents,1), rnn_states
        else:    
            return values.mean(dim=1, keepdim=True).repeat(1,self.num_agents,1), rnn_states
