import functools
from typing import Any, Dict, Sequence

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict
from flax.training import train_state
from utils import Batch, target_update

from models import EnsembleActor, EnsembleDoubleCritic, EnsembleScalar


ACT_DELTA = 1e-3


class CEAgent:

    def __init__(
        self,
        ensemble_num: int,
        obs_dim: int,
        act_dim: int,
        gamma: float = 0.99,
        hid_dim: int = 256,
        max_action: float = 1.0,
        seed: int = 42,
        tau: float = 0.005,
        lr: float = 3e-4,
        grad_clip: float = 20.0,
        repeat_num: int = 5,
    ):

        self.tau = tau
        self.gamma = gamma
        self.obs_dim = obs_dim
        self.max_action = max_action
        self.ensemble_num = ensemble_num
        self.target_entropy = -act_dim / 2

        self.rng = jax.random.PRNGKey(seed)
        self.rng, actor_key, critic_key = jax.random.split(self.rng, 3)

        # Dummy inputs
        dummy_obs = jnp.ones([1, obs_dim], dtype=jnp.float32)
        dummy_act = jnp.ones([1, act_dim], dtype=jnp.float32)

        # Initialize the Actor
        self.actor = EnsembleActor(ensemble_num=ensemble_num,
                                   act_dim=act_dim,
                                   hid_dim=hid_dim,
                                   max_action=max_action)
        actor_params = self.actor.init(actor_key, dummy_obs)["params"]
        actor_tx = optax.chain(optax.clip_by_global_norm(grad_clip),
                               optax.adam(lr))
        self.actor_state = train_state.TrainState.create(
            apply_fn=self.actor.apply, params=actor_params, tx=actor_tx)

        # Initialize the Critic
        self.critic = EnsembleDoubleCritic(ensemble_num=ensemble_num)
        critic_params = self.critic.init(critic_key, dummy_obs,
                                         dummy_act)["params"]
        self.critic_target_params = critic_params
        critic_tx = optax.chain(optax.clip_by_global_norm(grad_clip),
                                optax.adam(lr))
        self.critic_state = train_state.TrainState.create(
            apply_fn=self.critic.apply, params=critic_params, tx=critic_tx)

        # Entropy tuning
        self.rng, alpha_key = jax.random.split(self.rng, 2)
        self.log_alpha = EnsembleScalar(np.zeros((ensemble_num, )))
        self.alpha_state = train_state.TrainState.create(
            apply_fn=None,
            params=self.log_alpha.init(alpha_key)["params"],
            tx=optax.adam(lr))

        # collaborative exploration
        self.dice_masks = jnp.arange(ensemble_num).reshape(-1, 1, 1)
        self.masks = np.ones((ensemble_num, ensemble_num)) - np.eye(ensemble_num)
        self.indices = np.arange(ensemble_num).reshape(-1, 1, 1)
        self.repeat_num = repeat_num

    @functools.partial(jax.jit, static_argnames=("self"))
    def _sample_softmax_eval_action(self,
                                    actor_params: FrozenDict,
                                    critic_params: FrozenDict,
                                    rng: Any,
                                    observation: np.ndarray):
        sampled_actions, distribution = self.actor.apply({"params": actor_params}, observation) 
        repeat_observations = jnp.repeat(observation, repeats=self.ensemble_num, axis=0)
        q1, q2 = self.critic.apply({"params": critic_params},
                                   repeat_observations,
                                   sampled_actions)
        q = (q1 + q2)/2
        q -= q.max()
        prob = jax.nn.softmax(q)
        idx = jax.random.choice(rng, jnp.arange(self.ensemble_num), p=prob)
        return sampled_actions, idx

    def sample_softmax_eval_action(self, observation, eval_mode: bool = False):
        self.rng, sample_rng = jax.random.split(self.rng)
        observation = observation.reshape(1, -1)
        actions, idx = self._sample_softmax_eval_action(self.actor_state.params,
                                                        self.critic_state.params,
                                                        sample_rng,
                                                        observation)
        actions = np.array(actions)
        return actions.clip(-self.max_action+ACT_DELTA, self.max_action-ACT_DELTA), idx.item()

    @functools.partial(jax.jit, static_argnames=("self"))
    def _sample_expl_action(self, params: FrozenDict, rng: Any,
                            observation: np.ndarray, mask: np.ndarray,
                            indice: np.ndarray):

        # each agent samples `repeat_num` actions
        _, distribution = self.actor.apply({"params": params}, observation)
        sampled_action, _ = distribution.sample_and_log_prob(
            seed=rng, sample_shape=(self.repeat_num))

        # (repeat_num, act_dim)
        candidate_action = jnp.take_along_axis(sampled_action,
                                               indices=indice,
                                               axis=1).squeeze(1)

        # compute mse
        def mse_fn(action):
            mse = jnp.square(sampled_action - action).sum(-1)
            return (mask * mse).mean()

        action_mse = jax.vmap(mse_fn, in_axes=(0, ))(candidate_action)
        action_idx = jnp.argmax(action_mse).reshape(1, 1)
        selected_action = jnp.take_along_axis(candidate_action,
                                              indices=action_idx,
                                              axis=0)
        return selected_action.squeeze()

    def sample_expl_action(self, observation: np.ndarray, idx: int):
        self.rng, sample_rng = jax.random.split(self.rng)
        if len(observation.shape) == 1:
            observation = observation.reshape(1, -1)
        action = self._sample_expl_action(self.actor_state.params, sample_rng,
                                          observation, self.masks[[idx]],
                                          self.indices[[idx]])
        action = np.asarray(action)
        return action.clip(-self.max_action+ACT_DELTA, self.max_action-ACT_DELTA)

    @functools.partial(jax.jit, static_argnames=("self"))
    def _sample_action(self, params: FrozenDict, rng: Any,
                       observation: np.ndarray) -> jnp.ndarray:
        mean_action, distribution = self.actor.apply({"params": params},
                                                     observation)
        sampled_action, _ = distribution.sample_and_log_prob(seed=rng)
        return mean_action * self.max_action, sampled_action * self.max_action

    def sample_action(self,
                      observation: np.ndarray,
                      eval_mode: bool = False) -> np.ndarray:
        self.rng, sample_rng = jax.random.split(self.rng)
        if len(observation.shape) == 1:
            observation = observation.reshape(1, -1)
        mean_action, sampled_action = self._sample_action(
            self.actor_state.params, sample_rng, observation)
        action = mean_action if eval_mode else sampled_action
        action = np.asarray(action)
        return action.clip(-self.max_action+ACT_DELTA, self.max_action-ACT_DELTA)

    def actor_alpha_train_step(self, key: Any, observations: jnp.ndarray,
                               alpha_state: train_state.TrainState,
                               actor_state: train_state.TrainState,
                               critic_state: train_state.TrainState):

        frozen_critic_params = critic_state.params

        def loss_fn(alpha_params: FrozenDict, actor_params: FrozenDict,
                    rng: Any, observation: jnp.ndarray):

            # (E, obs_dim) => (E, act_dim)
            _, dist = self.actor.apply({"params": actor_params}, observation)
            sampled_action, logp = dist.sample_and_log_prob(seed=rng)

            # compute alpha loss (E,)
            log_alpha = self.log_alpha.apply({"params": alpha_params})
            alpha = jnp.exp(log_alpha)
            alpha_loss = -alpha * jax.lax.stop_gradient(logp +
                                                        self.target_entropy)

            # stop alpha gradient (E,)
            alpha = jax.lax.stop_gradient(alpha)

            # (E, obs_dim) & (E, act_dim) => (E,) & (E, )
            sampled_q1, sampled_q2 = self.critic.apply(
                {"params": frozen_critic_params},
                observation,
                sampled_action,
            )
            sampled_q = jnp.minimum(sampled_q1, sampled_q2)

            # Actor loss
            actor_loss = alpha * logp - sampled_q

            # return info
            actor_alpha_loss = (actor_loss + alpha_loss).sum()
            log_info = {
                "actor_loss": actor_loss.mean(),
                "alpha_loss": alpha_loss.mean(),
                "alpha": alpha,
                "logp": logp.mean(),
            }
            return actor_alpha_loss, log_info

        # compute gradient with vmap
        grad_fn = jax.vmap(jax.value_and_grad(loss_fn,
                                              argnums=(0, 1),
                                              has_aux=True),
                           in_axes=(None, None, 0, 1))
        keys = jnp.stack(jax.random.split(key, num=observations.shape[1]))

        (_, log_info), grads = grad_fn(alpha_state.params, actor_state.params,
                                       keys, observations)
        grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0),
                                       grads)
        log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0),
                                          log_info)

        # Update TrainState
        alpha_grads, actor_grads = grads
        new_alpha_state = alpha_state.apply_gradients(grads=alpha_grads)
        new_actor_state = actor_state.apply_gradients(grads=actor_grads)
        return new_alpha_state, new_actor_state, log_info

    def critic_train_step(self, batch: Batch, intrinsic_rewards: jnp.ndarray,
                          key: Any, alphas: float,
                          actor_state: train_state.TrainState,
                          critic_state: train_state.TrainState,
                          critic_target_params: FrozenDict):

        frozen_actor_params = actor_state.params

        def loss_fn(params: FrozenDict, rng: Any, observation: jnp.ndarray,
                    action: jnp.ndarray, reward: jnp.ndarray,
                    intrinsic_reward: jnp.ndarray,
                    next_observation: jnp.ndarray, discount: jnp.ndarray):

            # (E, obs_dim) & (E, act_dim) ==> (E,)
            q1, q2 = self.critic.apply({"params": params}, observation, action)

            # (E, obs_dim) ==> (E, act_dim), (E,)
            _, next_dist = self.actor.apply({"params": frozen_actor_params},
                                            next_observation)
            next_action, logp_next_action = next_dist.sample_and_log_prob(
                seed=rng)
            next_q1, next_q2 = self.critic.apply(
                {"params": critic_target_params}, next_observation,
                next_action)
            next_q = jnp.minimum(next_q1, next_q2) - alphas * logp_next_action

            # target q value (E,)
            target_q = self.gamma * discount * next_q + (reward +
                                                         intrinsic_reward)

            # td error (E,)
            critic_loss1 = (q1 - target_q)**2
            critic_loss2 = (q2 - target_q)**2

            # Rescale to sum over the population
            critic_loss = critic_loss1 + critic_loss2
            log_info = {
                "critic_loss": critic_loss,
                "qs": q1,
            }

            return critic_loss.mean(), log_info

        # compute gradient using vmap w.r.t. batch axis
        grad_fn = jax.vmap(jax.value_and_grad(loss_fn, has_aux=True),
                           in_axes=(None, 0, 1, 1, 1, 1, 1, 1))
        keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[1]))

        (_, log_info), grads = grad_fn(critic_state.params, keys,
                                       batch.observations, batch.actions,
                                       batch.rewards, intrinsic_rewards,
                                       batch.next_observations,
                                       batch.discounts)
        grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0),
                                       grads)
        log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0),
                                          log_info)

        # Update TrainState
        new_critic_state = critic_state.apply_gradients(grads=grads)
        new_critic_target_params = target_update(new_critic_state.params,
                                                 critic_target_params,
                                                 self.tau)
        return new_critic_state, new_critic_target_params, log_info

    @functools.partial(jax.jit, static_argnames=("self"))
    def train_step(self, batch: Batch, intrinsic_rewards: jnp.ndarray,
                   key: Any, alpha_state: train_state.TrainState,
                   actor_state: train_state.TrainState,
                   critic_state: train_state.TrainState,
                   critic_target_params: FrozenDict):
        key1, key2 = jax.random.split(key)
        (new_alpha_state,
         new_actor_state, actor_log_info) = self.actor_alpha_train_step(
             key1, batch.observations, alpha_state, actor_state, critic_state)
        (new_critic_state, new_critic_target_params,
         critic_log_info) = self.critic_train_step(batch, intrinsic_rewards,
                                                   key2,
                                                   actor_log_info["alpha"],
                                                   actor_state, critic_state,
                                                   critic_target_params)
        log_info = {**actor_log_info, **critic_log_info}
        return new_alpha_state, new_actor_state, new_critic_state, new_critic_target_params, log_info

    def update(self, batch: Batch, intrinsic_rewards: jnp.ndarray):
        self.rng, key = jax.random.split(self.rng, 2)
        (self.alpha_state, self.actor_state, self.critic_state,
         self.critic_target_params, log_info) = self.train_step(
             batch, intrinsic_rewards, key, self.alpha_state, self.actor_state,
             self.critic_state, self.critic_target_params)
        return log_info
