from flax import nnx
import jax.numpy as jnp
import typing
import jax
from flax.typing import Dtype, Initializer
from nnx_models import SirenLayer
from nnx_models import RealGaborLayer
from nnx_models.utils import FourierLinear

class AdaptiveLoRA(nnx.LoRA):
    """
    An adaptive version of LoRA that can dynamically change the number of
    active low-rank updates ("micro-layers") during training.
    """
    def __init__(
        self,
        in_features: int,
        lora_rank: int,
        out_features: int,
        lora_layers: int,  # Total number of available micro-layers
        max_ml: int = 2000,
        ml: int = 20,       # Initial number of active micro-layers
        *,
        base_module: typing.Optional[nnx.Module] = None,
        dtype: typing.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        kernel_init: Initializer = nnx.initializers.lecun_normal(),
        lora_param_type: typing.Type[nnx.variablelib.Variable] = nnx.LoRAParam,
        rngs: nnx.rnglib.Rngs = nnx.rnglib.Rngs(0),
    ):
        # --- CORRECTED INITIALIZATION ---
        # We avoid calling super().__init__ to prevent it from creating the
        # standard 2D lora_a and lora_b parameters, which we will replace.
        
        # Manually set attributes that would have been set by the parent.
        self.in_features = in_features
        self.out_features = out_features
        self.dtype = dtype
        
        # If no base_module is provided, create a standard Linear layer.
        if base_module is None:
            self.base_module = nnx.Linear(
                in_features,
                out_features,
                use_bias=False, # Standard for LoRA base
                dtype=dtype,
                param_dtype=param_dtype,
                kernel_init=kernel_init,
                rngs=rngs,
            )
        else:
            self.base_module = base_module

        # Store LoRA-specific configuration
        self.lora_rank = lora_rank
        self.lora_layers = lora_layers
        self.max_ml = max_ml
        self.adaptive_ml = ml  # The dynamic number of active layers

        # Define the 3D LoRA matrices for holding all micro-layers.
        self.lora_a = lora_param_type(
            kernel_init(rngs.params(), (lora_layers, in_features, lora_rank), param_dtype)
        )
        self.lora_b = lora_param_type(
            # Initialize lora_b to zeros for a stable start
            nnx.initializers.zeros(rngs.params(), (lora_layers, lora_rank, out_features), param_dtype)
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Treat adaptive_ml as a static value during JIT compilation.
        adaptive_ml = jax.lax.stop_gradient(self.adaptive_ml)
        
        # Efficiently slice only the active LoRA matrices *before* computation.
        active_lora_a = self.lora_a.value[:adaptive_ml]
        active_lora_b = self.lora_b.value[:adaptive_ml]

        # Calculate the combined LoRA update.
        lora_updates = jnp.einsum('ijk,ikl -> ijl', active_lora_a, active_lora_b)
        summed_update = jnp.sum(lora_updates, axis=0)
        
        # Apply the LoRA update with scaling.
        lora_out = x @ summed_update * (1 / jnp.sqrt(self.lora_rank))
        
        # --- CORRECTED BASE MODULE CALL ---
        # The base_module is an nnx.Module and should be called directly.
        base_out = self.base_module(x)
            
        return base_out + lora_out
    
    def adapt_ml(self, tol: float, max_ml: int = None):
        """
        Adaptively increases the number of active LoRA micro-layers.
        This method must be called outside of JIT-compiled functions.
        """
        # if max_ml is None:
        #     max_ml = self.layers.hidself.lora_a.value.shape[0] # give a maximum hidden layers x rank > hidden rank

        # Your logic: if the improvement (tol) is below a threshold, adapt.
        # Note: The original logic was tol < threshold. I've kept your >= version.
        if tol >= 1.5e-04 and self.adaptive_ml < self.max_ml:
            # print(f"|--> Adapting! Increasing micro-layers from {self.adaptive_ml} to {min(self.adaptive_ml + 1, max_ml)}") # change accordingly
            self.adaptive_ml = min(self.adaptive_ml + 10, self.max_ml)

def add_lora_to_model_adaptive(
    model: nnx.Module,
    lora_rank: int,
    lora_layers: int = 4,
    max_ml=1000,
    ml: int = 20,  # Initial number of active micro-layers
    param_dtype: Dtype = jnp.float32,
    rngs: nnx.rnglib.Rngs = nnx.rnglib.Rngs(0) ## number of active LoRA micro-steps (for adaptive corrections)
):
    """Adds LoRA layers to a given model."""

    for i, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, nnx.Linear):
            # Replace the linear layer with a LoRA layer
            lora_layer = AdaptiveLoRA(
                in_features=layer.in_features,
                lora_rank=lora_rank,
                out_features=layer.out_features,
                lora_layers=lora_layers,
                max_ml=max_ml,
                ml=ml,
                param_dtype=param_dtype,
                base_module=layer,
                rngs=rngs, 
            )
            model.hidden_layers.layers[i] = lora_layer
        elif isinstance(layer, SirenLayer):
            layer.linear = AdaptiveLoRA(
                in_features=layer.linear.in_features,
                lora_rank=lora_rank,
                out_features=layer.linear.out_features,
                lora_layers=lora_layers,
                max_ml = max_ml,
                ml=ml,
                param_dtype=param_dtype,
                base_module=layer.linear,
                rngs=rngs,
            )
        elif isinstance(layer, RealGaborLayer):
            layer.freqs = AdaptiveLoRA(
                in_features=layer.freqs.in_features,
                lora_rank=lora_rank,
                out_features=layer.freqs.out_features,

                param_dtype=param_dtype,
                base_module=layer.freqs,
                rngs=rngs, 
            )
            layer.scales = AdaptiveLoRA(
                in_features=layer.scales.in_features,
                lora_rank=lora_rank,
                out_features=layer.scales.out_features,
                lora_layers=lora_layers,
                max_ml=max_ml,
                ml=ml,
                param_dtype=param_dtype,
                base_module=layer.scales,
                rngs=rngs, 
            )
        else:
            raise ValueError(f"Unsupported layer type: {type(layer)}. LoRA can only be applied to Linear, SirenLayer, and RealGaborLayer.")

