import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# ─── CONFIG ────────────────────────────────────────────────────────────────
BASE_DIR = "/drive2/Kuntal/Pysindy-experiment"
DATA_DIR = f"{BASE_DIR}/aptos_theta_data"  # Where your existing data files are
THETA_NPY = f"{DATA_DIR}/aptos_all_thetas.npy"
THETA_IDS = f"{DATA_DIR}/aptos_theta_ids.npy"
CSV_PATH = f"{DATA_DIR}/aptos_thetas.csv"
OUTPUT_DIR = f"{BASE_DIR}/conformal_boundary/aptos-output"
PLOTS_DIR = f"{OUTPUT_DIR}/plots"

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

# Random seed for reproducibility
RANDOM_SEED = 42

# ─── MARD CALCULATION FUNCTIONS ─────────────────────────────────────────────
def calculate_robust_mard(theta1, theta2, max_value=100.0):
    """
    Calculate Mean Absolute Relative Difference with numerical stability improvements
    """
    epsilon = 1e-6  # Larger epsilon to avoid extreme values
    
    # Flatten the theta arrays if they are not already flat
    if theta1.ndim > 1:
        theta1_flat = theta1.flatten()
        theta2_flat = theta2.flatten()
    else:
        theta1_flat = theta1
        theta2_flat = theta2
    
    # Calculate absolute relative differences with capping
    abs_diff = np.abs(theta1_flat - theta2_flat)
    denominator = np.abs(theta2_flat) + epsilon
    rel_diff = abs_diff / denominator
    
    # Cap extreme values
    rel_diff = np.minimum(rel_diff, max_value)
    
    # Return mean of absolute relative differences
    return np.mean(rel_diff)

def calculate_normalized_mae(theta1, theta2):
    """
    Alternative metric: Normalized Mean Absolute Error
    """
    if theta1.ndim > 1:
        theta1_flat = theta1.flatten()
        theta2_flat = theta2.flatten()
    else:
        theta1_flat = theta1
        theta2_flat = theta2
    
    abs_diff = np.abs(theta1_flat - theta2_flat)
    norm_factor = np.mean(np.abs(theta2_flat)) + 1e-6
    
    return np.mean(abs_diff) / norm_factor

