import abc
from typing import Tuple, Dict
import chex
from functools import partial
import jax
import jax.numpy as jnp

from src.agents.actors import ActorCriticRNN, ActorWithConditionalCritic

# from src.agents.mlp_actor_critic import ActorCritic
# from src.agents.mlp_actor_critic import ActorWithDoubleCritic
# from src.agents.mlp_actor_critic import ActorWithConditionalCritic
# from src.agents.s5_actor_critic import S5ActorCritic, StackedEncoderModel, init_S5SSM, make_DPLR_HiPPO
# from src.agents.rnn_actor_critic import RNNActorCritic, ScannedRNN


class AgentPolicy(abc.ABC):
    '''Abstract base class for a policy.'''

    def __init__(self, action_dim, obs_dim):
        '''
        Args:
            action_dim: int, dimension of the action space
            obs_dim: int, dimension of the observation space
        '''
        self.action_dim = action_dim
        self.obs_dim = obs_dim

    @abc.abstractmethod
    @partial(jax.jit, static_argnums=(0,))
    def get_action(self, params, obs, done, avail_actions, hstate, rng,
                   aux_obs=None, env_state=None, test_mode=False) -> Tuple[int, chex.Array]:
        """
        Only computes an action given an observation, done flag, available actions, hidden state, and random key.

        Args:
            params (dict): The parameters of the policy.
            obs (chex.Array): The observation.
            done (chex.Array): The done flag.
            avail_actions (chex.Array): The available actions.
            hstate (chex.Array): The hidden state.
            key (jax.random.PRNGKey): The random key.
            env_state (chex.Array): The environment state.
            aux_obs (chex.Array): an optional auxiliary vector to append to the observation
        Returns:
            Tuple[int, chex.Array]: A tuple containing the action and the new hidden state.
        """
        pass

    @partial(jax.jit, static_argnums=(0,))
    def get_action_value_policy(self, params, obs, done, avail_actions, hstate, rng,
                                aux_obs=None, env_state=None) -> Tuple[int, chex.Array, chex.Array, chex.Array]:
        """
        Computes the action, value, and policy given an observation, 
        done flag, available actions, hidden state, and random key.

        Args:
            params (dict): The parameters of the policy.
            obs (chex.Array): The observation.
            done (chex.Array): The done flag.
            avail_actions (chex.Array): The available actions.
            hstate (chex.Array): The hidden state.
            key (jax.random.PRNGKey): The random key.
            aux_obs (chex.Array): an optional auxiliary vector to append to the observation
        Returns:
            Tuple[int, chex.Array, chex.Array, chex.Array]: 
                A tuple containing the action, value, policy, and new hidden state.
        """
        pass

    def init_hstate(self, batch_size, aux_info: dict = None) -> chex.Array:
        """Initialize the hidden state for the policy.
        Args:
            batch_size: int, the batch size of the hidden state
            aux_info: any auxiliary information needed to initialize the hidden state at the 
            start of an episode (e.g. the agent id). 
        Returns:
            chex.Array: the initialized hidden state
        """
        return None

    def init_params(self, rng) -> Dict:
        """Initialize the parameters for the policy."""
        return None


