import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from collections import defaultdict
from matplotlib.gridspec import GridSpec
import matplotlib as mpl

# Import the necessary functions from decoder_comparison
# Ensure decoder_comparison.py is in the same directory or Python path
try:
    from decoder_comparison import run_single_comparison, plot_phased_reach_time_comparison, Synthetic_Neuron, LSTMRegression, SNNRegression, TwoScaleMetaRLWeightUpdaterFull, train_test_kalman_filter, train_lstm_model, compute_correlation # Added missing imports potentially needed by imported functions indirectly
except ImportError as e:
    print(f"Error importing from decoder_comparison.py: {e}")
    print("Make sure run_experiments.py is in the same directory as decoder_comparison.py")
    exit()

# Experiment configuration
NUM_RUNS_PER_DISRUPTION = 10  # Number of times to run each disruption type
BASE_SEED = 42 # Starting seed, will be incremented for each run
OUTPUT_DIR = "experiment_runs_results"

# Define disruption configurations to test
DISRUPTION_CONFIGS = [
    ("dropout", 0.5),
    ("remapping", 0.9), 
    ("drift", 0.9),
]


def aggregate_and_plot_reach_times(all_run_results, filename="decoder_reach_time_comparison_aggregated.pdf", smoothing_window=15, disruption_type=None, disruption_intensity=None):
    """Aggregates reach times across runs and plots mean +/- std dev for each decoder and phase, with phases concatenated on the x-axis for each decoder."""
    if not all_run_results:
        print("No run results to aggregate.")
        return

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from collections import defaultdict
    import matplotlib as mpl
    
    # Set font properties for publication quality with large fonts
    plt.rcParams.update({
        'font.size': 18,
        'font.family': 'Arial',
        'axes.labelsize': 24,
        'xtick.labelsize': 20,
        'ytick.labelsize': 20,
        'legend.fontsize': 20,
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'lines.linewidth': 2.5,
        'axes.linewidth': 1.5,
        'xtick.major.width': 1.5,
        'ytick.major.width': 1.5,
        'xtick.major.size': 5,
        'ytick.major.size': 5
    })

    # Create a single figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    decoder_colors = {'KF': 'blue', 'LSTM': 'green', 'SNN_BPTT': 'red', 'SNN_Online': 'purple', 'SNN-BPTT': 'red'}
    phase_styles = {'INITIAL_LEARNING': '-', 'ADAPT_TO_DISRUPTION': '--'}
    phase_order = ['INITIAL_LEARNING', 'ADAPT_TO_DISRUPTION']

    # Aggregate data: {decoder: {phase: [[run1_times], [run2_times], ...]}}
    aggregated = defaultdict(lambda: defaultdict(list))
    for run_result in all_run_results:
        for decoder, phase_dict in run_result.items():
            for phase, steps_list in phase_dict.items():
                # Convert steps to seconds
                times = np.array(steps_list) * 0.01
                aggregated[decoder][phase].append(times)

    plot_handles = []
    plot_labels = []
    max_x = 0
    
    # For each decoder, concatenate phases on x-axis
    for decoder, phase_dict in aggregated.items():
        color = decoder_colors.get(decoder, 'gray')
        x_offset = 0
        
        for i, phase in enumerate(phase_order):
            if phase in phase_dict:
                runs = phase_dict[phase]
                
                if runs:
                    # Pad runs to the same length
                    max_len = max(len(r) for r in runs)
                    if max_len > 0:
                        run_arr = np.full((len(runs), max_len), np.nan)
                        for j, r in enumerate(runs):
                            run_arr[j, :len(r)] = r
                        
                        # Compute mean and std, ignoring NaNs
                        means = np.nanmean(run_arr, axis=0)
                        stds = np.nanstd(run_arr, axis=0)
                        
                        # Smoothing
                        if len(means) >= smoothing_window:
                            means_smooth = pd.Series(means).rolling(window=smoothing_window, min_periods=1, center=True).mean().to_numpy()
                            stds_smooth = pd.Series(stds).rolling(window=smoothing_window, min_periods=1, center=True).mean().to_numpy()
                        else:
                            means_smooth = means
                            stds_smooth = stds
                        
                        x_vals = np.arange(x_offset + 1, x_offset + 1 + len(means_smooth))
                        
                        # Plot with phase-specific line style
                        line, = ax.plot(x_vals, means_smooth, color=color, linestyle=phase_styles[phase], linewidth=2.5, 
                                       label=f"{decoder}" if i == 0 else None)
                        ax.fill_between(x_vals, means_smooth - stds_smooth, means_smooth + stds_smooth, color=color, alpha=0.15)
                        
                        # Save for legend (only once per decoder)
                        if i == 0:
                            plot_handles.append(line)
                            plot_labels.append(f"{decoder}")
                        
                        # Draw phase boundary after first phase
                        if i == 0:
                            ax.axvline(x=x_vals[-1] + 0.5, color='black', linestyle=':', linewidth=1.5, alpha=0.8)
                        
                        x_offset += len(means_smooth)
                        max_x = max(max_x, x_offset)
    
    # Configure the plot
    ax.set_xlabel("Trial Index", fontsize=24)
    ax.set_ylabel("Time to Reach Target (seconds)", fontsize=24)
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.7)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim(bottom=0)
    ax.set_xlim(left=0, right=max_x + 5)
    
    # Legend
    legend = fig.legend(
        plot_handles, 
        plot_labels, 
        loc='upper right', 
        frameon=False, 
        fontsize=20
    )
    
    plt.tight_layout()
    
    output_path = os.path.join(OUTPUT_DIR, filename)
    try:
        plt.savefig(output_path, dpi=300, bbox_inches='tight', format='pdf')
        print(f"\nSaved aggregated reach time plot to {output_path}")
    except Exception as e:
        print(f"Error saving aggregated plot: {e}")
    plt.close()