def reset_lora_leaf(path, val, key):
    """Check if the path corresponds to LoRA parameters."""
    if path[-2].key == "lora_a":
        return jax.nn.initializers.lecun_normal()(key, val.shape, val.dtype)
    elif path[-2].key == "lora_b":
        return jax.nn.initializers.zeros(key, val.shape, val.dtype)

def reset_lora_params(model: nnx.Module, lora_layers: int = 1, key: jax.random.PRNGKey = jax.random.PRNGKey(0)):
    """Resets the LoRA parameters of the model."""
    if not nnx.state(model, nnx.LoRAParam):
        raise ValueError("No LoRA parameters found in the model.")
    keys = jax.random.split(key, model.num_hidden_layers)
    for i, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, AdaptiveLoRA):
            layer.lora_a.value = jax.nn.initializers.lecun_normal()(keys[i], (lora_layers,) + (layer.lora_a.value.shape[-2], layer.lora_a.value.shape[-1]), layer.lora_a.value.dtype)
            layer.lora_b.value = jax.nn.initializers.zeros(keys[i], (lora_layers,) + (layer.lora_b.value.shape[-2], layer.lora_b.value.shape[-1]), layer.lora_b.value.dtype)

def merge_lora_params(model: nnx.Module, ml: int = 5):
    if not nnx.state(model, nnx.LoRAParam):
        raise ValueError("No LoRA parameters found in the model.")

    for _, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, AdaptiveLoRA):
            # Slice the LoRA matrices up to the given ml before the einsum.
            active_lora_a = layer.lora_a.value[:ml]
            active_lora_b = layer.lora_b.value[:ml]
            out = jnp.einsum('ijk,ikl -> ijl', active_lora_a, active_lora_b)
            out = jnp.sum(out, axis=0) * (1 / jnp.sqrt(layer.lora_rank))
            layer.base_module.kernel.value += out