"""
Two-Point Statistics Analysis for Spinodal Decomposition
Compares three different time step models (10dt, 100dt, 1000dt) with phase-field baseline.

This script:
1. Runs two baseline phase-field simulations to establish error reference
2. Runs 100 ML model inference experiments with different random seeds
3. Collects two-point statistics errors and phase volume fractions
4. Saves all data for later plotting
"""

import os
import sys
import numpy as np
import torch
import torch.nn as nn
import time
import joblib
import random

# Add the FluxNet source path
sys.path.insert(0, '/home/ml4pf/zshlan/FluxNet/src')
from models.fluxnet_n_2d import FluxNet_N

# ==================== Configuration ====================
# Model checkpoint paths
MODEL_PATHS = {
    '10dt': '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_10dt/FluxNet_D_pf/best_checkpoint.pt',
    '100dt': '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_100dt/FluxNet_D_pf/best_checkpoint.pt',
    '1000dt': '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_1000dt/FluxNet_D_pf/best_checkpoint.pt',
}

# Simulation parameters
GRID_SIZE = 1024  # 1024x1024 grid
C0 = 0.60  # Initial mean concentration
NOISE_AMP = 0.05  # Noise amplitude

# Time parameters (in dt units)
WARMUP_TIME = 2000  # 2000dt initial phase-field simulation (corresponds to 0T)
TOTAL_TIME = 102000  # End time (corresponds to 2T)
TRAINING_LENGTH = 50000  # Training data covered 50000dt, so T = 50000dt

# Statistics collection interval
STATS_INTERVAL = 1000  # Collect two-point statistics every 1000dt

# Number of random experiments
NUM_EXPERIMENTS = 100
BASE_SEED = 666

# Output directory
OUTPUT_DIR = '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/two_point_analysis'

# Phase-field simulation parameters
PF_DT = 1.0e-2
PF_M = 1.0
PF_K = 3.57e-1
PF_R = 8.314
PF_T = 973.15

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ==================== Random Seed Functions ====================
def setup_seed(seed):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


# ==================== Noise Generation ====================
def generate_initial_noise(I, J, c0=0.60, noise_amp=0.05, seed=None):
    """
    Generate initial concentration field with random noise.
    con = c0 + noise_amp * (0.5 - random)
    where random is uniformly distributed in [0, 1)
    """
    if seed is not None:
        np.random.seed(seed)
    random_values = np.random.uniform(0, 1, size=(I, J))
    con = c0 + noise_amp * (0.5 - random_values)
    return con


# ==================== Phase-Field Simulation (GPU) ====================
def run_phase_field_step(con_gpu, m, n, dt=PF_DT, M=PF_M, k=PF_K, R=PF_R, T=PF_T):
    """
    Run one step of phase-field simulation on GPU.
    con_gpu should have shape (m+2, n+2) with ghost cells.
    """
    dx = dy = 1.0
    A0 = 15000.0 + 6.1 * T
    A1 = -7600 + 3.55 * T

    # Apply periodic boundary conditions
    # y direction
    con_gpu[0, 1:n + 1] = con_gpu[m, 1:n + 1]
    con_gpu[m + 1, 1:n + 1] = con_gpu[1, 1:n + 1]
    # x direction
    con_gpu[1:m + 1, 0] = con_gpu[1:m + 1, n]
    con_gpu[1:m + 1, n + 1] = con_gpu[1:m + 1, 1]
    # corners
    con_gpu[0, 0] = con_gpu[m, n]
    con_gpu[0, n + 1] = con_gpu[m, 1]
    con_gpu[m + 1, 0] = con_gpu[1, n]
    con_gpu[m + 1, n + 1] = con_gpu[1, 1]

    # Compute chemical free energy derivative
    c_interior = con_gpu[1:m + 1, 1:n + 1]
    c_safe = torch.clamp(c_interior, 1e-6, 1 - 1e-6)

    dcon = torch.zeros_like(con_gpu)
    dcon[1:m + 1, 1:n + 1] = (
        R * T * torch.log(c_safe / (1.0 - c_safe)) +
        (1.0 - 2.0 * c_safe) * A0 +
        (-6.0 * c_safe + 6.0 * c_safe ** 2 + 1.0) * A1
    ) / (R * T)

    # Laplacian of concentration
    c1 = con_gpu[2:m + 2, 1:n + 1]
    c2 = con_gpu[0:m, 1:n + 1]
    c3 = con_gpu[1:m + 1, 2:n + 2]
    c4 = con_gpu[1:m + 1, 0:n]
    c5 = con_gpu[1:m + 1, 1:n + 1]

    lap_con = torch.zeros_like(con_gpu)
    lap_con[1:m + 1, 1:n + 1] = (c1 + c2 + c3 + c4 - 4.0 * c5) / (dx * dy)

    # Chemical potential
    dF = torch.zeros_like(con_gpu)
    dF[1:m + 1, 1:n + 1] = dcon[1:m + 1, 1:n + 1] - 2 * k * lap_con[1:m + 1, 1:n + 1]

    # Apply boundary conditions to dF
    dF[0, 1:n + 1] = dF[m, 1:n + 1]
    dF[m + 1, 1:n + 1] = dF[1, 1:n + 1]
    dF[1:m + 1, 0] = dF[1:m + 1, n]
    dF[1:m + 1, n + 1] = dF[1:m + 1, 1]
    dF[0, 0] = dF[m, n]
    dF[0, n + 1] = dF[m, 1]
    dF[m + 1, 0] = dF[1, n]
    dF[m + 1, n + 1] = dF[1, 1]

    # Laplacian of chemical potential
    F1 = dF[2:m + 2, 1:n + 1]
    F2 = dF[0:m, 1:n + 1]
    F3 = dF[1:m + 1, 2:n + 2]
    F4 = dF[1:m + 1, 0:n]
    F5 = dF[1:m + 1, 1:n + 1]

    lap_dF = torch.zeros_like(con_gpu)
    lap_dF[1:m + 1, 1:n + 1] = (F1 + F2 + F3 + F4 - 4.0 * F5) / (dx * dy)

    # Update concentration
    con_gpu[1:m + 1, 1:n + 1] += dt * M * lap_dF[1:m + 1, 1:n + 1]


