import haiku as hk
import jax
import jax.numpy as jnp


class ObsEncoder(hk.Module):
    def __init__(
        self, num_hidden: int, is_state_vector: bool, name: str = "obs_encoder"
    ):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.is_state_vector = is_state_vector

    def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
        x = obs.astype(jnp.float32)
        # If the state is an image, we use a CNN
        if not self.is_state_vector:
            x = hk.Conv2D(32, kernel_shape=3)(x)
            x = jax.nn.relu(x)
            x = hk.Conv2D(16, kernel_shape=3)(x)
            # Average pooling
            window_shape = (1, 2, 2, 1)
            strides = (1, 2, 2, 1)
            padding = "VALID"
            pooled = jax.lax.reduce_window(
                x,
                init_value=0.0,
                computation=jax.lax.add,
                window_dimensions=window_shape,
                window_strides=strides,
                padding=padding,
            )
            # Divide by window size to get average
            x = pooled / (2 * 2)  # (1, 5, 5, 32)

        x = x.reshape((x.shape[0], -1))
        x = hk.Linear(self.num_hidden)(x)
        x = jax.nn.tanh(x)
        return hk.Linear(
            self.num_hidden,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(x)


class ActionEncoder(hk.Module):
    def __init__(self, num_hidden: int, name: str = "action_encoder"):
        super().__init__(name=name)
        self.num_hidden = num_hidden

    def __call__(self, action: jnp.ndarray) -> jnp.ndarray:
        action = action.astype(jnp.float32)
        return hk.Linear(
            self.num_hidden,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(action)


class ProjectionHead(hk.Module):
    def __init__(self, num_hidden: int, num_out: int, name: str = "projection_head"):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.num_out = num_out

    def __call__(self, obs_embedding: jnp.ndarray) -> jnp.ndarray:
        obs_embedding_ = obs_embedding.astype(jnp.float32)
        obs_embedding = hk.Linear(self.num_hidden)(obs_embedding)
        obs_embedding = jax.nn.relu(obs_embedding)
        obs_embedding = obs_embedding_ + obs_embedding
        return hk.Linear(
            self.num_out,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(obs_embedding)


class PolicyHead(hk.Module):
    def __init__(self, num_hidden: int, num_actions: int, name: str = "policy_head"):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.num_actions = num_actions

    def __call__(self, obs_embedding: jnp.ndarray) -> jnp.ndarray:
        obs_embedding_ = obs_embedding.astype(jnp.float32)
        obs_embedding = hk.Linear(self.num_hidden)(obs_embedding)
        obs_embedding = jax.nn.relu(obs_embedding)
        obs_embedding = obs_embedding_ + obs_embedding
        return hk.Linear(
            self.num_actions,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(obs_embedding)


class ValueHead(hk.Module):
    def __init__(self, num_hidden: int, name: str = "value_head"):
        super().__init__(name=name)
        self.num_hidden = num_hidden

    def __call__(self, obs_embedding: jnp.ndarray) -> jnp.ndarray:
        obs_embedding_ = obs_embedding.astype(jnp.float32)
        obs_embedding = hk.Linear(self.num_hidden)(obs_embedding)
        obs_embedding = jax.nn.relu(obs_embedding)
        obs_embedding = obs_embedding_ + obs_embedding
        return hk.Linear(
            1,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(obs_embedding).squeeze(
            -1
        )  # Squeeze to get a scalar value for each observation


class DynamicsHead(hk.Module):
    def __init__(self, num_hidden: int, num_actions: int, name: str = "dynamics_head"):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.num_actions = num_actions

    def __call__(
        self, obs_embedding: jnp.ndarray, action: jnp.ndarray
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        obs_embedding = obs_embedding.astype(jnp.float32)
        action = action.astype(jnp.float32)

        one_hot_action = jax.nn.one_hot(action, num_classes=self.num_actions)
        action_embedding = hk.Linear(self.num_hidden)(one_hot_action)
        action_embedding = jax.nn.relu(action_embedding)
        # combined = jnp.concatenate([obs_embedding, action_embedding], axis=-1)
        # next_state = hk.Linear(self.num_hidden)(combined)
        # next_state = jax.nn.tanh(next_state)
        next_state, rnn_output = hk.GRU(self.num_hidden)(
            action_embedding, obs_embedding
        )

        #  Add residual connection
        next_state = next_state + obs_embedding
        reward = hk.Linear(
            1,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(rnn_output).squeeze(-1)
        next_state_embedding = hk.Linear(
            self.num_hidden,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(next_state)
        return next_state_embedding, reward


def make_network_apply_fns(args):
    def representation_apply_fn(
        obs: jnp.ndarray,
    ):
        """Applies the representation model to the observation."""
        obs_embedding = ObsEncoder(
            num_hidden=args.num_hidden, is_state_vector=args.is_state_vector
        )(obs)
        return obs_embedding

    def projection_apply_fn(
        obs_embedding: jnp.ndarray,
    ) -> jnp.ndarray:
        """Applies the projection model to the observation embedding."""
        projection = ProjectionHead(
            num_hidden=args.num_hidden, num_out=args.num_hidden
        )(obs_embedding)
        return projection

    def critic_apply_fn(
        obs_embedding: jnp.ndarray,
    ) -> jnp.ndarray:
        """Applies the critic model to the observation embedding."""
        value = ValueHead(num_hidden=args.num_hidden)(obs_embedding)
        return value

    def policy_apply_fn(
        obs_embedding: jnp.ndarray,
    ) -> jnp.ndarray:
        """Applies the policy model to the observation embedding."""
        policy_logits = PolicyHead(
            num_hidden=args.num_hidden, num_actions=args.num_actions
        )(obs_embedding)
        return policy_logits

    def recurrent_inference_fn(
        obs_embedding: jnp.ndarray, action: jnp.ndarray
    ) -> jnp.ndarray:
        dynamics_output = DynamicsHead(
            num_hidden=args.num_hidden, num_actions=args.num_actions
        )(obs_embedding, action)
        return dynamics_output

    def init_model_fn(obs: jnp.ndarray, action: jnp.ndarray):
        """Just for tracing through the entire model."""
        obs_embedding = ObsEncoder(
            num_hidden=args.num_hidden, is_state_vector=args.is_state_vector
        )(obs)
        projection = ProjectionHead(
            num_hidden=args.num_hidden, num_out=args.num_hidden
        )(obs_embedding)
        policy_logits = PolicyHead(
            num_hidden=args.num_hidden, num_actions=args.num_actions
        )(obs_embedding)
        value = ValueHead(num_hidden=args.num_hidden)(obs_embedding)
        dynamics_output = DynamicsHead(
            num_hidden=args.num_hidden, num_actions=args.num_actions
        )(obs_embedding, action)
        return policy_logits, value, dynamics_output, projection

    init_model = hk.without_apply_rng(hk.transform(init_model_fn))
    policy_apply = hk.without_apply_rng(hk.transform(policy_apply_fn))
    representation_apply = hk.without_apply_rng(hk.transform(representation_apply_fn))
    projection_apply = hk.without_apply_rng(hk.transform(projection_apply_fn))
    critic_apply = hk.without_apply_rng(hk.transform(critic_apply_fn))
    recurrent_inference = hk.without_apply_rng(hk.transform(recurrent_inference_fn))

    return (
        init_model,
        representation_apply,
        projection_apply,
        policy_apply,
        critic_apply,
        recurrent_inference,
    )
