from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict

from slimdqn.networks.architectures.dqn import DQNNet
from slimdqn.sample_collection.fixed_replay_buffer import FixedReplayBuffer
from slimdqn.sample_collection.replay_buffer import ReplayElement


class eSCQL:
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        n_actions,
        n_ensemble_heads: int,
        features: list,
        layer_norm: bool,
        batch_norm: bool,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        target_update_frequency: int,
        alpha_cql: float,
        adam_eps: float = 0.0003125,
    ):
        self.n_ensemble_heads = n_ensemble_heads
        self.n_actions = n_actions
        self.last_idx_mlp = len(features) if architecture_type == "fc" else len(features) - 3
        self.network = DQNNet(
            features, architecture_type, 2 * self.n_ensemble_heads * n_actions, layer_norm, batch_norm
        )

        # 2 * self.n_ensemble_heads = [\bar{Q_0}, ..., \bar{Q_K-1}, Q_1, ..., Q_K]
        def apply(params, state):
            q_values, batch_stats = self.network.apply(params, state, mutable=["batch_stats"])
            return q_values.reshape((-1, 2 * self.n_ensemble_heads, n_actions)), batch_stats

        self.network.apply_fn = apply
        self.params = self.network.init(key, jnp.zeros(observation_dim, dtype=jnp.float32))

        self.optimizer = optax.adam(learning_rate, eps=adam_eps)
        self.optimizer_state = self.optimizer.init(self.params)

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.target_update_frequency = target_update_frequency
        self.cumulated_losses = np.zeros((self.n_ensemble_heads, 2))  # one entry each for TD and BC component
        self.alpha_cql = alpha_cql

    @partial(jax.jit, static_argnames="self")
    def apply_multiple_updates(self, params, optimizer_state, batches):
        def apply_single_update(state, batch):
            params, optimizer_state, loss = self.learn_on_batch(state[0], state[1], batch)
            return (params, optimizer_state), loss

        # Convert the list of batch to a list single batch where each element
        # has the shape (n_batch, batch_size) + (element_shape,)
        batches = jax.tree.map(lambda *batch: jnp.stack(batch), *batches)
        (final_params, final_optimizer_state), loss = jax.lax.scan(
            apply_single_update, (params, optimizer_state), batches
        )
        return final_params, final_optimizer_state, loss.sum(axis=0)

    def n_updates_online_params(self, n_updates: int, replay_buffer: FixedReplayBuffer):
        batches = replay_buffer.sample(n_updates)
        self.params, self.optimizer_state, losses = self.apply_multiple_updates(
            self.params, self.optimizer_state, batches
        )
        self.cumulated_losses += losses

    def update_target_params(self, **kwargs):
        # Window shift
        self.params = self.shift_params(self.params)

        logs = {
            "td_loss": self.cumulated_losses[:, 0].mean() / self.target_update_frequency,
            "bc_loss": self.alpha_cql * self.cumulated_losses[:, 1].mean() / self.target_update_frequency,
        }
        for idx_network in range(min(self.n_ensemble_heads, 5)):
            logs[f"networks/{idx_network}_loss"] = (
                self.cumulated_losses[idx_network, 0] + self.alpha_cql * self.cumulated_losses[idx_network, 1]
            ) / self.target_update_frequency

        self.cumulated_losses = np.zeros_like(self.cumulated_losses)

        return logs

    @partial(jax.jit, static_argnames="self")
    def learn_on_batch(self, params: FrozenDict, optimizer_state, batch_samples):
        grad_loss, (losses, batch_stats) = jax.grad(self.loss_on_batch, has_aux=True)(params, batch_samples)
        updates, optimizer_state = self.optimizer.update(grad_loss, optimizer_state)
        params = optax.apply_updates(params, updates)
        if self.network.batch_norm:
            params["batch_stats"] = batch_stats["batch_stats"]

        return params, optimizer_state, losses

    def loss_on_batch(self, params: FrozenDict, samples):
        batch_size = samples.state.shape[0]
        # shape (2 * batch_size, 2 * n_ensemble_heads, n_actions) | Dict
        all_q_values, batch_stats = self.network.apply_fn(params, jnp.concatenate((samples.state, samples.next_state)))
        # shape (batch_size, n_ensemble_heads)
        q_values = jax.vmap(lambda q_value, action: q_value[:, action])(
            all_q_values[:batch_size, self.n_ensemble_heads :], samples.action
        )
        targets = jax.vmap(self.compute_target)(samples, all_q_values[batch_size:, : self.n_ensemble_heads])
        stop_grad_targets = jax.lax.stop_gradient(targets)

        # shape (batch_size, n_ensemble_heads)
        td_losses = jnp.square(q_values - stop_grad_targets)
        bc_losses = jax.scipy.special.logsumexp(all_q_values[:batch_size, self.n_ensemble_heads :], axis=-1) - q_values
        return (td_losses + self.alpha_cql * bc_losses).mean(axis=0).sum(), (
            jnp.array([td_losses.mean(axis=0), bc_losses.mean(axis=0)]).T,
            batch_stats,
        )

    def compute_target(self, sample: ReplayElement, next_q_values: jax.Array):
        # shape of next_q_values (n_ensemble_heads, next_states, n_actions)
        return sample.reward + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * jnp.max(
            next_q_values, axis=-1
        )

    @partial(jax.jit, static_argnames="self")
    def shift_params(self, params):
        # Shift the last weight matrix with shape (last_feature, 2 x n_ensemble_heads x n_actions)
        # Reminder: 2 * self.n_ensemble_heads = [\bar{Q_0}, ..., \bar{Q_K-1}, Q_1, ..., Q_K]
        # Here we shifting: \bar{Q_i} <- Q_i+1
        kernel = params["params"][f"Dense_{self.last_idx_mlp}"]["kernel"]
        params["params"][f"Dense_{self.last_idx_mlp}"]["kernel"] = kernel.at[
            :, : self.n_ensemble_heads * self.n_actions
        ].set(kernel[:, self.n_ensemble_heads * self.n_actions :])

        # Shift the last bias vector with shape (2 x n_ensemble_heads x n_actions)
        bias = params["params"][f"Dense_{self.last_idx_mlp}"]["bias"]
        params["params"][f"Dense_{self.last_idx_mlp}"]["bias"] = bias.at[: self.n_ensemble_heads * self.n_actions].set(
            bias[self.n_ensemble_heads * self.n_actions :]
        )

        return params

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.Array = None):
        idx_network = 0 if key is None else jax.random.randint(key, (), 0, self.n_ensemble_heads)
        q_values = self.network.apply(params, state, use_running_average=True).reshape(
            (2 * self.n_ensemble_heads, self.n_actions)
        )

        # computes the best action for a single state from a uniformly chosen online network
        return jnp.argmax(q_values[self.n_ensemble_heads + idx_network])

    def get_model(self):
        return {"params": self.params}
