import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from collections import defaultdict
from matplotlib.gridspec import GridSpec
import matplotlib as mpl

# Import the necessary functions from decoder_comparison
# Ensure decoder_comparison.py is in the same directory or Python path
try:
    from decoder_comparison_nopretrain import run_single_comparison
except ImportError as e:
    print(f"Error importing from decoder_comparison.py: {e}")
    print("Make sure run_experiments.py is in the same directory as decoder_comparison.py")
    exit()

NUM_RUNS = 10 # Number of times to run the comparison
BASE_SEED = 42 # Starting seed, will be incremented for each run
OUTPUT_DIR = "experiment_runs_results"

# Assume NUM_TARGETS_INITIAL is known for consistent phase boundary plotting.
# This value is from decoder_comparison.py's run_single_comparison settings.
# A more dynamic way could involve parsing this from results if it varied,
# but for now, consistency is assumed.
NUM_TARGETS_INITIAL_FOR_PLOT = 30

def aggregate_and_plot_reach_times(all_run_results, filename="decoder_reach_time_comparison_aggregated.pdf", smoothing_window=15):
    """Aggregates reach times across runs and plots mean +/- std dev for each decoder and phase, with phases concatenated on the x-axis for each decoder."""
    if not all_run_results:
        print("No run results to aggregate.")
        return

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from collections import defaultdict
    import matplotlib as mpl
    
    # Set font properties for publication quality with large fonts
    plt.rcParams.update({
        'font.size': 18,
        'font.family': 'Arial',
        'axes.labelsize': 24,
        'xtick.labelsize': 20,
        'ytick.labelsize': 20,
        'legend.fontsize': 20,
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'lines.linewidth': 2.5,
        'axes.linewidth': 1.5,
        'xtick.major.width': 1.5,
        'ytick.major.width': 1.5,
        'xtick.major.size': 5,
        'ytick.major.size': 5
    })

    # Create a single figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    decoder_colors = {'KF': 'blue', 'LSTM': 'green', 'SNN_BPTT': 'red', 'SNN_Online': 'purple', 'SNN-BPTT': 'red'}
    phase_styles = {'ONLINE_TRAINING_COLLECTION': '-', 'POST_TRAINING_EVALUATION': '--'}
    phase_labels = {'ONLINE_TRAINING_COLLECTION': 'Training', 'POST_TRAINING_EVALUATION': 'Evaluation'}
    phase_order = ['ONLINE_TRAINING_COLLECTION', 'POST_TRAINING_EVALUATION']

    # Aggregate data: {decoder: {phase: [[run1_times], [run2_times], ...]}}
    aggregated = defaultdict(lambda: defaultdict(list))
    for run_result in all_run_results:
        for decoder, phase_dict in run_result.items():
            for phase, steps_list in phase_dict.items():
                # Convert steps to seconds - include ALL data including timeouts
                times = np.array(steps_list) * 0.01
                aggregated[decoder][phase].append(times)

    plot_handles = []
    plot_labels = []
    max_x = 0
    
    # For each decoder, concatenate phases on x-axis
    for decoder, phase_dict in aggregated.items():
        color = decoder_colors.get(decoder, 'gray')
        x_offset = 0
        
        for i, phase in enumerate(phase_order):
            if phase in phase_dict:
                runs = phase_dict[phase]
                
                if runs:
                    # Pad runs to the same length
                    max_len = max(len(r) for r in runs)
                    if max_len > 0:
                        run_arr = np.full((len(runs), max_len), np.nan)
                        for j, r in enumerate(runs):
                            run_arr[j, :len(r)] = r
                        
                        # Compute mean and std, ignoring NaNs
                        means = np.nanmean(run_arr, axis=0)
                        stds = np.nanstd(run_arr, axis=0)
                        
                        # Smoothing
                        if len(means) >= smoothing_window:
                            means_smooth = pd.Series(means).rolling(window=smoothing_window, min_periods=1, center=True).mean().to_numpy()
                            stds_smooth = pd.Series(stds).rolling(window=smoothing_window, min_periods=1, center=True).mean().to_numpy()
                        else:
                            means_smooth = means
                            stds_smooth = stds
                        
                        x_vals = np.arange(x_offset + 1, x_offset + 1 + len(means_smooth))
                        
                        # Plot with phase-specific line style
                        line, = ax.plot(x_vals, means_smooth, color=color, linestyle=phase_styles[phase], linewidth=2.5, 
                                       label=f"{decoder}" if i == 0 else None)
                        ax.fill_between(x_vals, means_smooth - stds_smooth, means_smooth + stds_smooth, color=color, alpha=0.15)
                        
                        # Save for legend (only once per decoder)
                        if i == 0:
                            plot_handles.append(line)
                            plot_labels.append(f"{decoder}")
                        
                        # Draw phase boundary after first phase
                        if i == 0:
                            ax.axvline(x=x_vals[-1] + 0.5, color='black', linestyle=':', linewidth=1.5, alpha=0.8)
                        
                        x_offset += len(means_smooth)
                        max_x = max(max_x, x_offset)
    
    # Configure the plot
    ax.set_xlabel("Trial Index", fontsize=24)
    ax.set_ylabel("Time to Reach Target (seconds)", fontsize=24)
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.7)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim(bottom=0)
    ax.set_xlim(left=0, right=max_x + 5)
    
    # Legend
    legend = fig.legend(
        plot_handles, 
        plot_labels, 
        loc='upper right', 
        frameon=False, 
        fontsize=20
    )
    
    plt.tight_layout()
    
    output_path = os.path.join(OUTPUT_DIR, filename)
    try:
        plt.savefig(output_path, dpi=300, bbox_inches='tight', format='pdf')
        print(f"\nSaved aggregated reach time plot to {output_path}")
    except Exception as e:
        print(f"Error saving aggregated plot: {e}")
    plt.close()

