from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict

from minto.networks.architectures.dqn import DQNNet
from minto.sample_collection.fixed_replay_buffer import FixedReplayBuffer
from minto.sample_collection.replay_buffer import ReplayElement
from minto.networks.cql import CQL


class DoubleCQL(CQL):
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        n_actions,
        features: list,
        layer_norm: bool,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        target_update_frequency: int,
        alpha_cql: float,
        adam_eps: float = 0.0003125,
        target_function: str = "default"
    ):
        assert target_function == "default", "DoubleCQL only works with the default target function in this class."
        super().__init__(
            key, observation_dim, n_actions, features, layer_norm, architecture_type, learning_rate,
            gamma, update_horizon, target_update_frequency, alpha_cql, adam_eps, target_function
        )

    def compute_target(self, target_params: FrozenDict, online_params: FrozenDict, sample: ReplayElement):
        """Compute the target using a Double DQN-Like for CQL."""
        # Get the Q-values for the next state using the target network
        q_next = self.network.apply(target_params, sample.next_state)
        # Get the action with the highest Q-value from the online network
        next_action = jnp.argmax(jax.lax.stop_gradient(self.network.apply(online_params, sample.next_state)), axis=-1)
        return sample.reward + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * q_next.at[next_action].get(),  {}
