from functools import partial
from typing import Dict

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.core import FrozenDict

from slimdqn.networks.architectures.dqn import MetaDQNNet
from slimdqn.networks.dqn import DQN
from slimdqn.sample_collection import IDX_RB
from slimdqn.sample_collection.replay_buffer import ReplayBuffer


class METADQN(DQN):
    def __init__(
        self,
        q_key: jax.random.PRNGKey,
        observation_dim,
        n_actions,
        features: list,
        cnn: bool,
        learning_rate: float,
        meta_learning_rate: float,
        gamma_init: float,
        gamma_validation: float,
        update_horizon: int,
        update_to_data: int,
        target_update_frequency: int,
        loss_type: str = "huber",
        adam_eps: float = 1e-8,
    ):
        self.q_key = q_key
        self.q_network = MetaDQNNet(features, cnn, n_actions)
        self.params = self.q_network.init(self.q_key, jnp.zeros(observation_dim, dtype=jnp.float32), 0.9)
        self.target_params = self.params.copy()
        self.gamma_logit = jnp.log(gamma_init / (1 - gamma_init))  # inverse of the sigmoid function
        self.gamma_logit_validation = jnp.log(
            gamma_validation / (1 - gamma_validation)
        )  # inverse of the sigmoid function

        # eps_root=1e-9 so that the gradient over adam does not output nans
        self.optimizer = optax.adam(learning_rate, eps=adam_eps, eps_root=1e-9)
        self.optimizer_state = self.optimizer.init(self.params)

        self.meta_optimizer = optax.adam(meta_learning_rate)
        self.meta_optimizer_state = self.meta_optimizer.init(self.gamma_logit)

        self.update_horizon = update_horizon
        self.update_to_data = update_to_data
        self.target_update_frequency = target_update_frequency
        self.loss_type = loss_type

    def update_online_params(self, step: int, replay_buffer: ReplayBuffer):
        if step % self.update_to_data == 0:
            batch_samples = replay_buffer.sample_transition_batch()
            meta_batch_samples = replay_buffer.sample_transition_batch()

            self.gamma_logit, self.meta_optimizer_state, meta_loss, self.params, self.optimizer_state, loss = (
                self.meta_learn_on_batch(
                    self.gamma_logit,
                    self.meta_optimizer_state,
                    meta_batch_samples,
                    self.params,
                    self.target_params,
                    self.optimizer_state,
                    batch_samples,
                )
            )

            return meta_loss, loss
        return 0

    @partial(jax.jit, static_argnames="self")
    def meta_learn_on_batch(
        self,
        gamma_logit: float,
        meta_optimizer_state,
        meta_batch_samples,
        params: FrozenDict,
        params_target: FrozenDict,
        optimizer_state,
        batch_samples,
    ):
        (meta_loss, (params, optimizer_state, loss)), meta_grad_loss = jax.value_and_grad(
            self.meta_loss_on_batch, has_aux=True
        )(gamma_logit, meta_batch_samples, params, params_target, optimizer_state, batch_samples)
        meta_updates, meta_optimizer_state = self.meta_optimizer.update(meta_grad_loss, meta_optimizer_state)
        gamma_logit = optax.apply_updates(gamma_logit, meta_updates)

        return gamma_logit, meta_optimizer_state, meta_loss, params, optimizer_state, loss

    def meta_loss_on_batch(
        self,
        gamma_logit: float,
        meta_batch_samples,
        params: FrozenDict,
        params_target: FrozenDict,
        optimizer_state,
        batch_samples,
    ):
        loss, grad_loss = jax.value_and_grad(self.loss_on_batch)(params, params_target, batch_samples, gamma_logit)
        updates, optimizer_state = self.optimizer.update(grad_loss, optimizer_state)
        params = optax.apply_updates(params, updates)

        return (
            self.loss_on_batch(params, params_target, meta_batch_samples, self.gamma_logit_validation),
            (params, optimizer_state, loss),
        )

    def loss_on_batch(self, params: FrozenDict, params_target: FrozenDict, samples, gamma_logit):
        return jax.vmap(self.loss, in_axes=(None, None, 0, None))(params, params_target, samples, gamma_logit).mean()

    def loss(self, params: FrozenDict, params_target: FrozenDict, sample, gamma_logit: float):
        # computes the loss for a single sample
        target = self.compute_target(params_target, sample, gamma_logit)
        q_value = self.q_network.apply(params, sample[IDX_RB["state"]], jax.lax.stop_gradient(gamma_logit))[
            sample[IDX_RB["action"]]
        ]
        return jnp.square(q_value - target)

    def compute_target(self, params: FrozenDict, samples, gamma_logit: float):
        gamma = nn.sigmoid(gamma_logit)
        # computes the target value for single sample
        return samples[IDX_RB["reward"]] + (1 - samples[IDX_RB["terminal"]]) * (gamma**self.update_horizon) * jnp.max(
            self.q_network.apply(params, samples[IDX_RB["next_state"]], jax.lax.stop_gradient(gamma_logit))
        )

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params: FrozenDict, state: jnp.ndarray, **kwargs):
        # computes the best action for a single state
        return jnp.argmax(self.q_network.apply(params, state, kwargs["gamma_logit"])).astype(jnp.int8)

    def get_model(self) -> Dict:
        return {"params": self.params, "gamma": nn.sigmoid(self.gamma_logit)}
