import yaml
import sys
import os

import pandas as pd

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sys.path.append('../') 

from src.seip_package.utils.data_loading import *
from src.seip_package.utils.sensor_failure_sim import *
from src.seip_package.utils.testing import evaluate_models
from src.seip_package.utils.custom_datasets import *
from src.seip_package.utils.experiment_code import plot_task_comparison_matplotlib_individual

# Save to CSV
EVAL_NEW = False 
device = "mps"
folder_dir = "model_size_experiment"
mask_names = ["mean", "bias", "noise", "scaling"]

if EVAL_NEW:
    BENCHMARK = "revs_benchmark_seed_42"

    PATH_TO_CONFS = os.path.join(project_root, "configs/")
    model_names = [f"seeding_experiment/{BENCHMARK}",
                "seeding_experiment/revs_m_b_n_ft_seed_42",
                "d_model_test/revs_m_b_n_ft_d_model_32_d_ff_128",
                "d_model_test/revs_m_b_n_ft_d_model_64_d_ff_256",
                "d_model_test/revs_m_b_n_ft_d_model_128_d_ff_512",
                "d_model_test/revs_m_b_n_ft_d_model_256_d_ff_1024"
            ]

    pretrain_tasks = model_names
    pretrain_tasks = [f"{PATH_TO_CONFS}{task}.yaml" for task in pretrain_tasks]
    args = yaml.load(open(f"{PATH_TO_CONFS}{model_names[0]}.yaml" , "r"), Loader=yaml.FullLoader)

    mean_std_yaml = yaml.load(open(f"scaling_info/{BENCHMARK}_mean_std.yaml", "r"), Loader=yaml.FullLoader)
    mean = mean_std_yaml["mean"]
    std = mean_std_yaml["std"]
    mean = torch.tensor(list(mean.values()))
    std = torch.tensor(list(std.values()))

    iid_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Montery_Motorsports_Reunion_Test",
                                        batchsize = 4096,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    ood_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Targa_Sixty_Six",
                                        batchsize = 4096,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    feat_dim = next(iter(iid_test_dataloader))[0].size(2)

    bias_mult = 1.0
    noise_l = 0.4
    scaling_mult = 2

    df = evaluate_models(mask_names, iid_test_dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device = device)
    df_ood = evaluate_models(mask_names, ood_test_dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device = device)

    iid_csv_path = os.path.join(folder_dir, "iid_rob_bias1_noise04_scale2_all_models.csv")
    ood_csv_path = os.path.join(folder_dir, "ood_rob_bias1_noise04_scale2_all_models.csv")
    df.to_csv(iid_csv_path)
    df_ood.to_csv(ood_csv_path)
else:
    iid_csv_path = os.path.join(folder_dir, "iid_rob_bias1_noise04_scale2_all_models.csv")
    ood_csv_path = os.path.join(folder_dir, "ood_rob_bias1_noise04_scale2_all_models.csv")
    df = pd.read_csv(iid_csv_path, index_col=0)
    df_ood = pd.read_csv(ood_csv_path, index_col=0)

if 'df' in locals() and not df.empty:
    all_cols = df.columns # Use columns from the loaded df
    model_bases = sorted(list(set([c.split('_mse_bootstrapped')[0] for c in all_cols if '_mse_bootstrapped' in c])))
    print(f"Identified model bases: {model_bases}")
elif 'df_ood' in locals() and not df_ood.empty: # Fallback to ood_df if df is empty
     all_cols = df_ood.columns
     model_bases = sorted(list(set([c.split('_mse_bootstrapped')[0] for c in all_cols if '_mse_bootstrapped' in c])))
     print(f"Identified model bases (from OOD data): {model_bases}")
else:
    print(f"Warning: Could not load data")

error_order = ['no_failure', 'mean', 'bias', 'noise', 'scaling']

if 'df' in locals() and not df.empty:
    print("\nPlotting IID Results (Individual Sensors)...")
    # Pass the unaggregated DataFrame 'df'
    plot_task_comparison_matplotlib_individual(df, model_bases, error_order=error_order, log=False,
                                    title="Model Size Comparison - Individual Sensors (Same Track)",
                                    save_path=os.path.join(folder_dir, "iid_task_comparison_individual.png"))
else:
    print("DataFrame 'df' not loaded or empty. Skipping IID plot.")

if 'df_ood' in locals() and not df_ood.empty:
    print("\nPlotting OOD Results (Individual Sensors)...")
    plot_task_comparison_matplotlib_individual(df_ood, model_bases, error_order=error_order, log=False,
                                    title="Model Size Comparison - Individual Sensors (New Track)",
                                    save_path=os.path.join(folder_dir, "ood_task_comparison_individual.png"))
else:
    print("DataFrame 'df_ood' not loaded or empty. Skipping OOD plot.")


