from typing import List

import jax
import jax.numpy as jnp
import optax
from flax.core import FrozenDict

from slimdqn.networks.architectures.dqn import DQNNet
from slimdqn.sample_collection import IDX_RB


class BaseDQN:
    def __init__(self, features: List[int], cnn: bool, n_actions: int):
        self.q_network = DQNNet(features, cnn, n_actions)

    def value_and_grad(self, params: FrozenDict, targets, targets_validation, samples):
        (_, loss_validation), grad = jax.value_and_grad(self.loss_from_targets, has_aux=True)(
            params, targets, targets_validation, samples
        )

        return loss_validation, grad

    def loss_from_targets(self, params: FrozenDict, targets, targets_validation, samples):
        q_values = jax.vmap(lambda state, action: self.q_network.apply(params, state)[action])(
            samples[IDX_RB["state"]], samples[IDX_RB["action"]]
        )

        # always return the l2 loss to compare fairly between the networks
        return jnp.square(q_values - targets).mean(), jnp.square(q_values - targets_validation).mean()

    def best_action(self, params: FrozenDict, state: jnp.ndarray, **kwargs):
        # computes the best action for a single state
        return jnp.argmax(self.q_network.apply(params, state)).astype(jnp.int8)