# ─── MAIN FUNCTION ───────────────────────────────────────────────────────────
def main():
    print("Loading APTOS theta data...")
    
    # Load theta arrays and IDs
    thetas = np.load(THETA_NPY, allow_pickle=True)
    ids = np.load(THETA_IDS, allow_pickle=True)
    
    # Load CSV for diagnosis information
    try:
        df = pd.read_csv(CSV_PATH)
        # Create a mapping from id to diagnosis
        id_to_diagnosis = dict(zip(df['id_code'], df['diagnosis']))
        has_diagnosis = True
    except Exception as e:
        print(f"Warning: Could not load diagnosis data from CSV: {e}")
        id_to_diagnosis = {}
        has_diagnosis = False
    
    print(f"Loaded {len(thetas)} theta arrays with shape {thetas[0].shape}")
    
    # Diagnostic: Analyze theta value ranges
    all_thetas = np.vstack([t.flatten() for t in thetas])
    non_zero_thetas = all_thetas[all_thetas != 0]
    
    if len(non_zero_thetas) > 0:
        min_theta = np.min(np.abs(non_zero_thetas))
        max_theta = np.max(np.abs(all_thetas))
        mean_theta = np.mean(np.abs(all_thetas))
        median_theta_val = np.median(np.abs(all_thetas))
        
        print("\nTheta Value Statistics:")
        print(f"Min non-zero absolute value: {min_theta:.10e}")
        print(f"Max absolute value: {max_theta:.10e}")
        print(f"Mean absolute value: {mean_theta:.10e}")
        print(f"Median absolute value: {median_theta_val:.10e}")
        print(f"Max/Min ratio: {max_theta/min_theta:.10e}")
        
        # If extreme ratio, warn about potential numerical issues
        if max_theta/min_theta > 1e6:
            print("WARNING: Extreme ratio detected. Using robust MARD calculation.")
    
    # Step 1: Split data into 60% train, 40% test as specified on the whiteboard
    train_thetas, test_thetas, train_ids, test_ids = train_test_split(
        thetas, ids, test_size=0.4, random_state=RANDOM_SEED
    )
    
    # Step 2: Further split training data into proper-train and calibration sets
    proper_train_thetas, calibration_thetas, proper_train_ids, calibration_ids = train_test_split(
        train_thetas, train_ids, test_size=0.4, random_state=RANDOM_SEED
    )
    
    print(f"Split into {len(proper_train_thetas)} proper-train, {len(calibration_thetas)} calibration, and {len(test_thetas)} test samples")
    
    # Step 3: Calculate median theta from proper-train set
    stacked_proper_train_thetas = np.stack(proper_train_thetas)
    median_theta = np.median(stacked_proper_train_thetas, axis=0)
    
    print(f"Calculated median theta with shape {median_theta.shape}")
    
    # Diagnostic: Check for zeros in median theta
    zero_count = np.sum(median_theta == 0)
    if zero_count > 0:
        print(f"WARNING: Median theta contains {zero_count} zero values out of {median_theta.size}")
    
    # Step 4: Calculate robust MARD for calibration samples against the median theta
    calibration_mard_values = []
    for theta in calibration_thetas:
        # Use robust MARD calculation
        mard = calculate_robust_mard(theta, median_theta)
        calibration_mard_values.append(mard)
    
    calibration_mard_values = np.array(calibration_mard_values)
    
    # Step 5: Calculate statistics from calibration set
    mean_mard = np.mean(calibration_mard_values)
    std_mard = np.std(calibration_mard_values)
    
    # Calculate conformal boundaries, ensuring non-negative lower bound
    lower_bound = max(0, mean_mard - 2 * std_mard)
    upper_bound = mean_mard + 2 * std_mard
    
    print("\nMARD Statistics (from calibration set):")
    print(f"Mean MARD: {mean_mard:.6f}")
    print(f"Std MARD: {std_mard:.6f}")
    print(f"Conformal boundaries: [{lower_bound:.6f}, {upper_bound:.6f}]")
    
    # Step 6: Calculate MARD for test samples
    test_mard_values = []
    test_diagnoses = []
    in_domain_flags = []
    
    for i, theta in enumerate(test_thetas):
        # Use robust MARD calculation
        mard = calculate_robust_mard(theta, median_theta)
        test_mard_values.append(mard)
        
        # Get diagnosis if available
        if has_diagnosis:
            test_id = test_ids[i]
            diagnosis = id_to_diagnosis.get(test_id, None)
            test_diagnoses.append(diagnosis)
        
        # Determine if in domain (within conformal bounds)
        in_domain = lower_bound <= mard <= upper_bound
        in_domain_flags.append(in_domain)
    
    test_mard_values = np.array(test_mard_values)
    
    # Save median theta for future use
    np.save(os.path.join(OUTPUT_DIR, "aptos_median_theta.npy"), median_theta)
    
    # Plot calibration MARD distribution
    plt.figure(figsize=(10, 6))
    plt.hist(calibration_mard_values, bins=30, alpha=0.7, color='blue', label='Calibration Set')
    plt.axvline(mean_mard, color='red', linestyle='--', label=f'Mean: {mean_mard:.4f}')
    plt.axvline(lower_bound, color='green', linestyle='--', label=f'Lower bound: {lower_bound:.4f}')
    plt.axvline(upper_bound, color='green', linestyle='--', label=f'Upper bound: {upper_bound:.4f}')
    plt.xlabel('MARD Value')
    plt.ylabel('Frequency')
    plt.title('Distribution of Calibration MARD Values (APTOS Dataset)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(PLOTS_DIR, "aptos_calibration_mard_distribution.png"))
    
    # Plot test MARD distribution
    plt.figure(figsize=(10, 6))
    plt.hist(test_mard_values, bins=30, alpha=0.7, color='orange', label='Test Set')
    plt.axvline(mean_mard, color='red', linestyle='--', label=f'Mean (cal): {mean_mard:.4f}')
    plt.axvline(lower_bound, color='green', linestyle='--', label=f'Lower bound: {lower_bound:.4f}')
    plt.axvline(upper_bound, color='green', linestyle='--', label=f'Upper bound: {upper_bound:.4f}')
    plt.xlabel('MARD Value')
    plt.ylabel('Frequency')
    plt.title('Distribution of Test MARD Values (APTOS Dataset)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(PLOTS_DIR, "aptos_test_mard_distribution.png"))
    
    # Another view: log-scale histogram (can help with extreme distributions)
    plt.figure(figsize=(10, 6))
    plt.hist(test_mard_values, bins=30, alpha=0.7, color='orange', label='Test Set')
    plt.axvline(mean_mard, color='red', linestyle='--', label=f'Mean: {mean_mard:.4f}')
    plt.axvline(lower_bound, color='green', linestyle='--', label=f'Lower bound: {lower_bound:.4f}')
    plt.axvline(upper_bound, color='green', linestyle='--', label=f'Upper bound: {upper_bound:.4f}')
    plt.xscale('log')
    plt.xlabel('MARD Value (log scale)')
    plt.ylabel('Frequency')
    plt.title('Log-Scale Distribution of Test MARD Values (APTOS Dataset)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(PLOTS_DIR, "aptos_test_mard_log_distribution.png"))
    
    # Save results with in-domain flags and diagnoses
    results_dict = {
        'id': test_ids,
        'mard': test_mard_values,
        'in_domain': in_domain_flags
    }
    
    # Add diagnosis if available
    if has_diagnosis:
        results_dict['diagnosis'] = test_diagnoses
        
    results_df = pd.DataFrame(results_dict)
    results_df.to_csv(os.path.join(OUTPUT_DIR, "aptos_test_results.csv"), index=False)
    
    # Save conformal boundaries info
    with open(os.path.join(OUTPUT_DIR, "aptos_conformal_bounds.txt"), 'w') as f:
        f.write(f"Mean MARD (from calibration): {mean_mard}\n")
        f.write(f"Std MARD (from calibration): {std_mard}\n")
        f.write(f"Lower bound (max(0, mean-2std)): {lower_bound}\n")
        f.write(f"Upper bound (mean+2std): {upper_bound}\n")
        f.write(f"Percentage of test samples in domain: {np.mean(in_domain_flags)*100:.2f}%\n")
        
        # Add breakdown by diagnosis if available
        if has_diagnosis and None not in test_diagnoses:
            f.write("\nIn-domain percentage by diagnosis level:\n")
            for level in sorted(set(test_diagnoses)):
                level_indices = [i for i, d in enumerate(test_diagnoses) if d == level]
                level_in_domain = [in_domain_flags[i] for i in level_indices]
                f.write(f"DR Level {level}: {np.mean(level_in_domain)*100:.2f}% in domain\n")
    
    print(f"\nResults saved to {OUTPUT_DIR}")
    print(f"Plots saved to {PLOTS_DIR}")
    print(f"Percentage of test samples in domain: {np.mean(in_domain_flags)*100:.2f}%")

if __name__ == "__main__":
    main()