import jax 
import jax.numpy as jnp 
from collections import deque 
from typing import Callable, Tuple 

@jax.jit
def max_abs(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.max(jnp.abs(x - y))

@jax.jit
def l2_norm(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.linalg.norm(x - y)

@jax.jit
def mae(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean(jnp.abs(x - y))

@jax.jit
def mse(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean((x - y) ** 2)

@jax.jit
def pearson_corr(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.corrcoef(x.flatten(), y.flatten())[0, 1]

def momentum_selector(trajectory: jax.Array,
                      window_size: int,
                      norm_metric: Callable[[jax.Array, jax.Array], jax.Array] = max_abs,
                      k: int = 1.0) -> jax.Array:
    T = trajectory.shape[0]

    selected_timesteps = [0]

    history = []

    for idx in range(1, window_size + 1):
        selected_timesteps.append(idx)
        e_t = norm_metric(trajectory[idx], trajectory[idx - 1])
        history.append(e_t)
    
    mu = jnp.mean(jnp.array(history))
    sigma = jnp.std(jnp.array(history))
    last_selected = window_size

    for idx in range(window_size + 1, T):
        e_t = norm_metric(trajectory[idx], trajectory[last_selected])

        history.append(e_t)

        if e_t > mu + k * sigma:
            selected_timesteps.append(idx)
            last_selected = idx
            mu = jnp.mean(jnp.array(history[-window_size:]))
            sigma = jnp.std(jnp.array(history[-window_size:]))

    return jnp.array(selected_timesteps)


def temporal_selection(trajectory: jax.Array,
                       queue_size: int,
                       norm_metric: Callable[[jax.Array, jax.Array], jax.Array] = max_abs,
                       corr_metric: Callable[[jax.Array, jax.Array], jax.Array] = pearson_corr,
                       corr_treshold: float = 0.9,
                       freq_treshold: Tuple[float, float] = (None, None)) -> jax.Array:
    """
    Selects the most informative time steps from a trajectory based on a norm metric and correlation metric.
    
    Args:
        trajectory: jax.Array of shape (T, ...) representing the trajectory over time.
        queue_size: int, number of time steps to select.
        norm_metric: Callable, function to compute the norm difference between two time steps.
        corr_metric: Callable, function to compute the correlation between two time steps.
        freq_treshold: Tuple[float, float], optional comparison, mostly relevant in the context of high-frequency data hard to capture by the model.
            First element is the treshhold frequency, after which frequency is considered to be hard to capture for the model, while the second element
            is the factor by which the selection criterion is scaled.

    Returns:
        jax.Array of shape (T', ...) containing the selected time steps.
    """
    start_idx = 0
    queue = deque(maxlen=queue_size)
    T = trajectory.shape[0]
    queue.append(1)
    maxx_diff = norm_metric(trajectory[1], trajectory[0])
    curr_diff_sum = norm_metric(trajectory[1], trajectory[0])
    selected_timesteps = [0]
    last_idx = 1

    while start_idx < T - 1:

        if len(queue) == queue_size or queue[-1] == T - 1:
            curr_mean = curr_diff_sum / len(queue)
            factor = jnp.sqrt(maxx_diff / curr_mean).item()
            if freq_treshold[0] is not None:
                factor *= freq_treshold[1]
            best_idx = start_idx + 1
            diff0 = norm_metric(trajectory[best_idx], trajectory[start_idx])
            for idx in queue:
                norm_val = norm_metric(trajectory[idx], trajectory[start_idx])
                corr_val = corr_metric(trajectory[idx], trajectory[start_idx])
                if norm_val / diff0 <= factor and corr_val >= corr_treshold:
                    best_idx = idx
                else:
                    break
            
            # print(f"Selected idx: {best_idx} from {start_idx} with jump {best_idx-start_idx}")
            start_idx = best_idx
            selected_timesteps.append(best_idx)
            while len(queue) > 0 and queue[0] <= start_idx:
                prev_idx = queue.popleft()
                norm_val = norm_metric(trajectory[prev_idx], trajectory[prev_idx - 1])
                curr_diff_sum -= norm_val

        if last_idx + 1 < T:
            queue.append(last_idx + 1)
            last_idx += 1
            norm_val = norm_metric(trajectory[queue[-1]], trajectory[queue[-1] - 1])
            maxx_diff = max(maxx_diff, norm_val)
            curr_diff_sum += norm_val

    return jnp.array(selected_timesteps)

class StreamingTemporalSelector:
    """
    Maintains the state of maxx_diff and curr_diff_sum across 
    multiple simulation windows to ensure global adaptive logic.
    """
    def __init__(self, corr_threshold=0.9, window_size=5):
        self.corr_threshold = corr_threshold
        self.window_size = window_size
        self.maxx_diff = 0.0
        self.curr_diff_sum = 0.0
        self.count = 0
        self.initialized = False

    def select_jump(self, window_real):
        # window_real[0] is current state, [1:] are candidates
        T = window_real.shape[0]
        
        # Initialization logic (Warmup baseline)
        if not self.initialized:
            diff = max_abs(window_real[1], window_real[0])
            self.maxx_diff = diff
            self.curr_diff_sum = diff
            self.count = 1
            self.initialized = True

        # Calculate the adaptive factor based on GLOBAL history
        curr_mean = self.curr_diff_sum / self.count
        factor = jnp.sqrt(self.maxx_diff / (curr_mean + 1e-8)).item()
        
        # Reference frame for this window selection
        ref_frame = window_real[0]
        diff0 = max_abs(window_real[1], window_real[0])
        
        best_idx = 1 # Minimum jump
        for i in range(1, T):
            norm_val = max_abs(window_real[i], ref_frame)
            corr_val = pearson_corr(window_real[i], ref_frame)
            
            # Use your logic: Norm ratio and Correlation threshold
            if (norm_val / (diff0 + 1e-8)) <= factor and corr_val >= self.corr_threshold:
                best_idx = i
            else:
                break
        
        # Update Global Metrics for the next window
        # We update based on the step we are actually taking
        actual_diff = max_abs(window_real[best_idx], window_real[best_idx-1])
        self.maxx_diff = max(self.maxx_diff, actual_diff)
        self.curr_diff_sum += actual_diff
        self.count += 1
        
        return best_idx
    
# %%%%%%%%%%% median absolute deviation based selector %%%%%%%%%%%%
def online_selector_mad(activity, window=20, kappa=3.5):
    T = len(activity)
    buf = deque(maxlen=window)

    selected = [0]
    t = 0
    while t < T - 1:
        buf.append(activity[t])

        if len(buf) < window:
            dt = 1
        else:
            med = jnp.median(jnp.array(buf))
            mad = jnp.median(jnp.abs(jnp.array(buf) - med)) + 1e-8
            score = jnp.abs(activity[t] - med) / mad
            dt = 1 if score > kappa else window

        t = min(t + dt, T - 1)
        selected.append(t)
    return jnp.array(selected)

############## In-situ BSSN selector ################
def persistent_median_policy(t, T, activity_array, history, factor, **kwargs):
    """
    Jumps if activity is near the baseline. 
    Freezes the baseline during spikes to force dense sampling.
    """
    window_size = kwargs.get('window_size', 5)
    patience_limit = window_size * 5
    
    surge_count = kwargs.get('surge_count', 0)
    
    # 1. Initialization
    if len(history) < 20:
        return t + 1, 0 # Return next_t and updated surge_count

    # 2. Establish the current baseline
    # We use the median of the 'History' which only contains 'accepted' baseline values
    local_median = jnp.median(jnp.array(history))
    
    current_activity = activity_array[t+1]
    
    # 3. Decision Logic
    if current_activity > local_median * factor: # 20% buffer above median
        # WE ARE IN A SURGE (e.g., the merger at t=150 in your plot)
        # Force dense sampling and increment counter
        new_surge_count = surge_count + 1
        return t + 1, new_surge_count
    else:
        # WE ARE IN THE BASELINE
        # Attempt to jump
        new_surge_count = 0 # Reset because we are back to 'normal'
        best_idx = min(t + window_size, T - 1)
        
        # Check ahead to ensure we don't jump OVER the start of a merger
        for candidate_idx in range(t + 1, best_idx + 1):
            if activity_array[candidate_idx] > local_median * factor:
                return candidate_idx, 0 # Stop early if merger starts mid-window
                
        return best_idx, new_surge_count

def persistent_orchestrator(activity_array, window_size=5):
    T = len(activity_array)
    t = 0
    selected = [0]
    history = deque(maxlen=60) # 50
    surge_count = 0
    patience_limit = window_size * 5 # 6

    while t < T - 1:
        # Get decision and updated counter
        next_t, surge_count = persistent_median_policy(
            t, T, activity_array, history, 
            window_size=window_size, surge_count=surge_count)
        
        # Update pointer
        t = int(next_t)
        selected.append(t)
        
        # LOGIC: Only update the median history if:
        # 1. We are in a 'quiet' phase (surge_count == 0)
        # 2. OR the surge has lasted so long it's the new baseline (surge_count > patience)
        if surge_count == 0 or surge_count > patience_limit:
            history.append(activity_array[t])
            if surge_count > patience_limit:
                surge_count = 0 # Reset after updating baseline
                
    return jnp.array(selected)
######################################################################

if __name__ == "__main__": 
    key = jax.random.PRNGKey(0)
    traj = jax.random.uniform(key, (1000, 256, 256, 1))

    selected_indices = temporal_selection(traj, queue_size=10, norm_metric=max_abs, corr_metric=pearson_corr, corr_treshold=0.9, freq_treshold=(None, None))
    print("Selected indices:", selected_indices)