from flax import nnx
import jax.numpy as jnp
import typing
import jax
from flax.typing import Dtype, Initializer
from flax.nnx.nn import initializers
from nnx_models import SirenLayer
from nnx_models import RealGaborLayer
from nnx_models.utils import FourierLinear

default_a_initializer = initializers.he_uniform()
default_b_initializer = initializers.zeros

class LoRA(nnx.LoRA):
    """Slightly modified version of nnx.LoRA."""
    
    def __init__(
        self,
        in_features: int,
        lora_rank: int,
        out_features: int,
        *,
        base_module: typing.Optional[nnx.Module] = None,
        dtype: typing.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        a_initializer: Initializer = default_a_initializer,
        b_initializer: Initializer = default_b_initializer,
        lora_param_type: typing.Type[nnx.variablelib.Variable] = nnx.LoRAParam,
        rngs: nnx.rnglib.Rngs = nnx.rnglib.Rngs(0),
    ):
        super().__init__(
            in_features=in_features,
            lora_rank=lora_rank,
            out_features=out_features,
            base_module=base_module,
            dtype=dtype,
            param_dtype=param_dtype,
            lora_param_type=lora_param_type,
            rngs=rngs
        )
        self.a_initializer = a_initializer
        self.b_initializer = b_initializer
    def __call__(self, x: jax.Array) -> jax.Array:
        out = x @ self.lora_a.value @ self.lora_b.value
        if self.base_module is not None:
            if not callable(self.base_module):
                raise ValueError('`self.base_module` must be callable.')
            out += self.base_module(x)

        return out

def add_lora_to_model(
    model: nnx.Module,
    lora_rank: int,
    param_dtype: Dtype = jnp.float32,
    rngs: nnx.rnglib.Rngs = nnx.rnglib.Rngs(0)
):
    """Adds LoRA layers to a given model."""

    for i, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, nnx.Linear):
            lora_layer = LoRA(
                in_features=layer.in_features,
                lora_rank=lora_rank,
                out_features=layer.out_features,
                param_dtype=param_dtype,
                base_module=layer,
                rngs=rngs
            )
            model.hidden_layers.layers[i] = lora_layer
        elif isinstance(layer, SirenLayer):
            layer.linear = LoRA(
                in_features=layer.linear.in_features,
                lora_rank=lora_rank,
                out_features=layer.linear.out_features,
                param_dtype=param_dtype,
                base_module=layer.linear,
                rngs=rngs
            )
        elif isinstance(layer, RealGaborLayer):
            layer.freqs = LoRA(
                in_features=layer.freqs.in_features,
                lora_rank=lora_rank,
                out_features=layer.freqs.out_features,
                param_dtype=param_dtype,
                base_module=layer.freqs,
                rngs=rngs
            )
            layer.scales = LoRA(
                in_features=layer.scales.in_features,
                lora_rank=lora_rank,
                out_features=layer.scales.out_features,
                param_dtype=param_dtype,
                base_module=layer.scales,
                rngs=rngs
            )
        else:
            raise ValueError(f"Unsupported layer type: {type(layer)}. LoRA can only be applied to Linear, SirenLayer, and RealGaborLayer.")

def reset_lora_leaf(path, val, key):
    """Check if the path corresponds to LoRA parameters."""
    if path[-2].key == "lora_a":
        return jax.nn.initializers.lecun_normal()(key, val.shape, val.dtype)
    elif path[-2].key == "lora_b":
        return jax.nn.initializers.zeros(key, val.shape, val.dtype)

@nnx.jit(static_argnames=('lora_rank', 'lora_layers'))
def reset_lora_params(model: nnx.Module, lora_rank: int, lora_layers: int, key: jax.random.PRNGKey = jax.random.PRNGKey(0)):
    """Resets the LoRA parameters of the model."""
    if not nnx.state(model, nnx.LoRAParam):
        raise ValueError("No LoRA parameters found in the model.")
    keys = jax.random.split(key, model.num_hidden_layers)
    for i, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, LoRA):
            layer.lora_a.value = layer.a_initializer(keys[i], (layer.in_features, lora_rank), layer.lora_a.value.dtype)
            layer.lora_b.value = layer.b_initializer(keys[i], (lora_rank, layer.out_features), layer.lora_b.value.dtype)

@nnx.jit
def merge_lora_params(model: nnx.Module):
    if not nnx.state(model, nnx.LoRAParam):
        raise ValueError("No LoRA parameters found in the model.")
    
    for _, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, LoRA):
            out = layer.lora_a.value @ layer.lora_b.value
            layer.base_module.kernel.value += out
