from functools import partial
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

from issac.algorithms.architectures.simbav2 import (
    SimbaV2CriticNet,
    SimbaV2ActorNet,
    update_g_g_max_stats,
    normalize_matrix,
    update_mean_var_stats,
)
from issac.sample_collection.replay_buffer import ReplayBuffer, ReplayElement


class iSSimbaV2:
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        action_dim,
        learning_rate_init: float,
        learning_rate_end: float,
        learning_rate_decay_steps: float,
        gamma: float,
        update_horizon: int,
        tau: float,
        batch_norm: bool,
        n_bellman_iterations: int,
        loss_discount: float,
    ):
        actor_key, critic_key = jax.random.split(key)

        obs = jnp.zeros(observation_dim, dtype=jnp.float32)
        action = jnp.zeros(action_dim, dtype=jnp.float32)

        # Critic
        self.critic = SimbaV2CriticNet(batch_norm, n_bellman_iterations + 1)
        self.critic_params = self.critic.init(critic_key, obs, action)
        self.critic_params["params"] = self.weight_normalization(self.critic_params["params"])
        self.critic_optimizer = optax.adam(
            optax.linear_schedule(learning_rate_init, learning_rate_end, learning_rate_decay_steps)
        )
        self.critic_optimizer_state = self.critic_optimizer.init(self.critic_params)

        # Actor
        self.actor = SimbaV2ActorNet(action_dim)
        self.actor_params = self.actor.init(actor_key, obs, jax.random.PRNGKey(0))
        self.actor_params["params"] = self.weight_normalization(self.actor_params["params"])
        self.actor_optimizer = optax.adam(
            optax.linear_schedule(learning_rate_init, learning_rate_end, learning_rate_decay_steps)
        )
        self.actor_optimizer_state = self.actor_optimizer.init(self.actor_params)

        # Entropy coefficient
        self.log_ent_coef = jnp.log(jnp.asarray(0.01, dtype=jnp.float32))
        self.entropy_optimizer = optax.adam(
            optax.linear_schedule(learning_rate_init, learning_rate_end, learning_rate_decay_steps)
        )
        self.entropy_optimizer_state = self.entropy_optimizer.init(self.log_ent_coef)
        self.target_entropy = -np.float32(action_dim) / 2

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.tau = tau

        self.cumulated_critic_losses = np.zeros(n_bellman_iterations, dtype=np.float32)
        self.cumulated_actor_loss = 0
        self.cumulated_entropy_loss = 0

        self.loss_discounts = jnp.power(loss_discount, jnp.arange(n_bellman_iterations)).astype(jnp.float32)
        self.normalized_loss_discounts = self.loss_discounts / self.loss_discounts.sum()

    def update_online_params(self, replay_buffer: ReplayBuffer, key):
        batch_samples = replay_buffer.sample()

        (
            self.critic_params,
            self.actor_params,
            self.log_ent_coef,
            self.critic_optimizer_state,
            self.actor_optimizer_state,
            self.entropy_optimizer_state,
            self.cumulated_critic_losses,
            self.cumulated_actor_loss,
            self.cumulated_entropy_loss,
        ) = self.learn_on_batch(
            self.critic_params,
            self.actor_params,
            self.log_ent_coef,
            self.critic_optimizer_state,
            self.actor_optimizer_state,
            self.entropy_optimizer_state,
            self.cumulated_critic_losses,
            self.cumulated_actor_loss,
            self.cumulated_entropy_loss,
            batch_samples,
            key,
        )

    @partial(jax.jit, static_argnames="self")
    def learn_on_batch(
        self,
        critic_params,
        actor_params,
        log_ent_coef,
        critic_optimizer_state,
        actor_optimizer_state,
        entropy_optimizer_state,
        cumulated_critic_losses,
        cumulated_actor_loss,
        cumulated_entropy_loss,
        batch_samples,
        key,
    ):
        critic_key, actor_key = jax.random.split(key, 2)

        # Update critic
        (critic_losses, batch_stats), critic_grads = jax.value_and_grad(self.critic_loss_on_batch, has_aux=True)(
            critic_params, actor_params, log_ent_coef, batch_samples, critic_key
        )
        critic_updates, critic_optimizer_state = self.critic_optimizer.update(critic_grads, critic_optimizer_state)
        critic_params = optax.apply_updates(critic_params, critic_updates)
        critic_params["params"] = self.weight_normalization(critic_params["params"])

        # Update batch_stats
        if "batch_stats" in batch_stats:
            critic_params["batch_stats"] = batch_stats["batch_stats"]

        # Update actor
        (actor_loss, entropy), actor_grads = jax.value_and_grad(self.actor_loss_on_batch, has_aux=True)(
            actor_params, critic_params, log_ent_coef, batch_samples, actor_key
        )
        actor_updates, actor_optimizer_state = self.actor_optimizer.update(actor_grads, actor_optimizer_state)
        actor_params = optax.apply_updates(actor_params, actor_updates)
        actor_params["params"] = self.weight_normalization(actor_params["params"])

        # Update entropy coefficient
        entropy_loss, entropy_grads = jax.value_and_grad(self.entropy_loss)(log_ent_coef, entropy)
        entropy_updates, entropy_optimizer_state = self.entropy_optimizer.update(
            entropy_grads, entropy_optimizer_state, log_ent_coef
        )
        log_ent_coef = optax.apply_updates(log_ent_coef, entropy_updates)

        # Update critic target head
        critic_params = self.soft_shift_params(critic_params)

        cumulated_critic_losses = (1 - self.tau) * cumulated_critic_losses + self.tau * critic_losses
        cumulated_actor_loss = (1 - self.tau) * cumulated_actor_loss + self.tau * actor_loss
        cumulated_entropy_loss = (1 - self.tau) * cumulated_entropy_loss + self.tau * entropy_loss

        return (
            critic_params,
            actor_params,
            log_ent_coef,
            critic_optimizer_state,
            actor_optimizer_state,
            entropy_optimizer_state,
            cumulated_critic_losses,
            cumulated_actor_loss,
            cumulated_entropy_loss,
        )

    def critic_loss_on_batch(self, critic_params, actor_params, log_ent_coef, samples, key):
        next_actions, next_log_probs = self.actor.apply(actor_params, samples.next_state, key)
        batch_size = samples.state.shape[0]

        # shape (2 x batch_size, 1, n_bins) | (batch_stats)
        all_logits_q_values, batch_stats = partial(self.critic.apply, mutable=["batch_stats"])(
            critic_params,
            jnp.concatenate((samples.state, samples.next_state)),
            jnp.concatenate((samples.action, next_actions)),
        )
        logit_q_value = all_logits_q_values[:batch_size]
        logit_next_q_value = all_logits_q_values[batch_size:]

        losses = jax.vmap(self.critic_loss, in_axes=(None, 0, 0, 0, 0, None))(
            critic_params, samples, logit_q_value, logit_next_q_value, next_log_probs, log_ent_coef
        ).mean(axis=0)

        return jnp.dot(losses, self.loss_discounts), batch_stats

    def critic_loss(self, params, sample, logit_q_value, logit_next_q_value, next_log_prob, log_ent_coef):
        next_prob_q_value = nn.softmax(logit_next_q_value[:-1], axis=-1)
        log_prob_q_value = nn.log_softmax(logit_q_value[1:], axis=-1)

        target_location_ = self.compute_target(params, sample, self.critic.bins, jnp.exp(log_ent_coef), next_log_prob)
        target_location = jnp.clip(target_location_, self.critic.bins[0], self.critic.bins[-1])

        def projection(bin_location):
            # Distance to bin location. shape (batch_size, n_bins)
            distances_to_bin = jnp.abs(target_location - bin_location) / (self.critic.bins[1] - self.critic.bins[0])
            # (batch_size, n_bellman_iterations, n_bins)
            repeated_distances_to_bin = jnp.repeat(distances_to_bin[jnp.newaxis, :], self.critic.n_heads - 1, axis=0)
            # Clip the maximum distance to 1 to weight the probabilities. shape (batch_size, n_bellman_iterations)
            return jnp.sum((1 - jnp.minimum(repeated_distances_to_bin, 1)) * next_prob_q_value, axis=1)

        target_prob = jax.lax.stop_gradient(jax.vmap(projection, out_axes=1)(self.critic.bins))
        loss = -jnp.sum(target_prob * log_prob_q_value, axis=1)

        return loss

    def compute_target(
        self, params, samples: ReplayElement, next_q_values: jax.Array, ent_coef: float, next_log_probs: jax.Array
    ):
        normalized_reward = self.critic.apply(params, samples.reward, method=self.critic.normalize_reward)
        return normalized_reward + (jnp.int32(1) - samples.is_terminal) * (self.gamma**self.update_horizon) * (
            next_q_values - ent_coef * next_log_probs
        )

    def actor_loss_on_batch(self, actor_params, critic_params, log_ent_coef, samples, key):
        actions, log_probs = self.actor.apply(actor_params, samples.state, key)

        # shape (batch_size, n_bellman_iterations + 1, n_bins)
        q_bins_logits = self.critic.apply(critic_params, samples.state, actions, use_running_average=True)
        q_bins_probs = nn.softmax(q_bins_logits, axis=-1)
        q_values = q_bins_probs[:, 1:] @ self.critic.bins

        losses = jnp.exp(log_ent_coef) * log_probs - jnp.dot(q_values, self.normalized_loss_discounts)
        return losses.mean(), -log_probs.mean()

    def entropy_loss(self, log_ent_coef, entropy):
        return jnp.exp(log_ent_coef) * (entropy - self.target_entropy)

    def soft_shift_params(self, params):
        # Shift the last weight matrix with shape (last_feature, (1 + n_bellman_iterations) x n_bins)
        # Reminder: 1 + self.n_bellman_iterations = [\bar{Q_0}, Q_1, ..., Q_K]
        # Here we shifting: \bar{Q_0} <- \tau Q_1 + (1 - \tau) \bar{Q_0}
        kernel = params["params"]["last_layer"]["kernel"]
        params["params"]["last_layer"]["kernel"] = kernel.at[:, : self.critic.n_bins].set(
            optax.incremental_update(
                kernel[:, self.critic.n_bins : 2 * self.critic.n_bins], kernel[:, : self.critic.n_bins], self.tau
            )
        )

        # Shift the last bias vector with shape (1 + n_bellman_iterations x n_bins)
        bias = params["params"]["last_layer"]["bias"]
        params["params"]["last_layer"]["bias"] = bias.at[: self.critic.n_bins].set(
            optax.incremental_update(
                bias[self.critic.n_bins : 2 * self.critic.n_bins], bias[: self.critic.n_bins], self.tau
            )
        )

        return params

    @partial(jax.jit, static_argnames="self")
    def sample_action(self, state, actor_params, key):
        # only return the action
        return self.actor.apply(actor_params, state, key)[0]

    def update_observation_statistics(self, state):
        self.critic_params["running_obs_stats"]["RSObservationNorm_0"] = update_mean_var_stats(
            state.squeeze(), self.critic_params["running_obs_stats"]["RSObservationNorm_0"]
        )
        self.actor_params["running_obs_stats"] = self.critic_params["running_obs_stats"]

    def update_reward_statistics(self, reward, episode_end):
        # Update discounted return statistics
        self.critic_params["running_reward_stats"].update(
            update_g_g_max_stats(reward, episode_end, self.gamma, self.critic_params["running_reward_stats"])
        )
        # Update variance statistics
        self.critic_params["running_reward_stats"].update(
            update_mean_var_stats(
                self.critic_params["running_reward_stats"]["G"], self.critic_params["running_reward_stats"]
            ),
        )

    def get_logs(self):
        logs = {
            "train/critic_loss": self.cumulated_critic_losses.mean(),
            "train/actor_loss": self.cumulated_actor_loss,
            "train/entropy_loss": self.cumulated_entropy_loss,
            "train/entropy_coef": np.exp(self.log_ent_coef),
        }
        for idx_network in range(min(len(self.cumulated_critic_losses), 5)):
            logs[f"networks/{idx_network}_critic_loss"] = self.cumulated_critic_losses[idx_network]
        return logs

    def get_model(self):
        return {"critic": self.critic_params, "actor": self.actor_params, "log_ent_coef": self.log_ent_coef}

    def weight_normalization(self, params):
        for layer_name in params.keys():
            if layer_name.startswith(("Dense")):
                # Do not normalize bias
                params[layer_name] = jax.tree.map(normalize_matrix, params[layer_name])
            if layer_name.startswith(("last_layer")):
                # Normalize heads independently
                params[layer_name] = jax.tree.map(
                    lambda w: (
                        jax.vmap(normalize_matrix, in_axes=1, out_axes=1)(
                            w.reshape((-1, self.critic.n_heads, self.critic.n_bins))
                        ).reshape((-1, self.critic.n_heads * self.critic.n_bins))
                        if w.ndim == 2
                        else w
                    ),
                    params[layer_name],
                )

        return params
