from typing import Protocol

import jax


class QFixProtocol(Protocol):
    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> tuple[jax.Array, dict]:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        # w.shape == (T, B, N)
        # b.shape == (T, B)
        ...

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
    ) -> jax.Array:
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        ...

    def w(
        self,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        ...

    def b(
        self,
        states: jax.Array,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        ...