def run_phase_field_simulation(initial_con, num_steps, m, n, record_interval=None):
    """
    Run phase-field simulation for specified number of steps.

    Args:
        initial_con: Initial concentration field (m, n) numpy array
        num_steps: Number of simulation steps
        m, n: Grid dimensions
        record_interval: If not None, record state every record_interval steps

    Returns:
        final_con: Final concentration field
        recorded_states: Dict of time -> concentration field (if record_interval is not None)
    """
    # Create GPU tensor with ghost cells
    con_gpu = torch.zeros((m + 2, n + 2), device=DEVICE, dtype=torch.float32)
    con_gpu[1:m + 1, 1:n + 1] = torch.from_numpy(initial_con).to(DEVICE)

    recorded_states = {}

    for step in range(num_steps):
        run_phase_field_step(con_gpu, m, n)

        if record_interval is not None and (step + 1) % record_interval == 0:
            # Record the state
            recorded_states[step + 1] = con_gpu[1:m + 1, 1:n + 1].cpu().numpy().copy()

        if (step + 1) % 10000 == 0:
            print(f"    Phase-field step: {step + 1}/{num_steps}")

    final_con = con_gpu[1:m + 1, 1:n + 1].cpu().numpy()
    return final_con, recorded_states


# ==================== Two-Point Statistics Functions ====================
def autocor(data, periodicity='periodic'):
    """Compute autocorrelation function using FFT"""
    m, n = data.shape

    if periodicity == 'nonperiodic':
        cutoff_1 = int(m * 0.5)
        cutoff_2 = int(n * 0.5)
        image_pad = np.zeros([m + cutoff_1, n + cutoff_2])
        image_pad[0:m, 0:n] = data
        H1 = np.fft.fftn(image_pad)
        H1 = H1 * np.conj(H1)
        H1 = np.fft.ifftn(H1)
        H1 = np.fft.fftshift(H1.real)

        mask = np.zeros([m + cutoff_1, n + cutoff_2])
        mask[0:m, 0:n] = 1
        H2 = np.fft.fftn(mask)
        H2 = H2 * np.conj(H2)
        H2 = np.fft.ifftn(H2)
        H2 = np.fft.fftshift(H2.real)

        image_2pt = H1 / H2
    else:
        H = np.fft.fftn(data)
        H = H * np.conj(H)
        H = np.fft.ifftn(H)
        H = np.fft.fftshift(H.real)
        image_2pt = H / (m * n)

    return image_2pt


