import jax.numpy as jnp
from jax import random
import flax.linen as nn

from functools import partial
from typing import Callable


class BoxDense(nn.Module):
    """
    Custom dense layer with specialized initialization for "plain" and "resnet" architectures.
    """
    name: str = "BoxDense"
    features: int
    activation: Callable
    depth: int
    layer: int
    arch_type: str
    dtype: any = jnp.float64

    SUPPORTED_ARCHITECTURES = {"plain", "resnet"}

    def setup(self):
        if self.arch_type not in self.SUPPORTED_ARCHITECTURES:
            raise ValueError(f"Unsupported architecture type: {self.arch_type}")

    @staticmethod
    def _initialize_plain(rng, shape, dtype):
        """Initialize weights and biases for the 'plain' architecture."""
        rng_points, rng_norms = random.split(rng, 2)
        norms = random.normal(rng_norms, shape=shape, dtype=dtype)
        norms = norms / jnp.linalg.norm(norms, axis=1, keepdims=True)
        p_max = jnp.maximum(0, jnp.sign(norms))
        points = random.uniform(rng_points, shape=shape, dtype=dtype)
        scaling_factor = 1. / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)
        return norms, points, scaling_factor

    @staticmethod
    def _initialize_resnet(rng, shape, depth, layer, dtype):
        """Initialize weights and biases for the 'resnet' architecture."""
        rng_points, rng_norms = random.split(rng, 2)
        norms = random.normal(rng_norms, shape=shape, dtype=dtype)
        norms = norms / jnp.linalg.norm(norms, axis=1, keepdims=True)
        scaling_multiplier = (1. + 1. / depth)**layer
        p_max = scaling_multiplier * jnp.maximum(0, jnp.sign(norms))
        points = random.uniform(rng_points, shape=shape, dtype=dtype, minval=0., maxval=scaling_multiplier)
        scaling_factor = 1. / depth / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)
        return norms, points, scaling_factor

    def box_init(self, rng, shape, arch_type : str):
        """Initialize weights and biases based on the specified architecture."""
        if self.arch_type == "plain":
            norms, points, scaling_factor = self._initialize_plain(rng, shape, self.dtype)
        elif self.arch_type == "resnet":
            norms, points, scaling_factor = self._initialize_resnet(rng, shape, self.depth, self.layer, self.dtype)
        else:
            raise ValueError(f"Unsupported architecture type: {self.arch_type}")

        kernel = scaling_factor * norms
        bias = scaling_factor * jnp.sum(points * norms, axis=0)
        return kernel, bias.ravel()

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Perform the forward pass."""
        init_fn = partial(self.box_init, arch_type=self.arch_type, )
        layer_weights = self.param("layer_weights", init_fn, (x.shape[-1], self.features))
        kernel, bias = layer_weights
        return self.activation(jnp.tensordot(x, kernel, axes=(-1, 0)) - bias)

