import jax
import functools
import distrax
import flax.linen as nn
import jax.numpy as jnp
import numpy as np


from typing import Sequence
from flax.linen.initializers import constant, orthogonal


class ScannedRNN(nn.Module):
    @functools.partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        rnn_state = jnp.where(
            resets[:, np.newaxis],
            self.initialize_carry(ins.shape[0], ins.shape[1]),
            rnn_state,
        )
        new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
        return new_rnn_state, y

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        # Use a dummy key since the default state init fn is just zeros.
        cell = nn.GRUCell(features=hidden_size)
        return cell.initialize_carry(
            jax.random.PRNGKey(0), (batch_size, hidden_size)
        )


class ActorCriticRNN(nn.Module):
    action_dim: Sequence[int]

    gru_hidden_dim_size: int = 256
    fc_dim_size: int = 256

    embedding_layers: int = 1
    actor_layers: int = 4
    critic_layers: int = 4

    other_agent_prediction: bool = False
    env_has_avail_actions: bool = False

    use_layernorm: bool = False

    @nn.compact
    def __call__(self, hidden, x):
        if self.other_agent_prediction:
            if self.env_has_avail_actions:
                obs, dones, past_5_sa_pairs, avail_actions = x

            else:
                obs, dones, past_5_sa_pairs = x[0], x[1], x[2]
        else:
            obs, dones = x[0], x[1]

        embedding = obs

        embedding = nn.Dense(
            self.fc_dim_size * 2, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(embedding)
        if self.use_layernorm:
            embedding = nn.LayerNorm()(embedding)
        embedding = nn.relu(embedding)

        for _ in range(self.embedding_layers):
            embedding = nn.Dense(
                self.gru_hidden_dim_size, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
            )(embedding)
            if self.use_layernorm:
                embedding = nn.LayerNorm()(embedding)
            embedding = nn.relu(embedding)

        rnn_in = (embedding, dones)
        hidden, embedding = ScannedRNN()(hidden, rnn_in)

        #########
        # Model of other agent
        #########
        if self.other_agent_prediction:
            b_size, num_time = past_5_sa_pairs['action'].shape[0], past_5_sa_pairs['action'].shape[1]

            other_graph_vals = nn.Dense(self.fc_dim_size, kernel_init=orthogonal(np.sqrt(
                2)), bias_init=constant(0.0))(past_5_sa_pairs['obs'].reshape((b_size, num_time, -1)))
            other_graph_vals = nn.relu(other_graph_vals)

            embeddings = nn.Embed(num_embeddings=self.action_dim, features=self.fc_dim_size)(
                past_5_sa_pairs['action'].astype(jnp.int32))
            # remove extraneous dimension
            embeddings = embeddings.reshape((b_size, num_time, -1))
            # concatenate along feature dimension
            other_actor_mean = jnp.concatenate(
                [other_graph_vals, embeddings], axis=-1)

            ####
            prediction_other = nn.Dense(64, kernel_init=orthogonal(
                np.sqrt(2)), bias_init=constant(0.0))(other_actor_mean)
            if self.use_layernorm:
                prediction_other = nn.LayerNorm()(prediction_other)
            prediction_other = nn.leaky_relu(prediction_other)

            #####
            prediction_other = nn.Dense(64, kernel_init=orthogonal(
                np.sqrt(2)), bias_init=constant(0.0))(prediction_other)
            if self.use_layernorm:
                prediction_other = nn.LayerNorm()(prediction_other)
            prediction_other = nn.leaky_relu(prediction_other)

            #####
            prediction_other = nn.Dense(64, kernel_init=orthogonal(
                np.sqrt(2)), bias_init=constant(0.0))(prediction_other)
            if self.use_layernorm:
                prediction_other = nn.LayerNorm()(prediction_other)
            prediction_other = nn.leaky_relu(prediction_other)

            #####
            prediction_other = nn.Dense(64, kernel_init=orthogonal(
                np.sqrt(2)), bias_init=constant(0.0))(prediction_other)
            if self.use_layernorm:
                prediction_other = nn.LayerNorm()(prediction_other)
            prediction_other = nn.tanh(prediction_other)

            ####
            prediction_other = nn.Dense(self.action_dim, kernel_init=orthogonal(
                np.sqrt(2)), bias_init=constant(0.0))(prediction_other)

            prediction_other = prediction_other / \
                jnp.sqrt(jnp.sum(prediction_other**2, axis=-1,
                                 keepdims=True) + 1e-10)  # L2 normalization

            other_pi = distrax.Categorical(logits=prediction_other)
            actor_embedding = jnp.concatenate(
                [embedding, jax.lax.stop_gradient(prediction_other)], axis=-1)
        else:
            other_pi = distrax.Categorical(
                logits=jnp.zeros((self.action_dim,)))
            actor_embedding = embedding

        #########
        # Actor
        #########
        actor_mean = actor_embedding
        for _ in range(self.actor_layers):
            actor_mean = nn.Dense(self.fc_dim_size, kernel_init=orthogonal(2), bias_init=constant(0.0))(
                actor_mean
            )
            if self.use_layernorm:
                actor_mean = nn.LayerNorm()(actor_mean)
            actor_mean = nn.relu(actor_mean)

        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)

        if self.env_has_avail_actions:
            unavail_actions = 1 - avail_actions
            actor_mean = actor_mean - (unavail_actions * 1e10)

        pi = distrax.Categorical(logits=actor_mean)

        #########
        # Critic
        #########
        critic = embedding
        for _ in range(self.critic_layers):
            critic = nn.Dense(self.fc_dim_size, kernel_init=orthogonal(2), bias_init=constant(0.0))(
                critic
            )
            if self.use_layernorm:
                critic = nn.LayerNorm()(critic)
            critic = nn.relu(critic)

        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return hidden, pi, jnp.squeeze(critic, axis=-1), other_pi


class ActorWithConditionalCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"
    fc_hidden_dim: int = 64

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh

        if len(x) == 3:
            obs, teammate_id, avail_actions = x
        else:
            obs, teammate_id = x

        obs_with_teammate_id = jnp.concatenate([obs, teammate_id], axis=-1)
        actor_mean = nn.Dense(
            self.fc_hidden_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(obs)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.fc_hidden_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)

        if len(x) == 3:
            # Mask unavailable actions if avail_actions is provided
            unavail_actions = 1 - avail_actions
            actor_mean = actor_mean - (unavail_actions * 1e10)

        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            self.fc_hidden_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(obs_with_teammate_id)
        critic = activation(critic)
        critic = nn.Dense(
            self.fc_hidden_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)
