import os, sys
# os.environ["JAX_PLATFORMS"] = "cpu"  
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Limit to single GPU
from functools import partial
from tqdm import tqdm
from flax import nnx
import jax.numpy as jnp
import typing
from typing import Optional, Union, Sequence, Tuple, Callable, Literal 
import jax
jax.config.update('jax_enable_x64', False)
from flax.typing import Dtype, Initializer
from nnx_models import mlp
from nnx_models import SirenLayer
from nnx_models import RealGaborLayer
from nnx_models.utils import FourierLinear
from nnx_models import LoRA, add_lora_to_model, merge_lora_params, reset_lora_params
import optax 
from soap_jax import soap
import orbax.checkpoint as ocp 
from nnx_models.utils_adaptive_lora import AdaptiveLoRA, add_lora_to_model_adaptive
from nnx_models.utils_lora import add_lora_to_model
import nnx_models.utils_lora_vanilla as ulv 


def loss(model, x, y):
    preds = model(x)
    return jnp.mean((preds - y)**2)

# Optimizer for the WARM-UP phase (trains the whole model)
def define_lr_optim_full(model):
    scheduler = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=5000)
    return nnx.Optimizer(model, optax.adamw(scheduler), wrt=nnx.Param)

# Optimizer for the ADAPTIVE phase (trains ONLY LoRA parameters)
@nnx.jit
def define_lr_optim_lora(model):
    """
    Creates and returns the appropriate NNX optimizer based on the use_lora flag.
    """
    scheduler = optax.cosine_decay_schedule(init_value=1e-2, decay_steps=5000, alpha=1e-3)
    optim = nnx.Optimizer(model, soap(learning_rate=scheduler, precondition_frequency=1.0), wrt=nnx.LoRAParam)
    return optim

@nnx.jit
def train_step(model, optimizer, x, y):
    loss_step, grads = nnx.value_and_grad(loss)(model, x, y)
    optimizer.update(grads)
    return loss_step

@nnx.jit
def train_step_lora(model, optimizer, x, y):
    loss_step, grads = nnx.value_and_grad(loss, argnums=nnx.DiffState(0, filter=nnx.LoRAParam))(model, x, y)
    optimizer.update(grads)
    return loss_step

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Helper metrics (need to use them -- just incase) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 
@partial(jax.jit, static_argnames=('alpha',))
def calculate_h1_norm_of_difference(current_solution, previous_solution, alpha: float = 1.0):
    """
    Calculates the H1 Sobolev norm of the difference between two 2D solution fields.

    The H1 norm is sqrt(||u||_L2^2 + alpha * ||∇u||_L2^2).
    
    Args:
        current_solution: The 2D array (H, W) of the solution at the current step.
        previous_solution: The 2D array (H, W) of the solution at the previous step.
        alpha: A scaling factor to weigh the importance of the gradient term.
               A higher alpha makes the metric more sensitive to sharp features.

    Returns:
        The H1 norm of the difference.
    """
    # Ensure inputs are 2D
    if current_solution.ndim > 2:
        current_solution = current_solution.squeeze()
    if previous_solution.ndim > 2:
        previous_solution = previous_solution.squeeze()
        
    # 1. Calculate the difference field
    diff = current_solution - previous_solution
    
    # 2. Calculate the L2 norm of the difference itself
    l2_norm_squared = jnp.linalg.norm(diff)**2
    
    # 3. Calculate the spatial gradients of the difference field
    # jnp.gradient returns a list of arrays [gradient_in_dim_0, gradient_in_dim_1, ...]
    grads = jnp.gradient(diff)
    grad_y, grad_x = grads[0], grads[1]
    
    # 4. Calculate the squared L2 norm of the gradient field
    grad_l2_norm_squared = jnp.linalg.norm(grad_y)**2 + jnp.linalg.norm(grad_x)**2
    
    # 5. Combine them into the H1 norm
    h1_norm = jnp.sqrt(l2_norm_squared + alpha * grad_l2_norm_squared)
    
    return h1_norm

