import yaml
import sys
import os

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sys.path.append('../') 

from src.seip_package.utils.data_loading import *
from src.seip_package.utils.sensor_failure_sim import *
from src.seip_package.utils.testing import evaluate_models
from src.seip_package.utils.custom_datasets import *
from src.seip_package.utils.experiment_code import load_strength_results_from_filenames, plot_strength_trends_specific

EVAL_NEW = False
device = "mps"
folder_dir = "bias_scaling_experiment"
mask_names = ["mean", "bias", "noise", "scaling"]

BENCHMARK = "revs_benchmark_seed_42"
PATH_TO_CONFS = os.path.join(project_root, "configs/")

model_names = [f"seeding_experiment/{BENCHMARK}",
            "bias_scaling/revs_m_b_n_ft_bias_1",
            "bias_scaling/revs_m_b_n_ft_bias_2",
            "bias_scaling/revs_m_b_n_ft_bias_3",
            "bias_scaling/revs_m_b_n_ft_bias_4",
        ]

pretrain_tasks = model_names
pretrain_tasks = [f"{PATH_TO_CONFS}{task}.yaml" for task in pretrain_tasks]
args = yaml.load(open(f"{PATH_TO_CONFS}{model_names[0]}.yaml" , "r"), Loader=yaml.FullLoader)

mean_std_yaml = yaml.load(open(f"scaling_info/{BENCHMARK}_mean_std.yaml", "r"), Loader=yaml.FullLoader)
mean = mean_std_yaml["mean"]
std = mean_std_yaml["std"]
mean = torch.tensor(list(mean.values()))
std = torch.tensor(list(std.values()))

if EVAL_NEW:
        
    iid_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Montery_Motorsports_Reunion_Test",
                                        batchsize = 16384,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    ood_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Targa_Sixty_Six",
                                        batchsize = 16384,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    feat_dim = next(iter(iid_test_dataloader))[0].size(2)

    biases = [0.5, 1, 1.5, 2, 2.5]
    noise_levels = [0.2, 0.4, 0.6, 0.8, 1.0]
    scaling_mult = [1.5, 2, 2.5, 3, 3.5]

    for bias, noise, scale in zip(biases, noise_levels, scaling_mult):
        df = evaluate_models(mask_names, iid_test_dataloader, bias, noise, scale, pretrain_tasks, feat_dim, mean, std, device = device)
        df_ood = evaluate_models(mask_names, ood_test_dataloader, bias, noise, scale, pretrain_tasks, feat_dim, mean, std, device = device)
        df.to_csv(f"{folder_dir}/iid_rob_bias{bias}_noise{noise}_scale{scale}.csv")
        df_ood.to_csv(f"{folder_dir}/ood_rob_bias{bias}_noise{noise}_scale{scale}.csv")

iid_pattern = "iid_rob_bias*_noise*_scale*.csv"
ood_pattern = "ood_rob_bias*_noise*_scale*.csv"
combined_iid_df, iid_strengths_info = load_strength_results_from_filenames(iid_pattern, folder_dir)
combined_ood_df, ood_strengths_info = load_strength_results_from_filenames(ood_pattern, folder_dir)

print("Combined IID Data Head:")
print(combined_iid_df.head())
print("\nParsed IID Strengths Info (first 5):")
print(iid_strengths_info[:5])

print("\nCombined OOD Data Head:")
print(combined_ood_df.head())
print("\nParsed OOD Strengths Info (first 5):")
print(ood_strengths_info[:5])

model_keys = ["revs_benchmark_seed_42", 
              "revs_m_b_n_ft_bias_1",
              "revs_m_b_n_ft_bias_2",
              "revs_m_b_n_ft_bias_3",
              "revs_m_b_n_ft_bias_4",]

line_colors = {
    "revs_benchmark_seed_42": (102/255, 0/255, 150/255),
    "revs_m_b_n_ft_bias_1": (7/255, 147/255, 195/255),
    "revs_m_b_n_ft_bias_2": (253/255, 184/255, 99/255),
    "revs_m_b_n_ft_bias_3": (230/255, 97/255, 1/255),
    "revs_m_b_n_ft_bias_4": (178/255, 24/255, 43/255)
}

line_labels = {
    "revs_benchmark_seed_42": "Baseline (no pretraining)",
    "revs_m_b_n_ft_bias_1": "Ours (Bias 1.0)",
    "revs_m_b_n_ft_bias_2": "Ours (Bias 2.0)",
    "revs_m_b_n_ft_bias_3": "Ours (Bias 3.0)",
    "revs_m_b_n_ft_bias_4": "Ours (Bias 4.0)"
}

features = args["data"]["columns_to_standardize"]
if 'no_failure' in features:
     features.remove('no_failure')

# Plot for IID data
print("\n--- Plotting IID Strength Trends ---")
plot_strength_trends_specific(combined_iid_df, model_keys, features, "Same Track", line_labels, line_colors, folder_dir, log=True)

# Plot for OOD data
print("\n--- Plotting OOD Strength Trends ---")
plot_strength_trends_specific(combined_ood_df, model_keys, features, "New Track", line_labels, line_colors, folder_dir, log=True)

