import flax.linen as nn
import jax
import jax.numpy as jnp

from .protocol import QFixProtocol


class FFAdapter(nn.Module):
    """Adapts QFix and AdditiveQFix to timeless batches."""

    module: QFixProtocol

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> tuple[jax.Array, dict]:
        # individual_qvalues.shape == (N, B)
        # individual_vvalues.shape == (N, B)
        # states.shape == (B, DS)
        # joint_action_n_hot.shape == (B, N*A), N-hot encoding

        individual_qvalues = jnp.expand_dims(individual_qvalues, axis=1)
        # individual_qvalues.shape == (N, T, B)
        individual_vvalues = jnp.expand_dims(individual_vvalues, axis=1)
        # individual_vvalues.shape == (N, T, B)
        states = jnp.expand_dims(states, axis=0)
        # states.shape == (T, B, DS)
        joint_action_n_hot = jnp.expand_dims(joint_action_n_hot, axis=0)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        qvalues, info = self.module.qvalues(
            individual_qvalues,
            individual_vvalues,
            states,
            joint_action_n_hot,
            w_delta=w_delta,
            w_gt=w_gt,
        )
        # mixed_qvalues.shape == (T, B)

        qvalues = qvalues.squeeze(0)
        # mixed_qvalues.shape == (B,)
        info = {k: v.squeeze(0) for k, v in info.items()}

        return qvalues, info

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
    ) -> tuple[jax.Array, dict]:
        # individual_vvalues.shape == (N, B)
        # states.shape == (B, DS)

        individual_vvalues = jnp.expand_dims(individual_vvalues, axis=1)
        # individual_vvalues.shape == (N, T, B)
        states = jnp.expand_dims(states, axis=0)
        # states.shape == (T, B, DS)

        vvalues, info = self.module.vvalues(individual_vvalues, states)
        # mixed_vvalues.shape == (T, B)

        vvalues = vvalues.squeeze(0)
        # mixed_qvalues.shape == (B,)
        info = {k: v.squeeze(0) for k, v in info.items()}

        return vvalues, info

    def w(
        self,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # states.shape == (B, DS)
        # joint_action_n_hot.shape == (B, N*A), N-hot encoding

        states = jnp.expand_dims(states, axis=0)
        # states.shape == (T=1, B, DS)
        joint_action_n_hot = jnp.expand_dims(joint_action_n_hot, axis=0)
        # joint_action_n_hot.shape == (T=1, B, N*A), N-hot encoding

        w = self.module.w(
            states,
            joint_action_n_hot,
            w_delta=w_delta,
            w_gt=w_gt,
        )
        # w.shape == (T, B, N)
        w = w.squeeze(0)
        # w.shape == (B, N)

        return w

    def b(
        self,
        states: jax.Array,
    ) -> jax.Array:
        # states.shape == (B, DS)

        states = jnp.expand_dims(states, axis=0)
        # states.shape == (T, B, DS)

        b = self.module.b(states)
        # mixed_qvalues.shape == (T, B)
        b = b.squeeze(0)
        # mixed_qvalues.shape == (B,)

        return b
