from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict
from jax import Array

from minto.networks.architectures.iqn import IQNNet
from minto.networks.dqn import DQN
from minto.sample_collection.replay_buffer import ReplayBuffer, ReplayElement


class IQN(DQN):
    def __init__(
        self,
        key: Array,
        observation_dim,
        n_actions,
        features: list,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        data_to_update: int,
        target_update_frequency: int,
        adam_eps: float = 3.125e-4,
        target_function: str = "default",
        layer_norm: bool = False,
    ):
        print("Layer Norm?", layer_norm)

        self.network = IQNNet(  # change
            features, architecture_type, n_actions, layer_norm=layer_norm
        )
        self.params = self.network.init(
            key, jnp.zeros(observation_dim, dtype=jnp.float32), key, 32
        )
        self.n_quantiles_policy = 32
        self.n_quantiles = 64
        self.n_quantiles_target = 64
        self.network_key = key

        self.optimizer = optax.adam(learning_rate, eps=adam_eps)
        self.optimizer_state = self.optimizer.init(self.params)
        self.target_params = self.params

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.data_to_update = data_to_update
        self.target_update_frequency = target_update_frequency
        self.cumulated_info = {
            "loss": 0,
            "grad_norm": 0,
            "param_norm": 0,
            "online_fraction": 0,
            "q_next_online": 0,
            "q_next_target": 0,
            "q_value": 0,
            "target": 0,
        }

        self.listed_info = {"online_fraction_all": []}

        self.compute_target_fn = {
            "default": self.compute_target,
            "min": self.compute_target_min,
        }[target_function]

    def update_online_params(self, step: int, replay_buffer: ReplayBuffer):
        if step % self.data_to_update == 0:
            batch_samples = replay_buffer.sample()
            self.network_key, key = jax.random.split(self.network_key)
            self.params, self.optimizer_state, info = self.learn_on_batch(
                self.params,
                self.target_params,
                self.optimizer_state,
                batch_samples,
                key,
            )
            # cumulate the info
            for k in info.keys():
                if k in self.cumulated_info:
                    self.cumulated_info[k] += info[k]
                if f"{k}_all" in self.listed_info:
                    self.listed_info[f"{k}_all"].append(info[k].item())

    @partial(jax.jit, static_argnames="self")
    def learn_on_batch(
        self,
        params: FrozenDict,
        params_target: FrozenDict,
        optimizer_state,
        batch_samples,
        key: Array,
    ):
        (loss, info), grad_loss = jax.value_and_grad(self.loss_on_batch, has_aux=True)(
            params, params_target, batch_samples, key
        )
        updates, optimizer_state = self.optimizer.update(grad_loss, optimizer_state)
        params = optax.apply_updates(params, updates)

        # add loss and extra info
        info.update({"grad_norm": optax.global_norm(grad_loss)})
        info.update({"param_norm": optax.global_norm(params)})
        info.update({"loss": loss})

        return params, optimizer_state, info

    def loss_on_batch(
        self, params: FrozenDict, params_target: FrozenDict, samples, key: Array
    ):
        loss, info = jax.vmap(self.loss, in_axes=(None, None, 0, 0))(
            params,
            params_target,
            samples,
            jax.random.split(key, samples.state.shape[0]),
        )
        info = {
            k: v.mean(axis=0) for k, v in info.items()
        }  # average the info over the batch
        return loss.mean(), info

    def loss(
        self,
        params: FrozenDict,
        params_target: FrozenDict,
        sample: ReplayElement,
        key: Array,
    ):
        # output (n_quantiles_target)
        targets, info = self.compute_target_fn(
            params_target, params, sample, jax.random.split(key)[1]
        )

        # output (n_quantiles, n_actions) | (n_quantiles)
        q_quantiles_values, quantiles = self.network.apply(
            params, sample.state, key, self.n_quantiles
        )

        # output (n_quantiles)
        q_quantiles = q_quantiles_values[:, sample.action]

        # cross difference
        # output (n_quantiles_target, n_quantiles)
        bellman_errors = targets[:, jnp.newaxis] - q_quantiles[jnp.newaxis]

        abs_bellman_errors_mask_low = jax.lax.stop_gradient(
            (jnp.abs(bellman_errors) <= 1).astype(jnp.float32)
        )
        huber_losses_quadratic_case = (
            abs_bellman_errors_mask_low * 0.5 * bellman_errors**2
        )
        abs_bellman_errors_mask_high = jax.lax.stop_gradient(
            (jnp.abs(bellman_errors) > 1).astype(jnp.float32)
        )
        huber_losses_linear_case = abs_bellman_errors_mask_high * (
            jnp.abs(bellman_errors) - 0.5
        )
        # output (n_quantiles_target, n_quantiles)
        huber_losses = huber_losses_quadratic_case + huber_losses_linear_case

        bellman_errors_mask_low = jax.lax.stop_gradient(bellman_errors < 0).astype(
            jnp.float32
        )

        # mapping over the target quantiles
        # output (n_quantiles_target, n_quantiles)
        quantile_losses = jax.vmap(
            lambda quantile, bellman_error_mask_low, huber_loss: jnp.abs(
                quantile - bellman_error_mask_low
            )
            * huber_loss,
            in_axes=(None, 0, 0),
        )(quantiles, bellman_errors_mask_low, huber_losses)

        loss = jnp.mean(jnp.sum(quantile_losses, axis=1))

        # logging info
        info.update({"q_value": jnp.mean(q_quantiles)})
        info.update({"target": jnp.mean(targets)})

        return loss, info

    def compute_target(
        self,
        params_target: FrozenDict,
        params_online: FrozenDict,
        sample: ReplayElement,
        key: Array,
    ):
        # We need to n_quantiles_target quantiles for the next value function
        # and n_quantiles_policy quantiles for the next action
        # output (n_quantiles_policy + n_quantiles_target, n_actions)

        next_q_policy_quantiles_quantiles, _ = self.network.apply(
            params_target,
            sample.next_state,
            key,
            self.n_quantiles_policy + self.n_quantiles_target,
        )

        # output (n_actions)
        next_q_policy_values = jnp.mean(
            next_q_policy_quantiles_quantiles[: self.n_quantiles_policy], axis=0
        )
        next_action = jnp.argmax(next_q_policy_values)

        # output (n_quantiles_target)
        return (
            sample.reward
            + (1 - sample.is_terminal)
            * self.gamma**self.update_horizon
            * next_q_policy_quantiles_quantiles[self.n_quantiles_policy :, next_action]
        ), {
            "q_next_target": jnp.mean(
                next_q_policy_quantiles_quantiles[
                    self.n_quantiles_policy :, next_action
                ]
            )
        }

    def compute_target_min(
        self,
        params_target: FrozenDict,
        params_online: FrozenDict,
        sample: ReplayElement,
        key: Array,
    ):
        # We need to n_quantiles_target quantiles for the next value function
        # and n_quantiles_policy quantiles for the next action
        # output (n_quantiles_policy + n_quantiles_target, n_actions)

        target_key, online_key = jax.random.split(key)

        next_target_q_policy_quantiles_quantiles, _ = self.network.apply(
            params_target,
            sample.next_state,
            target_key,
            self.n_quantiles_policy + self.n_quantiles_target,
        )

        next_online_q_policy_quantiles_quantiles, _ = self.network.apply(
            params_online,
            sample.next_state,
            online_key,
            self.n_quantiles_policy + self.n_quantiles_target,
        )

        next_online_q_policy_quantiles_quantiles = jax.lax.stop_gradient(
            next_online_q_policy_quantiles_quantiles
        )

        next_q_policy_quantiles = jnp.stack(
            (
                next_target_q_policy_quantiles_quantiles[: self.n_quantiles_policy],
                next_online_q_policy_quantiles_quantiles[: self.n_quantiles_policy],
            ),
            axis=0,
        )

        next_q_policy_values = jnp.mean(next_q_policy_quantiles, axis=1)
        idx = jnp.argmin(next_q_policy_values, axis=0)  # (n_actions)
        # print(idx, next_q_policy_values)
        # output (n_actions)
        next_action = jnp.argmax(
            next_q_policy_values[idx, jnp.arange(next_q_policy_values.shape[-1])]
        )

        next_q_quantiles = jnp.stack(
            (
                next_target_q_policy_quantiles_quantiles[self.n_quantiles_policy :],
                next_online_q_policy_quantiles_quantiles[self.n_quantiles_policy :],
            ),
            axis=0,
        )

        next_q_quantiles = next_q_quantiles[idx[next_action], :, next_action]

        # output (n_quantiles_target)
        return (
            sample.reward
            + (1 - sample.is_terminal)
            * self.gamma**self.update_horizon
            * next_q_quantiles
        ), {
            "online_fraction": jnp.mean(idx),  # TODO: check
            "q_next_target": jnp.mean(
                next_target_q_policy_quantiles_quantiles[
                    self.n_quantiles_policy :, next_action
                ]
            ),
            "q_next_online": jnp.mean(
                next_online_q_policy_quantiles_quantiles[
                    self.n_quantiles_policy :, next_action
                ]
            ),
        }

    # TODO: FIX!
    def compute_target_min_abs_anchors(
        self,
        params_target: FrozenDict,
        params_online: FrozenDict,
        sample: ReplayElement,
        key: Array,
    ):
        # We need to n_quantiles_target quantiles for the next value function
        # and n_quantiles_policy quantiles for the next action
        # output (n_quantiles_policy + n_quantiles_target, n_actions)
        A, B, C = (
            self.n_quantiles_target,
            self.n_quantiles_policy,
            self.n_quantiles_target,
        )

        target_key, online_key = jax.random.split(key)

        next_target_q_policy_quantiles_quantiles, _ = self.network.apply(
            params_target,
            sample.next_state,
            target_key,
            self.n_quantiles_target + self.n_quantiles_policy + self.n_quantiles_target,
        )

        next_online_q_policy_quantiles_quantiles, _ = self.network.apply(
            params_online,
            sample.next_state,
            online_key,
            self.n_quantiles_target + self.n_quantiles_policy + self.n_quantiles_target,
        )

        next_online_q_policy_quantiles_quantiles = jax.lax.stop_gradient(
            next_online_q_policy_quantiles_quantiles
        )

        idx = jnp.argmin(
            jnp.stack(
                (
                    jnp.abs(jnp.mean(next_target_q_policy_quantiles_quantiles[0:A])),
                    jnp.abs(jnp.mean(next_online_q_policy_quantiles_quantiles[0:A])),
                ),
                axis=0,
            ),
            axis=0,
        )

        next_q_policy_quantiles_quantiles = jnp.where(
            idx == 0,
            next_target_q_policy_quantiles_quantiles[A : A + B],
            next_online_q_policy_quantiles_quantiles[A : A + B],
        )

        # output (n_actions)
        next_q_policy_values = jnp.mean(next_q_policy_quantiles_quantiles, axis=0)
        next_action = jnp.argmax(next_q_policy_values)

        next_q_quantiles = jnp.where(
            idx == 0,
            next_target_q_policy_quantiles_quantiles[A + B :, next_action],
            next_online_q_policy_quantiles_quantiles[A + B :, next_action],
        )
        # output (n_quantiles_target)
        return (
            sample.reward
            + (1 - sample.is_terminal)
            * self.gamma**self.update_horizon
            * next_q_quantiles
        ), {
            "online_fraction": jnp.mean(idx),
            "q_next_target": jnp.mean(
                next_target_q_policy_quantiles_quantiles[A + B :, next_action]
            ),
            "q_next_online": jnp.mean(
                next_online_q_policy_quantiles_quantiles[A + B :, next_action]
            ),
        }

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params: FrozenDict, state: jnp.ndarray, key: Array):
        # computes the best action for a single state
        q_quantiles, _ = self.network.apply(params, state, key, self.n_quantiles_policy)
        q_values = jnp.mean(q_quantiles, axis=0)
        action = jnp.argmax(q_values)
        return action
