import os
os.environ["JAX_PLATFORMS"] = "cpu"
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from soap_jax import soap
import optax
import shutil
from nnx_models.utils_lora import add_lora_to_model, merge_lora_params, reset_lora_params
from nnx_models.mlp import MLP
import orbax.checkpoint as ocp
from typing import Callable, Tuple
import logging
import absl.logging
from collections import deque

absl.logging.set_verbosity(absl.logging.FATAL)

@jax.jit
def max_abs(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.max(jnp.abs(x - y))


@jax.jit
def l2_norm(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.linalg.norm(x - y)

@jax.jit
def mae(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean(jnp.abs(x - y))

@jax.jit
def mse(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean((x - y) ** 2)

@jax.jit
def pearson_corr(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.corrcoef(x.flatten(), y.flatten())[0, 1]

def temporal_selection(trajectory: jax.Array,
                       queue_size: int,
                       norm_metric: Callable[[jax.Array, jax.Array], jax.Array] = max_abs,
                       corr_metric: Callable[[jax.Array, jax.Array], jax.Array] = pearson_corr,
                       corr_treshold: float = 0.9,
                       freq_treshold: Tuple[float, float] = (None, None)) -> jax.Array:
    """
    Selects the most informative time steps from a trajectory based on a norm metric and correlation metric.
    
    Args:
        trajectory: jax.Array of shape (T, ...) representing the trajectory over time.
        queue_size: int, number of time steps to select.
        norm_metric: Callable, function to compute the norm difference between two time steps.
        corr_metric: Callable, function to compute the correlation between two time steps.
        freq_treshold: Tuple[float, float], optional comparison, mostly relevant in the context of high-frequency data hard to capture by the model.
            First element is the treshhold frequency, after which frequency is considered to be hard to capture for the model, while the second element
            is the factor by which the selection criterion is scaled.

    Returns:
        jax.Array of shape (T', ...) containing the selected time steps.
    """

    start_idx = 0
    queue = deque(maxlen=queue_size)
    T = trajectory.shape[0]
    queue.append(1)
    maxx_diff = norm_metric(trajectory[1], trajectory[0])
    curr_diff_sum = norm_metric(trajectory[1], trajectory[0])
    selected_timesteps = [0]
    last_idx = 1

    while start_idx < T - 1:

        if len(queue) == queue_size or queue[-1] == T - 1:
            curr_mean = curr_diff_sum / len(queue)
            factor = jnp.sqrt(maxx_diff / curr_mean).item()
            if freq_treshold[0] is not None:
                factor *= freq_treshold[1]
            best_idx = start_idx + 1
            diff0 = norm_metric(trajectory[best_idx], trajectory[start_idx])
            for idx in queue:
                norm_val = norm_metric(trajectory[idx], trajectory[start_idx])
                corr_val = corr_metric(trajectory[idx], trajectory[start_idx])
                if norm_val / diff0 <= factor and corr_val >= corr_treshold:
                    best_idx = idx
                else:
                    break
            
            print(f"Selected idx: {best_idx} from {start_idx} with jump {best_idx - start_idx}")
            start_idx = best_idx
            selected_timesteps.append(best_idx)
            # print(len(queue))
            while len(queue) > 0 and queue[0] <= start_idx:
                prev_idx = queue.popleft()
                norm_val = norm_metric(trajectory[prev_idx], trajectory[prev_idx - 1])
                curr_diff_sum -= norm_val

        if last_idx + 1 < T:
            queue.append(last_idx + 1)
            last_idx += 1
            norm_val = norm_metric(trajectory[queue[-1]], trajectory[queue[-1] - 1])
            maxx_diff = max(maxx_diff, norm_val)
            curr_diff_sum += norm_val

    return jnp.array(selected_timesteps)

# %%%%%%%%%%%%%%%%%% In-situ selector %%%%%%%%%%%%%%%%%%%%%%
def persistent_median_policy(t, T, activity_array, history, factor, **kwargs):
    """
    Jumps if activity is near the baseline. 
    Freezes the baseline during spikes to force dense sampling.
    """
    window_size = kwargs.get('window_size', 5)
    # The 'patience' before we accept a surge as the new median baseline
    patience_limit = window_size * 5
    
    # State persistence handled via kwargs or global (in a real app, use a class)
    surge_count = kwargs.get('surge_count', 0)
    
    # 1. Initialization
    if len(history) < 20:
        return t + 1, 0 # 

    
    local_median = jnp.median(jnp.array(history))
    
    current_activity = activity_array[t+1]
    
    # 3. Decision Logic
    if current_activity > local_median * factor: # 20% buffer above median
        # WE ARE IN A SURGE (e.g., the merger at t=150 in your plot)
        # Force dense sampling and increment counter
        new_surge_count = surge_count + 1
        return t + 1, new_surge_count
    else:
        new_surge_count = 0 # Reset because we are back to 'normal'
        best_idx = min(t + window_size, T - 1)
        
        # Check ahead to ensure we don't jump OVER the start of a merger
        for candidate_idx in range(t + 1, best_idx + 1):
            if activity_array[candidate_idx] > local_median * factor:
                return candidate_idx, 0 # Stop early if merger starts mid-window
                
        return best_idx, new_surge_count

def persistent_orchestrator(activity_array, window_size=5):
    T = len(activity_array)
    t = 0
    selected = [0]
    history = deque(maxlen=60) # 50
    surge_count = 0
    patience_limit = window_size * 5 # 6

    while t < T - 1:
        next_t, surge_count = persistent_median_policy(
            t, T, activity_array, history, 
            window_size=window_size, surge_count=surge_count)
        
        # Update pointer
        t = int(next_t)
        selected.append(t)
        
        # LOGIC: Only update the median history if:
        # 1. We are in a 'quiet' phase (surge_count == 0)
        # 2. OR the surge has lasted so long it's the new baseline (surge_count > patience)
        if surge_count == 0 or surge_count > patience_limit:
            history.append(activity_array[t])
            if surge_count > patience_limit:
                surge_count = 0 # Reset after updating baseline
                
    return jnp.array(selected)

# %%%%%%%%%%%%%%%%%% offline selector %%%%%%%%%%%%%%%%%%%%%%
def universal_in_situ_orchestrator(activity_array, selection_policy, **kwargs):
    T = len(activity_array)
    selected = [0]
    t = 0
    history = deque(maxlen=kwargs.get('window_size', 20))

    while t < T - 1:
        # We pass activity_array and all kwargs. This part mainly contains all the physics-specific feature associated with the system.
        next_t = selection_policy(t, T, activity_array, **kwargs)
        
        t = int(max(next_t, t + 1))
        t = min(t, T - 1)
        
        selected.append(t)
        history.append(activity_array[t])
        
    return jnp.array(selected)

def bssn_flux_policy(t, T, activity_array, **kwargs):
    """
    Policy for 3+1D BSSN: Cumulative Innovation based on Psi4 Norm.
    No correlation check required.
    """
    window_size = kwargs.get('window_size', 20)
    kappa = kwargs.get('kappa', 0.05)
    
    best_idx = t + 1
    end_search = min(t + window_size, T - 1)
    
    # We track how much 'curvature innovation' happens over the jump
    running_innovation = 0.0
    
    for candidate_idx in range(t + 1, end_search + 1):
        # We look at the magnitude of the activity at the candidate step
        # Note: In BSSN, activity is usually the norm of Psi4 or its derivative
        running_innovation += activity_array[candidate_idx]
        
        # If the accumulated innovation stays within kappa, we can keep jumping
        if running_innovation <= kappa:
            best_idx = candidate_idx
        else:
            # The 'Physics' is changing too fast (Merger detected)
            break
            
    return best_idx


if __name__ == "__main__":
    jax.config.update("jax_default_matmul_precision", "highest")
    vorticity = jnp.load("/2DNS_1024x1024_1000T/vorticity_trajectory.npy")

    vorticity = vorticity.reshape(1000, -1)
    vorticity = (vorticity - vorticity.min()) / (vorticity.max() - vorticity.min())

    selected_timesteps = temporal_selection(
        vorticity,
        queue_size=8,
        norm_metric=max_abs,
        corr_metric=pearson_corr,
        corr_treshold=0.95,
        freq_treshold=(None, None),
    )

    # print("Selected timesteps:", selected_timesteps)
    print("Number of selected timesteps:", selected_timesteps.shape[0])
    print(f"Temporal compression ratio: {vorticity.shape[0] / selected_timesteps.shape[0]:.2f}")
    

    