def two_point_correlation_radial(image_2pt, max_radius):
    """Compute radial average of two-point correlation function"""
    h, w = image_2pt.shape
    y, x = np.ogrid[:h, :w]
    center_y, center_x = h // 2, w // 2
    r = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)

    radii = np.arange(0, max_radius)
    correlation = np.zeros_like(radii, dtype=float)

    for i, radius in enumerate(radii):
        mask = (r >= radius - 0.5) & (r < radius + 0.5)
        if mask.sum() > 0:
            correlation[i] = np.mean(image_2pt[mask])

    return radii, correlation


def calculate_two_point_statistics(data):
    """Calculate two-point statistics (radial average)"""
    image_2pt = autocor(data, 'periodic')
    max_radius = min(data.shape) // 2
    r, S = two_point_correlation_radial(image_2pt, max_radius)
    return r, S


def calculate_two_point_error(S1, S2):
    """Calculate MAE between two radial correlation functions"""
    min_len = min(len(S1), len(S2))
    S1 = S1[:min_len]
    S2 = S2[:min_len]
    return np.mean(np.abs(S1 - S2))


def calculate_phase_fractions(data, threshold_high=0.8, threshold_low=0.3):
    """Calculate phase volume fractions"""
    phase1_fraction = np.mean(data >= threshold_high)
    phase2_fraction = np.mean(data <= threshold_low)
    return phase1_fraction, phase2_fraction


# ==================== Model Loading ====================
def load_model(checkpoint_path, neighborhood_size):
    """Load a FluxNet_N model from checkpoint"""
    model = FluxNet_N(
        in_channels=1,
        base_channels=32,
        num_blocks=4,
        kernel_size=3,
        act_fn=nn.GELU,
        norm_2d=nn.BatchNorm2d,
        neighborhood_size=neighborhood_size
    )
    model.to(DEVICE)

    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    if 'model_state' in checkpoint:
        model.load_state_dict(checkpoint['model_state'])
    elif 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)

    model.eval()
    return model