@jax.jit 
def check_tolerance_per_time_slice(u_new: jax.Array, u_hat: jax.Array, tau_r=1e-3, tau_a=1e-6) -> jax.Array:
    """
    u_new, u_hat: arrays of same shape
    tau_r, tau_a: relative and absolute tolerances
    Returns: per-component scaled error
    """
    return jnp.max(jnp.abs(u_new - u_hat) / (tau_r * jnp.abs(u_new) + tau_a))

@jax.jit 
def check_rel_l2_per_time_slice(u_new: jax.Array, u_hat: jax.Array) -> jax.Array:
    """
    u_new, u_hat: arrays of same shape
    tau_r, tau_a: relative and absolute tolerances
    Returns: per-component scaled error
    """
    return jnp.linalg.norm(u_new - u_hat) / jnp.linalg.norm(u_new)

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

def create_new_model():
    """Creates a new instance of the MLP model with fresh random weights."""
    print("--- Creating a new, randomly initialized model instance. ---")
    model = mlp.MLP(input_dim=2, output_dim=1, hidden_dim=256, fourier_emb_scale=7.0, num_hidden_layers=5)
    return model

def adaptive_multi_rate_fitting_lora(
    x: jax.Array, 
    y: jax.Array,
    lora_rank: int = 4,
    total_timesteps: int = 100,
    warmup_steps: int = 10,
    tolerance_multiplier: float = 2.0,
    reset_threshold_factor: float = 4.0,
    initial_fit_interval: int = 1,
    max_fit_interval: int = 5,
    verbose: bool = True
):
    """
    The main loop for adaptive fitting, with the correct if/elif/else logic for
    fast dynamics and scheduled fits.
    """
    # --- Initialize the model at the very beginning ---
    model = create_new_model()
    
    stored_nefs = {}
    
    # --- WARM-UP PHASE: Train the full model ---
    if verbose:
        print(f"--- Starting Warm-up & Tolerance Learning Phase ({warmup_steps} steps) ---")
    
    warmup_dynamics_norms = []
    previous_pde_solution = None
    for i in tqdm(range(warmup_steps), desc='Warm-up Full Training'):
        current_pde_solution = y[i]
        
        if previous_pde_solution is not None:
            dynamics_norm = jnp.linalg.norm(current_pde_solution - previous_pde_solution)
            warmup_dynamics_norms.append(dynamics_norm)
        
        optim_full = define_lr_optim_full(model)
        for epoch in range(500):
            loss_step = train_step(model, optim_full, x, y)
        
        _, abstract_state = nnx.split(model)
        stored_nefs[i] = {'state': abstract_state, 'tstamp': {i}} 
        previous_pde_solution = current_pde_solution
            
    # --- SET TOLERANCE ---
    # ... (Your tolerance calculation code is correct) ...
    min_dynamic_norm = jnp.min(jnp.array(warmup_dynamics_norms))
    learned_tolerance = min_dynamic_norm * tolerance_multiplier
    reset_threshold = learned_tolerance * reset_threshold_factor 

    # --- TRANSITION: Add LoRA adapters for the adaptive phase ---
    print("\n--- Adding LoRA adapters and switching to fine-tuning mode. ---")
    add_lora_to_model(model, lora_rank=lora_rank)
        
    print(f"--- Starting Adaptive Phase ---")
    
    # --- ADAPTIVE PHASE ---
    last_fit_timestep = warmup_steps - 1
    fit_interval = initial_fit_interval
    previous_pde_solution = y[warmup_steps - 1] 
    
    pbar = tqdm(range(warmup_steps, total_timesteps), initial=warmup_steps, total=total_timesteps)
    for cst in pbar:
        time_since_last_fit = cst - last_fit_timestep
        current_pde_solution = y[cst]
        dynamics_norm = jnp.linalg.norm(current_pde_solution - previous_pde_solution)

        # CONDITION 1: FIT (FAST DYNAMICS)
        # Priority 1: If dynamics are fast, we MUST fit.
        if dynamics_norm > learned_tolerance:
            pbar.set_description(f"Dynamics {dynamics_norm:.3f} > Tol. Fitting.")
            
            # # Nested check: if dynamics are EXTREMELY fast, reset the model (Andrei jump in here)
            # if dynamics_norm > reset_threshold:
            #     print(f"\n!!! RESET: Dynamics spike detected. Creating new model. !!!")
            #     model = create_new_model()
            #     # A quick full fit on the new backbone is needed before LoRA is useful
            #     optim_reset = define_lr_optim_full(model)
            #     for epoch in range(500):
            #         train_step(model, optim_reset, x, current_pde_solution)
            #     add_lora_to_model(model, lora_rank=lora_rank)
            
            # Because dynamics are fast, decrease the fit interval to be more vigilant.
            fit_interval = max(initial_fit_interval, fit_interval - 1)
            # Perform the LoRA fine-tuning
            optim_lora = define_lr_optim_lora(model)
            for epoch in range(500):
                loss_step = train_step_lora(model, optim_lora, x, y)
            
            _, abstract_state = nnx.split(model)
            stored_nefs[cst] = {'state': abstract_state, 'tstamp': {cst}} 
            last_fit_timestep = cst

        # CONDITION 2: FIT (SCHEDULED) - THIS BLOCK IS NOW CORRECTLY RESTORED
        # This only runs if Condition 1 was false (i.e., dynamics are slow in the multi-rate phase).
        # We fit because enough time has passed.
        elif time_since_last_fit >= fit_interval:
            pbar.set_description(f"Interval reached ({time_since_last_fit}/{fit_interval}). Fitting.")
            
            # Because dynamics have been stable, INCREASE the fit interval to be more sparse.
            fit_interval = min(max_fit_interval, fit_interval + 1)
                
            # Perform the LoRA fine-tuning
            optim_lora = define_lr_optim_lora(model)
            for epoch in range(500):
                loss_step = train_step_lora(model, optim_lora, x, y)
            
            _, abstract_state = nnx.split(model)
            stored_nefs[cst] = {'state': abstract_state, 'tstamp': {cst}} 
            last_fit_timestep = cst

        # CONDITION 3: SKIP 
        # This runs if dynamics are slow AND it's not yet time for a scheduled fit.
        else:
            pbar.set_description(f"Dynamics <= Tol. Skipping. ({time_since_last_fit}/{fit_interval})")
            # Do nothing, just proceed to the next time step.

        previous_pde_solution = current_pde_solution 

    return stored_nefs, learned_tolerance

if __name__ == "__main__": 
    ns = 100
    coords = jnp.load("../2DNS_10000/coord.npy")
    vorticity = jnp.load(".../2DNS_10000/vorticity_trajectory.npy")[: ns]
    vorticity_normalized = jax.vmap(lambda xx: vorticity[xx]/jnp.linalg.norm(vorticity[xx]))(jnp.arange(0, ns, 1))
    dt = 0.0008765604502203664
    tlist = jnp.arange(0.0, 25.0, dt)[: ns] ## check if linspace or arange (minor change)

    checkpointer = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
    model = mlp.MLP(input_dim=2, output_dim=1,hidden_dim=256,fourier_emb_scale=7.0,num_hidden_layers=5)

    print(f"starting multi-rate self-adjusting LoRA temporal compression")
    adaptive_multi_rate_fitting_lora(x=coords,y=vorticity_normalized,lora_rank=4,total_timesteps=100,warmup_steps=5,tolerance_multiplier=2.0,reset_threshold_factor=4.0,initial_fit_interval=1, max_fit_interval=5, verbose=True)
    
    print(f"Script successfully ended ......... ")

 