from typing import Literal, Callable
import torch
import torch.nn as nn
import torch.nn.init as init


Activation = Literal["tanh", "relu", "gelu"]

activations__ = {
    "tanh": torch.tanh,
    "relu": torch.relu,
    "gelu": torch.nn.functional.gelu,
}


class MLP(nn.Module):
    """Multi-Layer Perceptron with optional skip connections.

    Args:
        input_dim: Input dimension
        output_dim: Output dimension
        hidden_dims: Tuple of hidden layer dimensions
        activation: Activation function (callable or string)
        use_bias: Whether to use bias in hidden layers
        use_bias_last: Whether to use bias in the last layer
        skip_every: Add skip connections every N layers (None = no skip connections)
        dtype: Data type for computations
        kernel_init: Initialization function for weights
        bias_init: Initialization function for biases
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        *,
        hidden_dims: tuple[int, ...],
        activation: Callable | Activation = "relu",
        use_bias: bool = True,
        use_bias_last: bool = False,
        skip_every: int | None = None,
        dtype: torch.dtype | None = None,
        kernel_init: Callable | None = None,
        bias_init: Callable | None = None,
    ):
        super().__init__()

        assert len(hidden_dims) > 0, "Must have at least one hidden layer"

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.use_bias = use_bias
        self.use_bias_last = use_bias_last
        self.skip_every = skip_every
        self.dtype = dtype

        if not isinstance(activation, Callable):
            self.activation = lambda x: activations__[activation](x)
        else:
            self.activation = activation

        layers = []

        layers.append(
            nn.Linear(
                in_features=input_dim,
                out_features=hidden_dims[0],
                bias=use_bias,
                dtype=dtype,
            )
        )

        width_pairs = list(zip(hidden_dims[:-1], hidden_dims[1:]))
        for in_dim, out_dim in width_pairs:
            layers.append(
                nn.Linear(
                    in_features=in_dim,
                    out_features=out_dim,
                    bias=use_bias,
                    dtype=dtype,
                )
            )

        layers.append(
            nn.Linear(
                in_features=hidden_dims[-1],
                out_features=output_dim,
                bias=use_bias_last,
                dtype=dtype,
            )
        )

        self.layers = nn.ModuleList(layers)
        self._initialize_parameters(kernel_init, bias_init)

    def _initialize_parameters(
        self,
        kernel_init: Callable | None = None,
        bias_init: Callable | None = None,
    ):
        """Initialize layer parameters."""
        for layer in self.layers:
            if kernel_init is not None:
                kernel_init(layer.weight)

            if layer.bias is not None:
                if bias_init is not None:
                    bias_init(layer.bias)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """Forward pass through the MLP.

        Args:
            inputs: Input tensor of shape (batch_size, input_dim)

        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        last_layer_idx = len(self.layers) - 1

        skip_layers = None
        if self.skip_every:
            skip_layers = set(range(0, last_layer_idx, self.skip_every))

        y = self.layers[0](inputs)
        x = self.activation(y)

        if last_layer_idx > 1:
            for l, layer in enumerate(self.layers[1:-1]):
                add_skip = skip_layers is not None and l in skip_layers

                if add_skip:
                    skip_x = x

                y = layer(x)
                x = self.activation(y)

                if add_skip:
                    x = x + skip_x

        x = self.layers[-1](x)
        return x


# class Autoencoder(nnx.Module):
#     def __init__(self, input_dim: int, activation: Callable, rngs: nnx.Rngs):
#         super().__init__()
#         self.activation = activation
#         self.encoder = nnx.Sequential(
#             nnx.Linear(input_dim, 1000, rngs=rngs),
#             activation,
#             nnx.Linear(1000, 500, rngs=rngs),
#             activation,
#             nnx.Linear(500, 250, rngs=rngs),
#             activation,
#             nnx.Linear(250, 30, rngs=rngs),
#         )

#         self.decoder = nnx.Sequential(
#             nnx.Linear(30, 250, rngs=rngs),
#             activation,
#             nnx.Linear(250, 500, rngs=rngs),
#             activation,
#             nnx.Linear(500, 1000, rngs=rngs),
#             activation,
#             nnx.Linear(1000, input_dim, rngs=rngs),
#         )

#     def __call__(self, batch):
#         assert len(batch.shape) == 2, "Input Image Must be Flattened"
#         encoded = self.encoder(batch)
#         decoded = self.decoder(encoded)
#         return decoded


# def l2_loss(params):
#     """Computes the L2 norm for a collection of parameters using tree operations."""
#     squared_leaves = jax.tree.map(lambda p: jnp.sum(p**2), params)
#     return 0.5 * jax.tree.reduce(operator.add, squared_leaves)


# def autoencoder_loss(
#     model: Autoencoder,
#     batch,
#     l2_reg,
#     is_training: bool,
#     mse_loss: bool = False,
#     return_output: bool = False,
# ) -> jax.Array:
#     """Evaluates the loss of the autoencoder."""

#     logits = model(batch)

#     if mse_loss:
#         sigmoid_logits = jax.nn.sigmoid(logits)
#         losses = optax.losses.l2_loss(sigmoid_logits, batch)
#         avg_losses = jnp.mean(losses)
#     else:
#         losses = jnp.sum(
#             optax.losses.sigmoid_binary_cross_entropy(logits, batch), axis=1
#         )
#         avg_losses = jnp.mean(losses)

#     params = nnx.state(model, nnx.Param)

#     if is_training:
#         l2_reg_val = l2_loss(params)
#         avg_losses = avg_losses + l2_reg * l2_reg_val

#     if return_output:
#         return avg_losses, logits
#     else:
#         return avg_losses