class AgentPopulation:
    '''Base class for a population of homogeneous agents
    '''

    def __init__(self, pop_size, policy_cls):
        '''
        Args:
            pop_size: int, number of agents in the population
            policy_cls: an instance of the AgentPolicy class. The policy class for the population of agents
        '''
        self.pop_size = pop_size
        self.policy_cls = policy_cls  # AgentPolicy class

    def sample_agent_indices(self, n, rng):
        '''Sample n indices from the population, with replacement.'''
        return jax.random.randint(rng, (n,), 0, self.pop_size)

    def gather_agent_params(self, pop_params, agent_indices):
        '''Gather the parameters of the agents specified by agent_indices.

        Args:
            pop_params: pytree of parameters for the population of agents of shape (pop_size, ...).
            agent_indices: indices with shape (num_envs,), each in [0, pop_size)
        '''
        def gather_leaf(leaf):
            # leaf shape: (num_envs,  ...)
            return jax.vmap(lambda idx: leaf[idx])(agent_indices)
        return jax.tree.map(gather_leaf, pop_params)

    def get_actions(self, pop_params, agent_indices, obs, done, hstate, rng,
                    env_state=None, aux_obs=None, test_mode=False):
        '''
        Get the actions of the agents specified by agent_indices.

        Args:
            pop_params: pytree of parameters for the population of agents of shape (pop_size, ...).
            agent_indices: indices with shape (num_envs,), each in [0, pop_size)
            obs: observations with shape (num_envs, ...)
            done: done flags with shape (num_envs,)
            avail_actions: available actions with shape (num_envs, num_actions)
            hstate: hidden state with shape (num_envs, ...) or None if policy doesn't use hidden state
            rng: random key
            env_state: environment state with shape (num_envs, ...) or None if policy doesn't use env state
            aux_obs: an optional auxiliary vector to append to the observation
        Returns:
            actions: actions with shape (num_envs,)
            new_hstate: new hidden state with shape (num_envs, ...) or None
        '''
        gathered_params = self.gather_agent_params(pop_params, agent_indices)
        num_envs = agent_indices.squeeze().shape[0]
        rngs_batched = jax.random.split(rng, num_envs)
        vmapped_get_action = jax.vmap(partial(self.policy_cls.get_action,
                                              aux_obs=aux_obs,
                                              env_state=env_state,
                                              test_mode=test_mode))
        actions, new_hstate = vmapped_get_action(
            gathered_params, obs, done, hstate,
            rngs_batched)
        return actions, new_hstate

    def init_hstate(self, n: int, aux_info: dict = None):
        '''Initialize the hidden state for n members of the population.'''
        return self.policy_cls.init_hstate(n, aux_info)


class ActorWithConditionalCriticPolicy:
    """Policy wrapper for ActorWithConditionalCritic
    """

    def __init__(self, action_dim, obs_dim, pop_size, activation="tanh"):
        """
        Args:
            action_dim: int, dimension of the action space
            obs_dim: int, dimension of the observation space
            pop_size: int, number of agents in the population that the critic was trained with
            activation: str, activation function to use
        """
        # self.activation = activation
        self.action_dim = action_dim
        self.obs_dim = obs_dim
        self.pop_size = pop_size
        self.network = ActorWithConditionalCritic(
            action_dim, activation=activation)

    @partial(jax.jit, static_argnums=(0,))
    def get_action(self, params, obs, done, hstate, rng,
                   aux_obs=None, env_state=None, test_mode=False):
        """Get actions."""
        # The agent id is only used by the critic, so we pass in a
        # dummy vector to represent the one-hot agent id
        dummy_agent_id = jnp.zeros(obs.shape[:-1] + (self.pop_size,))
        pi, _ = self.network.apply(
            params, (obs, dummy_agent_id))
        action = jax.lax.cond(test_mode,
                              lambda: pi.mode(),
                              lambda: pi.sample(seed=rng))
        return action, None  # no hidden state

    @partial(jax.jit, static_argnums=(0,))
    def get_action_value_policy(self, params, obs, done, hstate, rng,
                                aux_obs=None, env_state=None):
        """Get actions, values, and policy for the policy with conditional critics.
        The auxiliary observation should be used to pass in the agent ids that we wish to predict
        values for.
        """
        pi, value = self.network.apply(params, (obs, aux_obs))
        action = pi.sample(seed=rng)
        return action, value, pi, None  # no hidden state

    def init_params(self, rng):
        """Initialize parameters for the policy with conditional critics."""
        dummy_obs = jnp.zeros((self.obs_dim,))
        dummy_ids = jnp.zeros((self.pop_size,))
        init_x = (dummy_obs, dummy_ids)
        return self.network.init(rng, init_x)