if __name__ == "__main__":
    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Track results for each disruption type
    all_disruption_results = {}
    
    print(f"Starting experiments with {len(DISRUPTION_CONFIGS)} disruption types, {NUM_RUNS_PER_DISRUPTION} runs each...")

    for disruption_idx, (disruption_type, disruption_intensity) in enumerate(DISRUPTION_CONFIGS):
        print(f"\n{'='*80}")
        print(f"DISRUPTION TYPE {disruption_idx+1}/{len(DISRUPTION_CONFIGS)}: {disruption_type.upper()} (intensity: {disruption_intensity})")
        print(f"{'='*80}")
        
        all_reach_results = []
        all_summary_dfs = []
        
        for run_idx in range(NUM_RUNS_PER_DISRUPTION):
            current_seed = BASE_SEED + disruption_idx * NUM_RUNS_PER_DISRUPTION + run_idx
            print(f"\n===== Starting Run {run_idx+1}/{NUM_RUNS_PER_DISRUPTION} (Seed: {current_seed}) =====")
            try:
                # run_single_comparison returns (reach_times, summary_dict)
                reach_times, summary_dict_for_run = run_single_comparison(
                    seed_value=current_seed,
                    disruption_type=disruption_type,
                    disruption_intensity=disruption_intensity
                )
                if reach_times is not None and summary_dict_for_run is not None:
                    all_reach_results.append(reach_times)
                    
                    # Flatten the summary_dict_for_run and convert to DataFrame
                    flattened_summary = {}
                    for decoder_name, phase_data in summary_dict_for_run.items():
                        for phase_name, metric_string in phase_data.items():
                            flattened_summary[f"{decoder_name}_{phase_name}"] = metric_string
                    
                    df_one_run = pd.DataFrame([flattened_summary]) # Create a single-row DataFrame
                    df_one_run['Seed'] = current_seed
                    df_one_run['Disruption_Type'] = disruption_type
                    df_one_run['Disruption_Intensity'] = disruption_intensity
                    all_summary_dfs.append(df_one_run) # Append the DataFrame

                    print(f"===== Run {run_idx+1} Completed Successfully ====")
                else:
                     print(f"ERROR: Run {run_idx+1} (Seed: {current_seed}) did not return valid results.")
            except Exception as e:
                print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                print(f"ERROR during Run {run_idx+1} (Seed: {current_seed}): {e}")
                import traceback
                traceback.print_exc() # Print full traceback for debugging
                print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                # Continue with next run if one fails
                continue

        print(f"\n{len(all_reach_results)} out of {NUM_RUNS_PER_DISRUPTION} runs completed successfully for {disruption_type}.")

        # Store results for this disruption type
        all_disruption_results[disruption_type] = {
            'reach_results': all_reach_results,
            'summary_dfs': all_summary_dfs
        }

        # Save individual disruption results
        if all_summary_dfs:
            concatenated_summaries = pd.concat(all_summary_dfs, ignore_index=True)
            csv_path = os.path.join(OUTPUT_DIR, f"performance_summary_{disruption_type}.csv")
            concatenated_summaries.to_csv(csv_path, index=False)
            print(f"Saved {disruption_type} performance summary to {csv_path}")
        
        # Plot individual disruption results
        if all_reach_results:
            plot_filename = f"decoder_reach_time_comparison_{disruption_type}.pdf"
            aggregate_and_plot_reach_times(
                all_reach_results, 
                filename=plot_filename,
                disruption_type=disruption_type,
                disruption_intensity=disruption_intensity
            )
        else:
            print(f"No reach time results collected for {disruption_type}, skipping plot.")

    # --- Create combined summary across all disruption types ---
    print(f"\n{'='*80}")
    print("CREATING COMBINED SUMMARY ACROSS ALL DISRUPTION TYPES")
    print(f"{'='*80}")
    
    all_combined_summaries = []
    for disruption_type, results in all_disruption_results.items():
        if results['summary_dfs']:
            all_combined_summaries.extend(results['summary_dfs'])
    
    if all_combined_summaries:
        combined_summaries = pd.concat(all_combined_summaries, ignore_index=True)
        combined_csv_path = os.path.join(OUTPUT_DIR, "combined_performance_summary_all_disruptions.csv")
        combined_summaries.to_csv(combined_csv_path, index=False)
        print(f"Saved combined performance summary to {combined_csv_path}")
        
        # Print summary statistics
        print("\n===== Combined Performance Summaries Across All Disruption Types ====")
        print(combined_summaries.groupby('Disruption_Type').agg({
            'KF_INITIAL_LEARNING': 'count',
            'LSTM_INITIAL_LEARNING': 'count', 
            'SNN_BPTT_INITIAL_LEARNING': 'count',
            'SNN_Online_INITIAL_LEARNING': 'count'
        }).rename(columns={
            'KF_INITIAL_LEARNING': 'Total_Runs'
        }))
    else:
        print("No summary DataFrames to concatenate across disruption types.")

    print(f"\n{'='*80}")
    print("ALL EXPERIMENTS COMPLETED")
    print(f"Tested {len(DISRUPTION_CONFIGS)} disruption types:")
    for disruption_type, intensity in DISRUPTION_CONFIGS:
        num_runs = len(all_disruption_results.get(disruption_type, {}).get('reach_results', []))
        print(f"  - {disruption_type.upper()}: {num_runs}/{NUM_RUNS_PER_DISRUPTION} successful runs")
    print(f"{'='*80}") 