if __name__ == "__main__":
    all_reach_results = []
    all_summary_dfs = []

    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print(f"Starting {NUM_RUNS} experiment runs...")

    for i in range(NUM_RUNS):
        current_seed = BASE_SEED + i
        print(f"\n===== Starting Run {i+1}/{NUM_RUNS} (Seed: {current_seed}) =====")
        try:
            # run_single_comparison returns (reach_times, summary_dict)
            reach_times, summary_dict_for_run = run_single_comparison(seed_value=current_seed)
            print(f"DEBUG: Run {i+1} returned reach_times: {reach_times}") # DEBUG
            if reach_times is not None and summary_dict_for_run is not None:
                all_reach_results.append(reach_times)
                
                # Flatten the summary_dict_for_run and convert to DataFrame
                flattened_summary = {}
                for decoder_name, phase_data in summary_dict_for_run.items():
                    for phase_name, metric_string in phase_data.items():
                        flattened_summary[f"{decoder_name}_{phase_name}"] = metric_string
                
                df_one_run = pd.DataFrame([flattened_summary]) # Create a single-row DataFrame
                df_one_run['Seed'] = current_seed
                all_summary_dfs.append(df_one_run) # Append the DataFrame

                print(f"===== Run {i+1} Completed Successfully ====")
            else:
                 print(f"ERROR: Run {i+1} (Seed: {current_seed}) did not return valid results.")
        except Exception as e:
            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            print(f"ERROR during Run {i+1} (Seed: {current_seed}): {e}")
            import traceback
            traceback.print_exc() # Print full traceback for debugging
            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            # Decide whether to continue or stop if a run fails
            # continue 

    print(f"\n{len(all_reach_results)} out of {NUM_RUNS} runs completed successfully.")

    if not all_summary_dfs:
        print("No summary DataFrames to concatenate. Exiting statistics calculation.")
    else:
        concatenated_summaries = pd.concat(all_summary_dfs, ignore_index=True)
        
        # --- Adjusting the aggregation logic for the new structure ---
        # Columns are now 'KF_ONLINE_TRAINING_COLLECTION', 'KF_POST_TRAINING_EVALUATION', 'LSTM_ONLINE_TRAINING_COLLECTION', etc. plus 'Seed'
        # We need to parse the 'X steps / Y% success' strings to get numerical values for mean/std.
        
        # Example of how one might parse: (This is complex and might need a dedicated function)
        # For simplicity in this edit, I'll assume the downstream code can handle string metrics
        # or that a more sophisticated parsing step will be added later.
        # The crucial part is that pd.concat now works.
        
        # If you need to calculate mean/std of 'steps' and 'success_rate' numerically:
        # You would iterate through columns, parse out the numbers, and then calculate.
        # For now, the existing groupby().mean()/std() will fail if columns are strings.
        
        # Placeholder for print and save if numerical aggregation is deferred:
        print("\n===== Concatenated Performance Summaries Across All Runs ====")
        print(concatenated_summaries)
        csv_path = os.path.join(OUTPUT_DIR, "concatenated_performance_summary.csv")
        concatenated_summaries.to_csv(csv_path, index=False)
        print(f"Saved concatenated performance summary table to {csv_path}")

        # The old final_summary_table logic will need to be re-thought if metrics are parsed numerically.
        # For now, let's comment out the part that calculates mean/std on potentially string columns,
        # as it will likely error out or produce non-sensical results with "X steps / Y% success" strings.
        
        # print("\n===== Average Performance Across All Runs (Further Processing Needed for Numerical Stats) ====")
        # # average_summary = concatenated_summaries.groupby('Seed').mean() # This won't work directly on string metrics
        # # std_summary = concatenated_summaries.groupby('Seed').std()
        # # final_summary_table = average_summary.copy()
        # # for col in average_summary.columns:
        # #     if col != 'Seed': # Avoid trying to format the Seed column if it's numeric
        # #         final_summary_table[col] = average_summary[col].map('{:.3f}'.format) + ' ± ' + std_summary[col].map('{:.3f}'.format)
        # # print(final_summary_table)
        # # final_summary_table.to_csv(os.path.join(OUTPUT_DIR, "average_performance_summary.csv"))
        # # print(f"Saved average performance summary table to {os.path.join(OUTPUT_DIR, 'average_performance_summary.csv')}")

    # --- Plot Aggregated Reach Times --- 
    if all_reach_results:
        print(f"DEBUG: Final all_reach_results before plotting: {all_reach_results}") # DEBUG
        
        # Add more detailed debugging
        for i, run_result in enumerate(all_reach_results):
            print(f"DEBUG: Run {i} data:")
            for decoder, phase_dict in run_result.items():
                print(f"  {decoder}: {list(phase_dict.keys())}")
                for phase, data in phase_dict.items():
                    print(f"    {phase}: {len(data)} data points, first few: {data[:5] if data else 'empty'}")
        
        aggregate_and_plot_reach_times(all_reach_results)
    else:
        print("No reach time results collected, skipping aggregated plot.")

    print("\nExperiment runs finished.") 