from functools import partial
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

from issac.algorithms.architectures.simbav2 import (
    SimbaV2CriticNet,
    SimbaV2ActorNet,
    update_g_g_max_stats,
    weight_normalization,
    update_mean_var_stats,
)
from issac.sample_collection.replay_buffer import ReplayBuffer, ReplayElement


class SimbaV2:
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        action_dim,
        learning_rate_init: float,
        learning_rate_end: float,
        learning_rate_decay_steps: float,
        gamma: float,
        update_horizon: int,
        tau: float,
        batch_norm: bool,
    ):
        actor_key, critic_key = jax.random.split(key)

        obs = jnp.zeros(observation_dim, dtype=jnp.float32)
        action = jnp.zeros(action_dim, dtype=jnp.float32)

        # Critic
        self.critic = SimbaV2CriticNet(batch_norm)
        self.critic_params = self.critic.init(critic_key, obs, action)
        self.critic_params["params"] = weight_normalization(self.critic_params["params"])
        self.critic_target_params = self.critic_params.copy()
        self.critic_optimizer = optax.adam(
            optax.linear_schedule(learning_rate_init, learning_rate_end, learning_rate_decay_steps)
        )
        self.critic_optimizer_state = self.critic_optimizer.init(self.critic_params)

        # Actor
        self.actor = SimbaV2ActorNet(action_dim)
        self.actor_params = self.actor.init(actor_key, obs, jax.random.PRNGKey(0))
        self.actor_params["params"] = weight_normalization(self.actor_params["params"])
        self.actor_optimizer = optax.adam(
            optax.linear_schedule(learning_rate_init, learning_rate_end, learning_rate_decay_steps)
        )
        self.actor_optimizer_state = self.actor_optimizer.init(self.actor_params)

        # Entropy coefficient
        self.log_ent_coef = jnp.log(jnp.asarray(0.01, dtype=jnp.float32))
        self.entropy_optimizer = optax.adam(
            optax.linear_schedule(learning_rate_init, learning_rate_end, learning_rate_decay_steps)
        )
        self.entropy_optimizer_state = self.entropy_optimizer.init(self.log_ent_coef)
        self.target_entropy = -np.float32(action_dim) / 2

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.tau = tau

        self.cumulated_critic_loss = 0
        self.cumulated_actor_loss = 0
        self.cumulated_entropy_loss = 0

    def update_online_params(self, replay_buffer: ReplayBuffer, key):
        batch_samples = replay_buffer.sample()

        (
            self.critic_params,
            self.critic_target_params,
            self.actor_params,
            self.log_ent_coef,
            self.critic_optimizer_state,
            self.actor_optimizer_state,
            self.entropy_optimizer_state,
            self.cumulated_critic_loss,
            self.cumulated_actor_loss,
            self.cumulated_entropy_loss,
        ) = self.learn_on_batch(
            self.critic_params,
            self.critic_target_params,
            self.actor_params,
            self.log_ent_coef,
            self.critic_optimizer_state,
            self.actor_optimizer_state,
            self.entropy_optimizer_state,
            self.cumulated_critic_loss,
            self.cumulated_actor_loss,
            self.cumulated_entropy_loss,
            batch_samples,
            key,
        )

    @partial(jax.jit, static_argnames="self")
    def learn_on_batch(
        self,
        critic_params,
        critic_target_params,
        actor_params,
        log_ent_coef,
        critic_optimizer_state,
        actor_optimizer_state,
        entropy_optimizer_state,
        cumulated_critic_loss,
        cumulated_actor_loss,
        cumulated_entropy_loss,
        batch_samples,
        key,
    ):
        critic_key, actor_key = jax.random.split(key, 2)

        # Update critic
        (critic_loss, batch_stats), critic_grads = jax.value_and_grad(self.critic_loss_on_batch, has_aux=True)(
            critic_params, critic_target_params, actor_params, log_ent_coef, batch_samples, critic_key
        )
        critic_updates, critic_optimizer_state = self.critic_optimizer.update(critic_grads, critic_optimizer_state)
        critic_params = optax.apply_updates(critic_params, critic_updates)
        critic_params["params"] = weight_normalization(critic_params["params"])

        # Update batch_stats
        if "batch_stats" in batch_stats:
            critic_params["batch_stats"] = batch_stats["batch_stats"]
            critic_target_params["batch_stats"] = batch_stats["batch_stats"]

        # Update actor
        (actor_loss, entropy), actor_grads = jax.value_and_grad(self.actor_loss_on_batch, has_aux=True)(
            actor_params, critic_params, log_ent_coef, batch_samples, actor_key
        )
        actor_updates, actor_optimizer_state = self.actor_optimizer.update(actor_grads, actor_optimizer_state)
        actor_params = optax.apply_updates(actor_params, actor_updates)
        actor_params["params"] = weight_normalization(actor_params["params"])

        # Update entropy coefficient
        entropy_loss, entropy_grads = jax.value_and_grad(self.entropy_loss)(log_ent_coef, entropy)
        entropy_updates, entropy_optimizer_state = self.entropy_optimizer.update(
            entropy_grads, entropy_optimizer_state, log_ent_coef
        )
        log_ent_coef = optax.apply_updates(log_ent_coef, entropy_updates)

        # Update critic target
        critic_target_params["params"] = optax.incremental_update(
            critic_params["params"], critic_target_params["params"], self.tau
        )

        cumulated_critic_loss = (1 - self.tau) * cumulated_critic_loss + self.tau * critic_loss
        cumulated_actor_loss = (1 - self.tau) * cumulated_actor_loss + self.tau * actor_loss
        cumulated_entropy_loss = (1 - self.tau) * cumulated_entropy_loss + self.tau * entropy_loss

        return (
            critic_params,
            critic_target_params,
            actor_params,
            log_ent_coef,
            critic_optimizer_state,
            actor_optimizer_state,
            entropy_optimizer_state,
            cumulated_critic_loss,
            cumulated_actor_loss,
            cumulated_entropy_loss,
        )

    def critic_loss_on_batch(self, critic_params, critic_target_params, actor_params, log_ent_coef, samples, key):
        next_actions, next_log_probs = self.actor.apply(actor_params, samples.next_state, key)
        batch_size = samples.state.shape[0]

        # shape (2 x batch_size, 1, n_bins) | (batch_stats)
        all_logits_q_values_, batch_stats = partial(self.critic.apply, mutable=["batch_stats"])(
            critic_params,
            jnp.concatenate((samples.state, samples.next_state)),
            jnp.concatenate((samples.action, next_actions)),
        )
        all_logits_q_values = all_logits_q_values_.squeeze(axis=1)
        logit_q_value = all_logits_q_values[:batch_size]
        logit_next_q_value = all_logits_q_values[batch_size:]

        return (
            jax.vmap(self.critic_loss, in_axes=(None, 0, 0, 0, 0, None))(
                critic_target_params, samples, logit_q_value, logit_next_q_value, next_log_probs, log_ent_coef
            ).mean(),
            batch_stats,
        )

    def critic_loss(self, target_params, sample, logit_q_value, logit_next_q_value, next_log_prob, log_ent_coef):
        next_prob_q_value = nn.softmax(logit_next_q_value, axis=-1)
        log_prob_q_value = nn.log_softmax(logit_q_value, axis=-1)

        target_location_ = self.compute_target(
            target_params, sample, self.critic.bins, jnp.exp(log_ent_coef), next_log_prob
        )
        target_location = jnp.clip(target_location_, self.critic.bins[0], self.critic.bins[-1])

        def projection(bin_location):
            # Distance to bin location. shape (batch_size, n_bins)
            distances_to_bin = jnp.abs(target_location - bin_location) / (self.critic.bins[1] - self.critic.bins[0])
            # Clip the maximum distance to 1 to weight the probabilities. shape (batch_size, n_bellman_iterations)
            return jnp.sum((1 - jnp.minimum(distances_to_bin, 1)) * next_prob_q_value, axis=0)

        target_prob = jax.vmap(projection, out_axes=0)(self.critic.bins)
        loss = -jnp.sum(target_prob * log_prob_q_value, axis=0)

        return loss

    def compute_target(
        self, params, samples: ReplayElement, next_q_values: jax.Array, ent_coef: float, next_log_probs: jax.Array
    ):
        normalized_reward = self.critic.apply(params, samples.reward, method=self.critic.normalize_reward)
        return normalized_reward + (jnp.int32(1) - samples.is_terminal) * (self.gamma**self.update_horizon) * (
            next_q_values - ent_coef * next_log_probs
        )

    def actor_loss_on_batch(self, actor_params, critic_params, log_ent_coef, samples, key):
        actions, log_probs = self.actor.apply(actor_params, samples.state, key)

        # shape (batch_size, n_bins)
        q_bins_logits = self.critic.apply(critic_params, samples.state, actions, use_running_average=True).squeeze(
            axis=1
        )
        q_bins_probs = nn.softmax(q_bins_logits, axis=-1)
        q_values = q_bins_probs @ self.critic.bins

        losses = jnp.exp(log_ent_coef) * log_probs - q_values
        return losses.mean(), -log_probs.mean()

    def entropy_loss(self, log_ent_coef, entropy):
        return jnp.exp(log_ent_coef) * (entropy - self.target_entropy)

    @partial(jax.jit, static_argnames="self")
    def sample_action(self, state, actor_params, key):
        # only return the action
        return self.actor.apply(actor_params, state, key)[0]

    def update_observation_statistics(self, state):
        self.critic_params["running_obs_stats"]["RSObservationNorm_0"] = update_mean_var_stats(
            state.squeeze(), self.critic_params["running_obs_stats"]["RSObservationNorm_0"]
        )
        self.critic_target_params["running_obs_stats"] = self.critic_params["running_obs_stats"]
        self.actor_params["running_obs_stats"] = self.critic_params["running_obs_stats"]

    def update_reward_statistics(self, reward, episode_end):
        # Update discounted return statistics
        self.critic_params["running_reward_stats"].update(
            update_g_g_max_stats(reward, episode_end, self.gamma, self.critic_params["running_reward_stats"])
        )
        # Update variance statistics
        self.critic_params["running_reward_stats"].update(
            update_mean_var_stats(
                self.critic_params["running_reward_stats"]["G"], self.critic_params["running_reward_stats"]
            ),
        )
        self.critic_target_params["running_reward_stats"] = self.critic_params["running_reward_stats"]

    def get_logs(self):
        return {
            "train/critic_loss": self.cumulated_critic_loss,
            "train/actor_loss": self.cumulated_actor_loss,
            "train/entropy_loss": self.cumulated_entropy_loss,
            "train/entropy_coef": np.exp(self.log_ent_coef),
        }

    def get_model(self):
        return {"critic": self.critic_params, "actor": self.actor_params, "log_ent_coef": self.log_ent_coef}
