from functools import partial

import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax

from algorithms.utils.q_network import MLP


class Q_Agent:
    def __init__(self, layers):
        self.network = MLP(features=layers)

    def init(self, key, sample):
        return self.network.init(key, sample)

    @partial(jax.jit, static_argnames=["self"])
    def __apply_with_action__(self, params, state, action):
        q_s = self.network.apply(params, state)
        q_s = jnp.take_along_axis(q_s, action[:, None], axis=1)[:, 0]
        return q_s

    @partial(jax.jit, static_argnames=["self"])
    def __apply__(self, params, state):
        q_s = self.network.apply(params, state)
        return q_s
