"""lora adaption for Implicit Neural Representations (INRs)"""
from __future__ import annotations

import os, sys
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu
import optax
from tqdm.auto import trange 

"""prototype Dense layers with LoRA weights added"""
class LoRALinear(eqx.Module):
    base_layer: eqx.nn.Linear  # The original, frozen layer
    lora_A: jnp.ndarray        # The new, trainable LoRA matrices
    lora_B: jnp.ndarray

    def __init__(self, in_features, out_features, rank, *, key):
        # Unpack the key for creating the base layer and LoRA matrices
        base_key, a_key, b_key = jrandom.split(key, 3)

        # Initialize the original layer as you normally would
        self.base_layer = eqx.nn.Linear(in_features, out_features, key=base_key)

        # Initialize the low-rank matrices
        # Standard LoRA initialization: A is random, B is zero
        self.lora_A = jrandom.normal(a_key, (in_features, rank)) 
        self.lora_B = jnp.zeros((rank, out_features))

    def __call__(self, x):
        # The core LoRA logic: base_output (W) + lora_output (\Delta W)
        base_output = self.base_layer(x)
        lora_output = x @ self.lora_A @ self.lora_B
        return base_output + lora_output

"""Replace the MLP layers with the LoRA enhanced Linear layer class above"""
class LoRAMLP(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3 = jrandom.split(key, 3)
        # We replace standard eqx.nn.Linear with our new LoRALinear
        self.layers = [
            LoRALinear(in_features=2, out_features=64, rank=4, key=key1), ## example, but can be adapated as per architechture by inputting different in_features and out_features
            jax.nn.relu,
            LoRALinear(in_features=64, out_features=1, rank=2, key=key2)
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
# Robust filter based on the module structure 
### Look into this nice documentation on implementation on Freezing parameters with a filter implementatio: https://docs.kidger.site/equinox/examples/frozen_layer/
def get_filter_spec(model):
    """
    Creates a boolean PyTree to partition the model. It sets everything
    to False (frozen) by default, then specifically marks the lora_A and
    lora_B arrays in any LoRALinear layer as True (trainable).
    """
    # 1. Start with a PyTree of the same structure as the model, but with all leaves as False.
    filter_spec = jtu.tree_map(lambda _: False, model)

    # 2. Iterate through the model's layers to find the LoRA layers.
    #    (This assumes your model has a list of layers, which is a common pattern)
    for i, layer in enumerate(model.layers):
        if isinstance(layer, LoRALinear):
            # 3. For each LoRALinear layer, create a 'where' function that points
            #    to its lora_A and lora_B attributes.
            where = lambda m: (m.layers[i].lora_A, m.layers[i].lora_B)
            
            # 4. Use eqx.tree_at to replace the 'False' values at that specific
            #    location in our filter_spec with 'True'.
            filter_spec = eqx.tree_at(where, filter_spec, replace=(True, True))

    return filter_spec

# 1. Initialize: model and optimizer 
model_key = jax.random.PRNGKey(42)
model = LoRAMLP(key=model_key)
optim = optax.adam(1e-3)

# 2. Instantiate the filter_spec using the model.
filter_spec = get_filter_spec(model)

# 3. Partition the model into trainable (diff -- trainable LoRA weights) and frozen (static -- base Weights) parts.

#  This will now correctly separate the lora_A/lora_B arrays from everything else.
diff_model, static_model = eqx.partition(model, filter_spec)

# 4. Initialize the optimizer state using ONLY the trainable parameters.
opt_state = optim.init(diff_model)

@eqx.filter_jit
def make_step(diff_model, static_model, opt_state, x, y):
    # The inner loss function remains the same. It correctly combines the model
    # for the forward pass and gradient calculation.
    @eqx.filter_grad
    def loss(diff_model, static_model, x, y):
        model = eqx.combine(diff_model, static_model)
        pred_y = jax.vmap(model)(x)
        return jnp.mean((y - pred_y) ** 2)

    # Calculate gradients for only the trainable part (diff_model).
    grads = loss(diff_model, static_model, x, y)
    
    # Get the updates and the new optimizer state.
    updates, opt_state = optim.update(grads, opt_state)
    
    # 1. Re-create the full model from the parts that were passed into this function.
    model_to_update = eqx.combine(diff_model, static_model)

    # 2. Now, apply the updates to this fully formed model.
    updated_model = eqx.apply_updates(model_to_update, updates)
    
    # 3. Re-partition the newly updated model for the next training step.
    #    The same filter_spec is used as before.
    diff_model, static_model = eqx.partition(updated_model, filter_spec)
    
    # 4. Return the new partitioned parts and the new optimizer state.
    return diff_model, static_model, opt_state

def run_training(batch: tuple, epochs: int) -> None: 
    x, y = batch
    print("Training LoRA with Equinox...")
    for epoch in trange(epochs, desc="Training test MLP model"):
        # The 'model' itself is not passed directly, only its partitioned parts
        diff_model, static_model, opt_state = make_step(diff_model, static_model, opt_state, x, y)

        if epoch % 100 == 0:
            # To evaluate, we need to combine the model back together
            full_model = eqx.combine(diff_model, static_model)
            pred_y = jax.vmap(full_model)(x)
            current_loss = jnp.mean((y - pred_y) ** 2)
            print(f"Epoch {epoch} | Loss: {float(current_loss):.6e}")

    return 

if __name__ == '__main__': 
    sys.path.append(os.path.dirname(f".."))
    grid_k = jnp.load(f"../grid_k.npy", allow_pickle=True).reshape(-1, 2).astype(jnp.float32)
    vorticity_k = jnp.load(f"../vorticity_k.npy", allow_pickle=True)[500].astype(jnp.float32).reshape(-1, 1) 
    vorticity_k /= jnp.linalg.norm(vorticity_k) # use jnp.max for normalizing data (or initialize some random (1000, 2) shaped input and (1000, 1) ground truth label) 

    # %%%%%%%%%%% Run training loop %%%%%%%%%%%%%%%%%%%%%%%%%%%    
    epochs = 10_000 
    run_training(batch=(grid_k, vorticity_k), epochs=epochs)
    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

    


