import haiku as hk
import jax
import jax.numpy as jnp


class PredictionNetwork(hk.Module):
    def __init__(
        self,
        num_hidden: int,
        num_actions: int,
        is_state_vector: bool,
        activation: str = "tanh",
    ):
        super().__init__()
        self.num_hidden = num_hidden
        self.num_actions = num_actions
        self.is_state_vector = is_state_vector
        self.activation = jax.nn.relu if activation == "relu" else jax.nn.tanh
        assert activation in ["relu", "tanh"]

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(jnp.float32)

        # If the state is an image, we use a CNN
        if not self.is_state_vector:
            x = hk.Conv2D(32, kernel_shape=3)(x)
            x = jax.nn.relu(x)
            x = hk.Conv2D(16, kernel_shape=3)(x)
            # Average pooling
            window_shape = (1, 2, 2, 1)
            strides = (1, 2, 2, 1)
            padding = "VALID"
            pooled = jax.lax.reduce_window(
                x,
                init_value=0.0,
                computation=jax.lax.add,
                window_dimensions=window_shape,
                window_strides=strides,
                padding=padding,
            )
            # Divide by window size to get average
            x = pooled / (2 * 2)  # (1, 5, 5, 32)

        x = x.reshape((x.shape[0], -1))  # flatten, (1, 10*10*16)
        x = hk.Linear(self.num_hidden)(x)  # (1, 128)
        x = self.activation(x)

        # Value
        value = hk.Linear(self.num_hidden)(x)  # (1, 128)
        value = self.activation(value)
        value = hk.Linear(1)(value)  # (b, 1)
        value = jnp.squeeze(value, axis=-1)  # (b,)

        # Policy
        logits = hk.Linear(self.num_hidden)(x)  # (1, 128)
        logits = self.activation(logits)
        logits = hk.Linear(self.num_actions)(logits)  # (1, num_actions)
        return logits, value  # type: ignore


def make_network_apply_fns(args):
    def prediction_apply_fn(x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Apply the prediction network to the input."""
        net = PredictionNetwork(
            num_hidden=args.num_hidden,
            num_actions=args.num_actions,
            is_state_vector=args.is_state_vector,
        )
        return net(x)

    prediction_apply = hk.without_apply_rng(hk.transform(prediction_apply_fn))
    return prediction_apply
