import yaml
import sys
import os

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.lines import Line2D

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sys.path.append('../') 

from src.seip_package.utils.data_loading import *
from src.seip_package.utils.sensor_failure_sim import *
from src.seip_package.utils.testing import evaluate_models
from src.seip_package.utils.custom_datasets import *

# Save to CSV
EVAL_NEW = False
device = "mps"
folder_dir = "seeding_experiment"

if EVAL_NEW:
    BENCHMARK = "revs_benchmark_seed_42"

    PATH_TO_CONFS = os.path.join(project_root, "configs/seeding_experiment/")

    args = yaml.load(open(f"{PATH_TO_CONFS}{BENCHMARK}.yaml" , "r"), Loader=yaml.FullLoader)
    mean_std_yaml = yaml.load(open(f"scaling_info/{BENCHMARK}_mean_std.yaml", "r"), Loader=yaml.FullLoader)
    mean = mean_std_yaml["mean"]
    std = mean_std_yaml["std"]

    mean = torch.tensor(list(mean.values()))
    std = torch.tensor(list(std.values()))

    iid_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Montery_Motorsports_Reunion_Test",
                                        batchsize = 16384,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    ood_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Targa_Sixty_Six",
                                        batchsize = 16384,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    feat_dim = next(iter(iid_test_dataloader))[0].size(2)

    seeds = [42, 43, 44, 45, 46, 47, 48, 49, 50, 51]
    mask_names = ["mean", "bias", "noise", "scaling"]
    bias_mult = 1.0
    noise_l = 0.4
    scaling_mult = 2

    models = {}
    for seed in seeds:

        model_names = [f"revs_benchmark_seed_{seed}",
                f"revs_m_b_n_ft_seed_{seed}",
                f"revs_m_ft_seed_{seed}"
            ]

        pretrain_tasks = model_names
        pretrain_tasks = [f"{PATH_TO_CONFS}{task}.yaml" for task in pretrain_tasks]

        df = evaluate_models(mask_names, iid_test_dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device = device)
        df_ood = evaluate_models(mask_names, ood_test_dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device= device)

        iid_csv_path = os.path.join(folder_dir, f"iid_rob_bias1_noise04_scale2_seed_{seed}.csv")
        ood_csv_path = os.path.join(folder_dir, f"ood_rob_bias1_noise04_scale2_seed_{seed}.csv")
        df.to_csv(iid_csv_path)
        df_ood.to_csv(ood_csv_path)

seeds = [42, 43, 44, 45, 46, 47, 48, 49, 50, 51]
folder_dir = "seeding_experiment" 
model_bases = ["revs_benchmark", "revs_m_ft", "revs_m_b_n_ft"] 

# --- Load and Aggregate IID Data ---
iid_dfs = []
for seed in seeds:
    filepath = os.path.join(folder_dir, f"iid_rob_bias1_noise04_scale2_seed_{seed}.csv")
    if os.path.exists(filepath):
        try:
            df_seed = pd.read_csv(filepath, index_col=0)
            df_seed.columns = [col.replace(f"_seed_{seed}", "") for col in df_seed.columns]

            if df_seed.index.name is None:
                 df_seed.index.name = 'feature'
            df_seed['seed'] = seed
            iid_dfs.append(df_seed.reset_index())
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
    else:
        print(f"Warning: File not found {filepath}")

if iid_dfs:
    combined_iid_df = pd.concat(iid_dfs, ignore_index=True)

    cols_to_average = []
    for base in model_bases:
        mse_col = f"{base}_mse_bootstrapped"
        lower_col = f"{base}_ci_lower"
        upper_col = f"{base}_ci_upper"

        if mse_col in combined_iid_df.columns and lower_col in combined_iid_df.columns and upper_col in combined_iid_df.columns:
            cols_to_average.extend([mse_col, lower_col, upper_col])
        else:
            print(f"Warning: Missing one or more columns for base '{base}' in combined IID data. Skipping averaging for this base.")

    if cols_to_average:
         mean_iid_results = combined_iid_df.groupby(['feature', 'error'])[cols_to_average].mean().reset_index()
         print("Aggregated IID Results (Mean across Seeds):")
         print(mean_iid_results.head())
    else:
         print("Error: No valid columns found for aggregation in IID data.")
         mean_iid_results = pd.DataFrame()
else:
    print("No IID seed data loaded.")
    mean_iid_results = pd.DataFrame()

# --- Load and Aggregate OOD Data ---
ood_dfs = []
for seed in seeds:
    filepath = os.path.join(folder_dir, f"ood_rob_bias1_noise04_scale2_seed_{seed}.csv")
    if os.path.exists(filepath):
        try:
            df_seed = pd.read_csv(filepath, index_col=0)

            df_seed.columns = [col.replace(f"_seed_{seed}", "") for col in df_seed.columns]
            if df_seed.index.name is None:
                 df_seed.index.name = 'feature'
            df_seed['seed'] = seed
            ood_dfs.append(df_seed.reset_index())
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
    else:
        print(f"Warning: File not found {filepath}")

if ood_dfs:
    combined_ood_df = pd.concat(ood_dfs, ignore_index=True)

    cols_to_average_ood = []
    for base in model_bases:
        mse_col = f"{base}_mse_bootstrapped"
        lower_col = f"{base}_ci_lower"
        upper_col = f"{base}_ci_upper"
        if mse_col in combined_ood_df.columns and lower_col in combined_ood_df.columns and upper_col in combined_ood_df.columns:
            cols_to_average_ood.extend([mse_col, lower_col, upper_col])
        else:
            print(f"Warning: Missing one or more columns for base '{base}' in combined OOD data. Skipping averaging for this base.")

    if cols_to_average_ood:
        mean_ood_results = combined_ood_df.groupby(['feature', 'error'])[cols_to_average_ood].mean().reset_index()
        print("\nAggregated OOD Results (Mean across Seeds):")
        print(mean_ood_results.head())
    else:
        print("Error: No valid columns found for aggregation in OOD data.")
        mean_ood_results = pd.DataFrame()
