import jax
import jax.numpy as jnp
from flax.core import FrozenDict
from jax import Array

from minto.networks.dqn import DQN
from minto.sample_collection.replay_buffer import ReplayElement


class DoubleDQN(DQN):
    def __init__(
        self,
        key: Array,
        observation_dim,
        n_actions,
        features: list,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        data_to_update: int,
        target_update_frequency: int,
        adam_eps: float = 1e-8,
    ):
        super().__init__(
            key,
            observation_dim,
            n_actions,
            features,
            architecture_type,
            learning_rate,
            gamma,
            update_horizon,
            data_to_update,
            target_update_frequency,
            adam_eps,
        )

    def compute_target(
        self,
        key: Array,
        target_params: FrozenDict,
        online_params: FrozenDict,
        sample: ReplayElement,
    ):
        """Compute the target using Double DQN."""
        # Get the Q-values for the next state using the target network
        q_next = self.network.apply(target_params, sample.next_state)
        # Get the action with the highest Q-value from the online network
        next_action = jnp.argmax(
            jax.lax.stop_gradient(self.network.apply(online_params, sample.next_state)),
            axis=-1,
        )
        # Compute the target using the Double DQN formula
        return (
            sample.reward
            + (1 - sample.is_terminal)
            * (self.gamma**self.update_horizon)
            * q_next.at[next_action].get(),
            {},
        )
