from typing import Callable, Sequence   
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
import distrax

# Define a type for the activation function for clarity
ActivationFn = Callable[[jnp.ndarray], jnp.ndarray]

class ResidualBlock(nn.Module):
    """
    A simple residual block with two Dense layers.
    output = activation_fn(Dense(features)(activation_fn(Dense(features)(x))) + shortcut(x))
    """
    features: int
    activation_fn: ActivationFn
    kernel_init_main: Callable = orthogonal(np.sqrt(2))
    bias_init_main: Callable = constant(0.0)
    kernel_init_shortcut: Callable = orthogonal(np.sqrt(2))
    bias_init_shortcut: Callable = constant(0.0)

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        residual = x

        # Main path
        y = nn.Dense(
            features=self.features,
            kernel_init=self.kernel_init_main,
            bias_init=self.bias_init_main,
            name="dense_1"
        )(x)
        y = self.activation_fn(y)
        y = nn.Dense(
            features=self.features,
            kernel_init=self.kernel_init_main,
            bias_init=self.bias_init_main,
            name="dense_2"
        )(y)

        # Shortcut path
        if residual.shape[-1] != self.features:
            residual = nn.Dense(
                features=self.features,
                kernel_init=self.kernel_init_shortcut,
                bias_init=self.bias_init_shortcut,
                name="shortcut_projection"
            )(residual)

        output = self.activation_fn(y + residual)
        return output

class ResNetActorCritic(nn.Module):
    action_dim: int
    activation: str = "tanh"
    hidden_features: int = 512
    num_residual_blocks: int = 1 # New parameter

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        if self.activation == "relu":
            activation_fn = nn.relu
        elif self.activation == "tanh":
            activation_fn = nn.tanh
        else:
            raise ValueError(f"Unsupported activation: {self.activation}")

        # --- Actor Branch ---
        actor_y = x
        # Initial projection if input features don't match hidden_features
        # This is important if the first residual block expects hidden_features
        # and the input x has a different dimension.
        if x.shape[-1] != self.hidden_features and self.num_residual_blocks > 0:
            actor_y = nn.Dense(
                features=self.hidden_features,
                kernel_init=orthogonal(np.sqrt(2)),
                bias_init=constant(0.0),
                name="actor_initial_projection"
            )(actor_y)
            actor_y = activation_fn(actor_y) # Apply activation after initial projection

        for i in range(self.num_residual_blocks):
            actor_y = ResidualBlock(
                features=self.hidden_features,
                activation_fn=activation_fn,
                name=f"actor_resblock_{i+1}"
            )(actor_y)

        actor_logits = nn.Dense(
            features=self.action_dim,
            kernel_init=orthogonal(0.01),
            bias_init=constant(0.0),
            name="actor_logits_layer"
        )(actor_y)
        pi = distrax.Categorical(logits=actor_logits)

        # --- Critic Branch ---
        critic_y = x
        # Initial projection for critic (can share if input is the same, or be separate)
        if x.shape[-1] != self.hidden_features and self.num_residual_blocks > 0:
            critic_y = nn.Dense(
                features=self.hidden_features,
                kernel_init=orthogonal(np.sqrt(2)),
                bias_init=constant(0.0),
                name="critic_initial_projection" # Use a different name if weights aren't shared
            )(critic_y)
            critic_y = activation_fn(critic_y)

        for i in range(self.num_residual_blocks):
            critic_y = ResidualBlock(
                features=self.hidden_features,
                activation_fn=activation_fn,
                name=f"critic_resblock_{i+1}"
            )(critic_y)

        critic_value = nn.Dense(
            features=1,
            kernel_init=orthogonal(1.0),
            bias_init=constant(0.0),
            name="critic_value_layer"
        )(critic_y)

        return pi, jnp.squeeze(critic_value, axis=-1)


class FFNActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        y = nn.Dense(
            512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        y = activation(y)
        y = nn.Dense(
            512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(y)
        y = activation(y)
        y = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(y)
        pi = distrax.Categorical(logits=y)

        critic = nn.Dense(
            512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)