def run_model_inference(model, current_phi):
    """Run single step of model inference"""
    with torch.no_grad():
        input_tensor = torch.tensor(current_phi, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
        output = model(input_tensor)
        if isinstance(output, tuple):
            predicted_phi = output[0].squeeze().cpu().numpy()
        else:
            predicted_phi = output.squeeze().cpu().numpy()
    return predicted_phi


# ==================== Main Analysis Functions ====================
def run_baseline_phase_field_analysis(seed1, seed2, output_dir):
    """
    Run two baseline phase-field simulations and compute their two-point statistics errors.
    Also collect radial correlation data at key time points for direct comparison.
    """
    print("=" * 60)
    print("BASELINE PHASE-FIELD ANALYSIS")
    print("=" * 60)

    os.makedirs(output_dir, exist_ok=True)
    cache_file = os.path.join(output_dir, 'baseline_data.joblib')

    # Check if cache exists
    if os.path.exists(cache_file):
        print(f"Loading cached baseline data from {cache_file}")
        return joblib.load(cache_file)

    m = n = GRID_SIZE

    # Generate two different initial conditions
    print(f"\nGenerating initial conditions with seeds {seed1} and {seed2}...")
    initial_con1 = generate_initial_noise(m, n, C0, NOISE_AMP, seed1)
    initial_con2 = generate_initial_noise(m, n, C0, NOISE_AMP, seed2)

    # Run phase-field simulations from 0 to TOTAL_TIME
    # Record states at WARMUP_TIME (0T) onwards, every STATS_INTERVAL
    record_times = list(range(WARMUP_TIME, TOTAL_TIME + 1, STATS_INTERVAL))

    print(f"\nRunning phase-field simulation 1...")
    start_time = time.time()
    _, states1 = run_phase_field_simulation(initial_con1, TOTAL_TIME, m, n, STATS_INTERVAL)
    print(f"  Completed in {time.time() - start_time:.1f}s")

    print(f"\nRunning phase-field simulation 2...")
    start_time = time.time()
    _, states2 = run_phase_field_simulation(initial_con2, TOTAL_TIME, m, n, STATS_INTERVAL)
    print(f"  Completed in {time.time() - start_time:.1f}s")

    # Compute two-point statistics for each recorded time
    print("\nComputing two-point statistics...")
    baseline_times = []
    baseline_errors = []
    pf1_phase1_fractions = []
    pf1_phase2_fractions = []
    pf2_phase1_fractions = []
    pf2_phase2_fractions = []

    # Store radial correlation data for first simulation (for ML comparison)
    pf1_radial_data = {}

    # Key time points for radial function comparison (0T, 1T, 1.5T, 2T)
    # 0T = 2000dt, T = 50000dt, so:
    # 0T = 2000, 1T = 52000, 1.5T = 77000, 2T = 102000
    key_times_T = {
        '0T': WARMUP_TIME,  # 2000
        '1T': WARMUP_TIME + TRAINING_LENGTH,  # 52000
        '1.5T': WARMUP_TIME + int(1.5 * TRAINING_LENGTH),  # 77000
        '2T': TOTAL_TIME  # 102000
    }
    radial_comparison_data = {}

    for t in record_times:
        if t < WARMUP_TIME:
            continue

        if t not in states1 or t not in states2:
            continue

        data1 = states1[t]
        data2 = states2[t]

        # Calculate two-point statistics
        r1, S1 = calculate_two_point_statistics(data1)
        r2, S2 = calculate_two_point_statistics(data2)

        # Calculate error between two PF simulations
        error = calculate_two_point_error(S1, S2)

        # Calculate phase fractions
        p1_1, p2_1 = calculate_phase_fractions(data1)
        p1_2, p2_2 = calculate_phase_fractions(data2)

        baseline_times.append(t)
        baseline_errors.append(error)
        pf1_phase1_fractions.append(p1_1)
        pf1_phase2_fractions.append(p2_1)
        pf2_phase1_fractions.append(p1_2)
        pf2_phase2_fractions.append(p2_2)

        # Store radial data for first PF simulation
        pf1_radial_data[t] = {'r': r1, 'S': S1}

        # Store key time points for radial comparison
        for key, key_t in key_times_T.items():
            if t == key_t:
                radial_comparison_data[key] = {
                    'time': t,
                    'r': r1,
                    'S': S1,
                    'pf1_data': data1.copy()
                }

        if t % 10000 == 0:
            print(f"  Processed t={t}: error={error:.6e}")

    # Average phase fractions between the two simulations
    avg_phase1_fractions = [(p1 + p2) / 2 for p1, p2 in zip(pf1_phase1_fractions, pf2_phase1_fractions)]
    avg_phase2_fractions = [(p1 + p2) / 2 for p1, p2 in zip(pf1_phase2_fractions, pf2_phase2_fractions)]

    baseline_data = {
        'times': baseline_times,
        'errors': baseline_errors,
        'pf1_phase1_fractions': pf1_phase1_fractions,
        'pf1_phase2_fractions': pf1_phase2_fractions,
        'pf2_phase1_fractions': pf2_phase1_fractions,
        'pf2_phase2_fractions': pf2_phase2_fractions,
        'avg_phase1_fractions': avg_phase1_fractions,
        'avg_phase2_fractions': avg_phase2_fractions,
        'pf1_radial_data': pf1_radial_data,
        'radial_comparison_data': radial_comparison_data,
        'states1': states1,  # Store for ML comparison reference
        'seed1': seed1,
        'seed2': seed2,
    }

    # Save baseline data
    joblib.dump(baseline_data, cache_file)
    print(f"\nBaseline data saved to {cache_file}")

    return baseline_data


def run_ml_experiments(models, baseline_data, num_experiments, base_seed, output_dir):
    """
    Run multiple ML inference experiments with different random seeds.
    Each experiment starts from a random initial noise, runs phase-field to 2000dt,
    then runs all three models in parallel to 102000dt.
    """
    print("\n" + "=" * 60)
    print("ML MODEL EXPERIMENTS")
    print("=" * 60)

    os.makedirs(output_dir, exist_ok=True)
    cache_file = os.path.join(output_dir, 'ml_experiments_data.joblib')

    # Check if cache exists
    if os.path.exists(cache_file):
        print(f"Loading cached ML experiments data from {cache_file}")
        return joblib.load(cache_file)

    m = n = GRID_SIZE

    # Reference PF data for comparison (first PF simulation)
    pf1_radial_data = baseline_data['pf1_radial_data']

    # Time points for statistics collection
    stats_times = list(range(WARMUP_TIME, TOTAL_TIME + 1, STATS_INTERVAL))

    # Model time steps
    model_dt = {
        '10dt': 10,
        '100dt': 100,
        '1000dt': 1000
    }

    # Results storage
    all_experiments = {
        '10dt': [],
        '100dt': [],
        '1000dt': []
    }

    # Key time points for radial comparison (only record from first experiment)
    key_times_T = {
        '0T': WARMUP_TIME,
        '1T': WARMUP_TIME + TRAINING_LENGTH,
        '1.5T': WARMUP_TIME + int(1.5 * TRAINING_LENGTH),
        '2T': TOTAL_TIME
    }
    ml_radial_comparison = {
        '10dt': {},
        '100dt': {},
        '1000dt': {}
    }
    first_experiment = True

    for exp_idx in range(num_experiments):
        exp_seed = base_seed + exp_idx
        print(f"\n--- Experiment {exp_idx + 1}/{num_experiments} (seed={exp_seed}) ---")

        # Generate initial noise
        setup_seed(exp_seed)
        initial_con = generate_initial_noise(m, n, C0, NOISE_AMP, exp_seed)

        # Run phase-field warmup (0 to 2000dt)
        print("  Running phase-field warmup...")
        start_time = time.time()
        warmup_con, _ = run_phase_field_simulation(initial_con, WARMUP_TIME, m, n)
        print(f"    Completed in {time.time() - start_time:.1f}s")

        # Initialize all models with the same starting point
        current_states = {
            '10dt': warmup_con.copy(),
            '100dt': warmup_con.copy(),
            '1000dt': warmup_con.copy()
        }

        # Initialize results for this experiment
        exp_results = {
            '10dt': {'times': [], 'errors': [], 'phase1_fractions': [], 'phase2_fractions': []},
            '100dt': {'times': [], 'errors': [], 'phase1_fractions': [], 'phase2_fractions': []},
            '1000dt': {'times': [], 'errors': [], 'phase1_fractions': [], 'phase2_fractions': []}
        }

        # Record initial state (at 0T = 2000dt)
        for model_name in models.keys():
            if WARMUP_TIME in pf1_radial_data:
                r_ml, S_ml = calculate_two_point_statistics(current_states[model_name])
                r_pf, S_pf = pf1_radial_data[WARMUP_TIME]['r'], pf1_radial_data[WARMUP_TIME]['S']
                error = calculate_two_point_error(S_ml, S_pf)
                p1, p2 = calculate_phase_fractions(current_states[model_name])

                exp_results[model_name]['times'].append(WARMUP_TIME)
                exp_results[model_name]['errors'].append(error)
                exp_results[model_name]['phase1_fractions'].append(p1)
                exp_results[model_name]['phase2_fractions'].append(p2)

                # Record radial data for first experiment
                if first_experiment:
                    ml_radial_comparison[model_name]['0T'] = {'r': r_ml, 'S': S_ml}

        # Rollout from 2000dt to 102000dt
        current_times = {name: WARMUP_TIME for name in models.keys()}

        # Progress tracking
        print("  Running ML inference...")
        inference_start = time.time()

        while any(t < TOTAL_TIME for t in current_times.values()):
            for model_name, model in models.items():
                if current_times[model_name] >= TOTAL_TIME:
                    continue

                dt = model_dt[model_name]

                # Run one model step
                current_states[model_name] = run_model_inference(model, current_states[model_name])
                current_times[model_name] += dt

                # Check if we need to record statistics
                current_t = current_times[model_name]
                if current_t in stats_times and current_t in pf1_radial_data:
                    r_ml, S_ml = calculate_two_point_statistics(current_states[model_name])
                    r_pf, S_pf = pf1_radial_data[current_t]['r'], pf1_radial_data[current_t]['S']
                    error = calculate_two_point_error(S_ml, S_pf)
                    p1, p2 = calculate_phase_fractions(current_states[model_name])

                    exp_results[model_name]['times'].append(current_t)
                    exp_results[model_name]['errors'].append(error)
                    exp_results[model_name]['phase1_fractions'].append(p1)
                    exp_results[model_name]['phase2_fractions'].append(p2)

                    # Record radial data at key times for first experiment
                    if first_experiment:
                        for key, key_t in key_times_T.items():
                            if current_t == key_t:
                                ml_radial_comparison[model_name][key] = {'r': r_ml.copy(), 'S': S_ml.copy()}

            # Progress report
            min_time = min(current_times.values())
            if min_time % 10000 == 0:
                print(f"    Progress: {min_time}/{TOTAL_TIME}")

        print(f"    Inference completed in {time.time() - inference_start:.1f}s")

        # Store experiment results
        for model_name in models.keys():
            all_experiments[model_name].append(exp_results[model_name])

        first_experiment = False

    # Aggregate results
    print("\nAggregating results...")
    aggregated_results = {}

    for model_name in models.keys():
        # Get all experiments for this model
        all_exp = all_experiments[model_name]

        # Find common time points
        times = all_exp[0]['times']

        # Collect all errors and phase fractions
        all_errors = np.array([exp['errors'] for exp in all_exp])
        all_phase1 = np.array([exp['phase1_fractions'] for exp in all_exp])
        all_phase2 = np.array([exp['phase2_fractions'] for exp in all_exp])

        aggregated_results[model_name] = {
            'times': times,
            'errors_mean': np.mean(all_errors, axis=0),
            'errors_std': np.std(all_errors, axis=0),
            'phase1_mean': np.mean(all_phase1, axis=0),
            'phase1_std': np.std(all_phase1, axis=0),
            'phase2_mean': np.mean(all_phase2, axis=0),
            'phase2_std': np.std(all_phase2, axis=0),
            'all_errors': all_errors,
            'all_phase1': all_phase1,
            'all_phase2': all_phase2,
        }

    ml_data = {
        'aggregated_results': aggregated_results,
        'ml_radial_comparison': ml_radial_comparison,
        'num_experiments': num_experiments,
        'base_seed': base_seed,
    }

    # Save ML experiments data
    joblib.dump(ml_data, cache_file)
    print(f"\nML experiments data saved to {cache_file}")

    return ml_data


def main():
    """Main function to run the complete analysis"""
    print("=" * 60)
    print("TWO-POINT STATISTICS ANALYSIS FOR SPINODAL DECOMPOSITION")
    print("=" * 60)
    print(f"Device: {DEVICE}")
    print(f"Grid size: {GRID_SIZE}x{GRID_SIZE}")
    print(f"Time range: {WARMUP_TIME} (0T) to {TOTAL_TIME} (2T)")
    print(f"Number of experiments: {NUM_EXPERIMENTS}")
    print("=" * 60)

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Step 1: Run baseline phase-field analysis
    print("\n\n" + "=" * 60)
    print("STEP 1: Baseline Phase-Field Analysis")
    print("=" * 60)
    baseline_data = run_baseline_phase_field_analysis(
        seed1=1234,
        seed2=5678,
        output_dir=OUTPUT_DIR
    )

    # Step 2: Load models
    print("\n\n" + "=" * 60)
    print("STEP 2: Loading Models")
    print("=" * 60)
    models = {}

    # Neighborhood sizes for each model (you may need to adjust based on your training config)
    neighborhood_sizes = {
        '10dt': 3,
        '100dt': 5,
        '1000dt': 9
    }

    for model_name, path in MODEL_PATHS.items():
        print(f"Loading {model_name} model...")
        models[model_name] = load_model(path, neighborhood_sizes[model_name])
        print(f"  Loaded from {path}")

    # Step 3: Run ML experiments
    print("\n\n" + "=" * 60)
    print("STEP 3: ML Model Experiments")
    print("=" * 60)
    ml_data = run_ml_experiments(
        models=models,
        baseline_data=baseline_data,
        num_experiments=NUM_EXPERIMENTS,
        base_seed=BASE_SEED,
        output_dir=OUTPUT_DIR
    )

    # Step 4: Combine and save final results
    print("\n\n" + "=" * 60)
    print("STEP 4: Saving Final Results")
    print("=" * 60)

    final_results = {
        'baseline': {
            'times': baseline_data['times'],
            'errors': baseline_data['errors'],
            'avg_phase1_fractions': baseline_data['avg_phase1_fractions'],
            'avg_phase2_fractions': baseline_data['avg_phase2_fractions'],
            'radial_comparison_data': baseline_data['radial_comparison_data'],
        },
        'ml_results': ml_data['aggregated_results'],
        'ml_radial_comparison': ml_data['ml_radial_comparison'],
        'config': {
            'grid_size': GRID_SIZE,
            'warmup_time': WARMUP_TIME,
            'total_time': TOTAL_TIME,
            'training_length': TRAINING_LENGTH,
            'stats_interval': STATS_INTERVAL,
            'num_experiments': NUM_EXPERIMENTS,
            'base_seed': BASE_SEED,
        }
    }

    final_file = os.path.join(OUTPUT_DIR, 'two_point_analysis_results.joblib')
    joblib.dump(final_results, final_file)
    print(f"Final results saved to {final_file}")

    print("\n" + "=" * 60)
    print("ANALYSIS COMPLETE!")
    print("=" * 60)
    print(f"\nOutput files:")
    print(f"  - Baseline data: {os.path.join(OUTPUT_DIR, 'baseline_data.joblib')}")
    print(f"  - ML experiments: {os.path.join(OUTPUT_DIR, 'ml_experiments_data.joblib')}")
    print(f"  - Final results: {final_file}")


if __name__ == '__main__':
    main()
