from functools import partial

import jax
import jax.numpy as jnp
import optax
from flax.core import FrozenDict
from jax import Array

from minto.networks.architectures.dqn import DQNNet
from minto.sample_collection.replay_buffer import ReplayBuffer, ReplayElement


class DQN:
    def __init__(
        self,
        key: Array,
        observation_dim,
        n_actions,
        features: list,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        data_to_update: int,
        target_update_frequency: int,
        adam_eps: float = 1e-8,
        target_function: str = "default",
        layer_norm: bool = False,
    ):
        self.network = DQNNet(
            features, architecture_type, n_actions, layer_norm=layer_norm
        )
        self.key = key
        self.params = self.network.init(
            key, jnp.zeros(observation_dim, dtype=jnp.float32)
        )

        self.optimizer = optax.adam(learning_rate, eps=adam_eps)
        self.optimizer_state = self.optimizer.init(self.params)
        self.target_params = self.params

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.data_to_update = data_to_update
        self.target_update_frequency = target_update_frequency
        self.cumulated_info = {
            "loss": 0,
            "grad_norm": 0,
            "param_norm": 0,
            "online_fraction": 0,
            "q_value": 0,
            "target": 0,
        }
        self.listed_info = {"online_fraction_all": []}

        self.compute_target_fn = {
            "default": self.compute_target,
            "online": self.compute_target_online,
            "min": self.compute_target_min,  # MINTO
            "random": self.compute_target_random,
            "max": self.compute_target_max,
            "mean": self.compute_target_mean,
        }[target_function]

        print(f"Using target function: {target_function}")

    def update_online_params(self, step: int, replay_buffer: ReplayBuffer):
        if step % self.data_to_update == 0:
            batch_samples = replay_buffer.sample()

            self.key, key = jax.random.split(self.key)
            self.params, self.optimizer_state, info = self.learn_on_batch(
                key,
                self.params,
                self.target_params,
                self.optimizer_state,
                batch_samples,
            )
            # cumulate the info
            for k in info.keys():
                if k in self.cumulated_info:
                    self.cumulated_info[k] += info[k]
                if f"{k}_all" in self.listed_info:
                    self.listed_info[f"{k}_all"].append(info[k].item())

    def update_target_params(self, step: int):
        if step % self.target_update_frequency == 0:
            self.target_params = self.params.copy()
            # average the cumulated info and reset
            logs = {
                k: v / (self.target_update_frequency / self.data_to_update)
                for k, v in self.cumulated_info.items()
            }
            logs.update({k: v for k, v in self.listed_info.items()})

            self.cumulated_info = {k: 0 for k in self.cumulated_info.keys()}
            self.listed_info = {k: [] for k in self.listed_info.keys()}

            return True, logs
        return False, {}

    @partial(jax.jit, static_argnames="self")
    def learn_on_batch(
        self,
        key: Array,
        params: FrozenDict,
        params_target: FrozenDict,
        optimizer_state,
        batch_samples,
    ):
        (loss, info), grad_loss = jax.value_and_grad(self.loss_on_batch, has_aux=True)(
            params, params_target, batch_samples, key
        )
        updates, optimizer_state = self.optimizer.update(grad_loss, optimizer_state)
        params = optax.apply_updates(params, updates)

        # add loss and extra info
        info.update({"grad_norm": optax.global_norm(grad_loss)})
        info.update({"param_norm": optax.global_norm(params)})
        info.update({"loss": loss})

        return params, optimizer_state, info

    def loss_on_batch(
        self, params: FrozenDict, params_target: FrozenDict, samples, key: Array
    ):
        keys = jax.random.split(key, samples.state.shape[0])
        loss, info = jax.vmap(self.loss, in_axes=(None, None, 0, 0))(
            params, params_target, samples, keys
        )
        info = {k: v.mean() for k, v in info.items()}  # average the info over the batch
        return loss.mean(), info

    def loss(
        self,
        params: FrozenDict,
        params_target: FrozenDict,
        sample: ReplayElement,
        key: Array,
    ):
        # computes the loss for a single sample
        # target = self.compute_target(params_target, sample)
        target, info = self.compute_target_fn(key, params_target, params, sample)
        q_values = self.network.apply(params, sample.state)
        q_value = q_values[sample.action]

        # add Q(s,a) and y to info
        info.update({"q_value": q_value})
        info.update({"target": target})

        return jnp.square(q_value - target), info

    def compute_target(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        # computes the target value for single sample
        return (
            sample.reward
            + (1 - sample.is_terminal)
            * (self.gamma**self.update_horizon)
            * jnp.max(self.network.apply(target_params, sample.next_state)),
            {},
        )

    def compute_target_online(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        # computes the target value for single sample
        return (
            sample.reward
            + (1 - sample.is_terminal)
            * (self.gamma**self.update_horizon)
            * jnp.max(
                jax.lax.stop_gradient(
                    self.network.apply(online_params, sample.next_state)
                )
            ),
            {"online_fraction": 1},
        )

    # min_anchors -> max_a min_ot
    def compute_target_min(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        # computes the target value for single sample using both target and online params
        q_online_next = self.network.apply(online_params, sample.next_state)
        q_online_next = jax.lax.stop_gradient(q_online_next)
        q_target_next = self.network.apply(target_params, sample.next_state)

        q_next = jnp.max(jnp.minimum(q_online_next, q_target_next))

        info = {
            "online_fraction": jnp.equal(
                q_next,
                jnp.max(
                    jnp.where(q_online_next <= q_target_next, q_online_next, -jnp.inf)
                ),
            ).astype(jnp.float32)
        }
        return (
            sample.reward
            + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * q_next,
            info,
        )

    # min_anchors -> max_a max_ot
    def compute_target_max(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        # computes the target value for single sample using both target and online params
        q_online_next = self.network.apply(online_params, sample.next_state)
        q_online_next = jax.lax.stop_gradient(q_online_next)
        q_target_next = self.network.apply(target_params, sample.next_state)

        q_next = jnp.max(jnp.maximum(q_online_next, q_target_next))

        info = {
            "online_fraction": jnp.equal(
                q_next,
                jnp.max(
                    jnp.where(q_online_next > q_target_next, q_online_next, -jnp.inf)
                ),
            ).astype(jnp.float32)
        }

        return (
            sample.reward
            + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * q_next,
            info,
        )

    # min_anchors -> max_a mean_ot
    def compute_target_mean(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        # computes the target value for single sample using both target and online params
        q_online_next = self.network.apply(online_params, sample.next_state)
        q_online_next = jax.lax.stop_gradient(q_online_next)
        q_target_next = self.network.apply(target_params, sample.next_state)

        q_next = jnp.vstack((q_online_next, q_target_next))

        return (
            sample.reward
            + (1 - sample.is_terminal)
            * (self.gamma**self.update_horizon)
            * jnp.max(jnp.mean(q_next, axis=0)),
            {},
        )

    def compute_target_random(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        # computes the target value for single sample using both target and online params
        q_online_next = self.network.apply(online_params, sample.next_state)
        q_online_next = jnp.max(jax.lax.stop_gradient(q_online_next))
        q_target_next = jnp.max(self.network.apply(target_params, sample.next_state))

        random_mask = jax.random.bernoulli(key, p=0.5, shape=q_target_next.shape)
        q_next = jnp.where(random_mask, q_online_next, q_target_next)

        info = {"online_fraction": random_mask.astype(jnp.float32)}

        return (
            sample.reward
            + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * q_next,
            info,
        )

    @partial(jax.jit, static_argnames="self")
    def best_action(
        self,
        params: FrozenDict,
        state: jnp.ndarray,
        key: Array,
    ):
        # computes the best action for a single state
        return jnp.argmax(self.network.apply(params, state))

    def get_model(self):
        return {"params": self.params}