else:
    print("No OOD seed data loaded.")
    mean_ood_results = pd.DataFrame()

plot_order_models = ['no_failure', 'chassisAccelRR', 'chassisAccelRL', 'chassisAccelFR', 'chassisAccelFL',
                     'axCG', 'brake', 'throttle', 'engineSpeed', 'handwheelAngle', 'ayCG', 'yawRate']

# Create a mapping from feature name to integer index for plotting
feature_to_x = {feature: i for i, feature in enumerate(plot_order_models)}

# Flipped: 4 rows (errors), 2 columns (datasets)
fig_models, axes_models = plt.subplots(4, 2, figsize=(10, 14), sharex=True, sharey=True, constrained_layout=True)
fig_models.set_facecolor('white')

datasets_plotting = [(combined_iid_df, "Same Track"), (combined_ood_df, "New Track")]
errors = ['bias', 'mean', 'noise', 'scaling']

# Define MSE columns and corresponding plot properties
model_plot_info = {
    "revs_benchmark": {
        "mse_col": "revs_benchmark_mse_bootstrapped",
        "color": (102/255, 0/255, 150/255),
        "label": "Baseline (no pretraining)",
        "offset": -0.2
    },
    "revs_m_ft": {
        "mse_col": "revs_m_ft_mse_bootstrapped",
        "color": (178/255, 24/255, 43/255),
        "label": "Baseline (mean pretraining)",
        "offset": 0.0
    },
    "revs_m_b_n_ft": {
        "mse_col": "revs_m_b_n_ft_mse_bootstrapped",
        "color": (7/255, 147/255, 195/255),
        "label": "Ours (mean, bias, noise pretraining)",
        "offset": 0.2
    }
}

point_size = 15
point_alpha = 0.6
plotted_legend_labels = set()

# Loop through errors first (rows), then datasets (columns)
for j, error in enumerate(errors):
    for i, (dataset, dataset_label) in enumerate(datasets_plotting):
        ax = axes_models[j, i]

        # Set dataset title only for the top row
        if j == 0:
            ax.set_title(dataset_label, fontsize=14, pad=15) # Add dataset title at the top of each column

        # Filter data for the current error type
        error_df_filtered = dataset[dataset['error'] == error].copy()

        # Map feature names to x-coordinates
        error_df_filtered['x_pos_base'] = error_df_filtered['feature'].map(feature_to_x)

        # Drop rows where feature wasn't in plot_order_models (x_pos_base is NaN)
        error_df_filtered.dropna(subset=['x_pos_base'], inplace=True)

        # Plot points for each model with horizontal jitter
        for model_key, info in model_plot_info.items():
            mse_col = info["mse_col"]
            plot_data = error_df_filtered.dropna(subset=[mse_col]) # Drop NaNs for this specific model's MSE

            if not plot_data.empty:
                x_jittered = plot_data['x_pos_base'] + info["offset"]
                label = info["label"]

                ax.scatter(x_jittered, plot_data[mse_col],
                           color=info["color"],
                           alpha=point_alpha,
                           s=point_size,
                           label=label, # Add label only once
                           zorder=3,
                           edgecolors='none')

        ax.grid(True, linestyle='--', alpha=0.7, zorder=0, axis='y') # Grid lines horizontal

        # Set error type label on the left side of each subplot row
        if i == 0:
             ax.text(-0.15, 0.5, f'{error.capitalize()}', transform=ax.transAxes, fontsize=14,
                     verticalalignment='center', horizontalalignment='right', rotation=90)

        # Set x-axis ticks and labels only for the bottom row
        if j == len(errors) - 1:
            ax.set_xticks(list(feature_to_x.values()))
            ax.set_xticklabels(plot_order_models, rotation=90, fontsize=9)
        else:
            # Hide x-tick labels for upper rows if sharing x-axis
            ax.set_xticklabels([])

        # Set y-axis label only for the first column (leftmost plots)
        if i == 0:
             ax.set_ylabel('MSE (Individual Seeds)', fontsize=10)

handles, labels = [], []
# Collect handles and labels from the first subplot that has data
for model_key, info in model_plot_info.items():
    # Use Line2D for scatter plot legend handles (marker='o', linestyle='')
    handle = Line2D([0], [0], marker='o', color='w',
                    markerfacecolor=info['color'], markersize=point_size//2,
                    label=info['label'], linestyle='None')
    handles.append(handle)
    labels.append(info['label'])

if handles:
    normal_font = FontProperties(weight='normal', size=10)
    bold_font = FontProperties(weight='bold', size=10)
    legend = fig_models.legend(handles=handles, labels=labels, loc='lower center', bbox_to_anchor=(0.5, 0.01),
                               ncol=3, frameon=False, fontsize=10)

    for idx, text in enumerate(legend.get_texts()):
        text.set_fontproperties(normal_font)
else:
    print("Warning: Could not generate legend for the multi-model plot (no data found).")

plt.tight_layout(rect=[0, 0.05, 1, 1])
fig_models.savefig("seeding_experiment/seeding_experiment.png", dpi=600, bbox_inches='tight')

