from flax import nnx
import jax.numpy as jnp
import typing
import jax
from flax.typing import Dtype, Initializer
import sys 
from lora_models import SirenLayer
from lora_models import RealGaborLayer
from lora_models.utils import FourierLinear

class LoRA(nnx.LoRA):
    """Modified version of nnx.LoRA."""
    
    def __init__(
        self,
        in_features: int,
        lora_rank: int,
        out_features: int,
        *,
        base_module: typing.Optional[nnx.Module] = None,
        dtype: typing.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        kernel_init: Initializer = nnx.initializers.lecun_normal(),
        lora_param_type: typing.Type[nnx.variablelib.Variable] = nnx.LoRAParam,
        rngs: nnx.rnglib.Rngs = nnx.rnglib.Rngs(0),
    ):
        
        # Call parent __init__ with any modifications you need
        super().__init__(
            in_features=in_features,
            lora_rank=lora_rank,
            out_features=out_features,
            base_module=base_module,
            dtype=dtype,
            param_dtype=param_dtype,
            kernel_init=kernel_init,
            lora_param_type=lora_param_type,
            rngs=rngs
        )
        self.lora_rank = lora_rank
        self.lora_b = lora_param_type(
            nnx.initializers.zeros(rngs.params(), (lora_rank, out_features), param_dtype)
        )
        
    
    def __call__(self, x) -> jnp.ndarray:
        out = x @ self.lora_a @ self.lora_b * (1 / jnp.sqrt(self.lora_rank))
        if self.base_module is not None:
            if not callable(self.base_module):
                raise ValueError('`self.base_module` must be callable.')
            out += self.base_module(x)

        return out

def add_lora_to_model(
    model: nnx.Module,
    lora_rank: int,
    param_dtype: Dtype = jnp.float32,
    rngs: nnx.rnglib.Rngs = nnx.rnglib.Rngs(0)
):
    """Adds LoRA layers to a given model."""

    for i, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, nnx.Linear):
            # Replace the linear layer with a LoRA layer
            lora_layer = LoRA(
                in_features=model.hidden_dim,
                lora_rank=lora_rank,
                out_features=model.hidden_dim,
                param_dtype=param_dtype,
                base_module=layer,
                rngs=rngs
            )
            model.hidden_layers.layers[i] = lora_layer
        elif isinstance(layer, SirenLayer):
            layer.linear = LoRA(
                in_features=layer.linear.in_features,
                lora_rank=lora_rank,
                out_features=layer.linear.out_features,
                param_dtype=param_dtype,
                base_module=layer.linear,
                rngs=rngs
            )
        elif isinstance(layer, RealGaborLayer):
            layer.freqs = LoRA(
                in_features=layer.freqs.in_features,
                lora_rank=lora_rank,
                out_features=layer.freqs.out_features,
                param_dtype=param_dtype,
                base_module=layer.freqs,
                rngs=rngs
            )
            layer.scales = LoRA(
                in_features=layer.scales.in_features,
                lora_rank=lora_rank,
                out_features=layer.scales.out_features,
                param_dtype=param_dtype,
                base_module=layer.scales,
                rngs=rngs
            )
        else:
            raise ValueError(f"Unsupported layer type: {type(layer)}. LoRA can only be applied to Linear, SirenLayer, and RealGaborLayer.")

def modify_lora_params(path, val, key):
    """Check if the path corresponds to LoRA parameters."""
    if path[-2].key == "lora_a":
        return jax.nn.initializers.lecun_normal()(key, val.shape, val.dtype)
    elif path[-2].key == "lora_b":
        return jax.nn.initializers.zeros(key, val.shape, val.dtype)

def reset_lora_params(model: nnx.Module, key: jax.random.PRNGKey = jax.random.PRNGKey(0)):
    """Resets the LoRA parameters of the model."""
    if not nnx.state(model, nnx.LoRAParam):
        raise ValueError("No LoRA parameters found in the model.")
    lora_params = nnx.state(model, nnx.LoRAParam)
    treedef = jax.tree.structure(lora_params)
    num_leaves = treedef.num_leaves
    keys = jax.random.split(key, num_leaves)
    key_tree = jax.tree.unflatten(treedef, keys)
    lora_params = jax.tree.map_with_path(modify_lora_params, lora_params, key_tree)
    nnx.update(model, lora_params)

def merge_lora_params(model: nnx.Module):
    if not nnx.state(model, nnx.LoRAParam):
        raise ValueError("No LoRA parameters found in the model.")

    for _, layer in enumerate(model.hidden_layers.layers):
        if isinstance(layer, LoRA):
            layer.base_module.kernel.value += layer.lora_a.value @ layer.lora_b.value * (1 / jnp.sqrt(layer.lora_rank))
