import yaml
import sys
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.font_manager import FontProperties

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sys.path.append('../') 

from src.seip_package.utils.data_loading import *
from src.seip_package.utils.sensor_failure_sim import *
from src.seip_package.utils.testing import evaluate_models
from src.seip_package.utils.custom_datasets import *

# Set the global font to Times New Roman
#mpl.rcParams['font.family'] = 'Times New Roman'
#mpl.rcParams['mathtext.fontset'] = 'stix'  # Optional: for math text
#mpl.rcParams['font.size'] = 14

EVAL_NEW = False
device = "mps"
folder_dir = "experiment_fig_3"
mask_names = ["mean", "bias", "noise", "scaling"]

def plot_with_error_bars(ax, error_df, column_base_name, y_position, height=0.25, color=None, 
                         label=None, zorder=2, ecolor='black', elinewidth=1.5, capsize=3):

    column_name = f"{column_base_name}_mse_bootstrapped"
    ci_lower = f"{column_base_name}_ci_lower"
    ci_upper = f"{column_base_name}_ci_upper"
    
    bars = ax.barh(y_position, error_df[column_name], height, 
                  label=label, color=color, zorder=zorder)
    
    lower_err = np.maximum(0, error_df[column_name] - error_df[ci_lower])
    upper_err = np.maximum(0, error_df[ci_upper] - error_df[column_name])
    
    for i, bar in enumerate(bars):

        center = bar.get_y() + bar.get_height()/2
        value = bar.get_width()
        
        ax.errorbar(
            x=value,
            y=center,
            xerr=[[lower_err.iloc[i]], [upper_err.iloc[i]]],
            fmt='none',
            ecolor=ecolor,
            elinewidth=elinewidth,
            capsize=capsize,
            zorder=zorder+1
        )
    
    return bars

if EVAL_NEW:
    BENCHMARK = "revs_benchmark_seed_42"

    PATH_TO_CONFS = os.path.join(project_root, "configs/seeding_experiment/")
    model_names = [f"{BENCHMARK}",
                "revs_m_b_n_ft_seed_42",
                "revs_m_ft_seed_42"
            ]

    pretrain_tasks = model_names
    pretrain_tasks = [f"{PATH_TO_CONFS}{task}.yaml" for task in pretrain_tasks]
    args = yaml.load(open(f"{PATH_TO_CONFS}{BENCHMARK}.yaml" , "r"), Loader=yaml.FullLoader)

    mean_std_yaml = yaml.load(open(f"scaling_info/{BENCHMARK}_mean_std.yaml", "r"), Loader=yaml.FullLoader)
    mean = mean_std_yaml["mean"]
    std = mean_std_yaml["std"]
    mean = torch.tensor(list(mean.values()))
    std = torch.tensor(list(std.values()))

    iid_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Montery_Motorsports_Reunion_Test",
                                        batchsize = 16384,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    ood_test_dataloader = load_test_data(test_path = "../datasets/REVS_Program_Vehicle_Dynamics_Database/2013_Targa_Sixty_Six",
                                        batchsize = 16384,
                                        columns_to_drop = args["data"]["columns_to_drop"],
                                        sequence_length = args["data"]["sequence_length"],
                                        target = args["data"]["target"],
                                        threshold = args["data"]["threshold_value"],
                                        threshold_column = args["data"]["threshold_column"],
                                        seed = args["experiment"]["seed"])

    feat_dim = next(iter(iid_test_dataloader))[0].size(2)

    bias_mult = 1.0
    noise_l = 0.4
    scaling_mult = 2

    df = evaluate_models(mask_names, iid_test_dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device = device)
    df_ood = evaluate_models(mask_names, ood_test_dataloader, bias_mult, noise_l, scaling_mult, pretrain_tasks, feat_dim, mean, std, device = device)

    iid_csv_path = os.path.join(folder_dir, "iid_rob_bias1_noise04_scale2.csv")
    ood_csv_path = os.path.join(folder_dir, "ood_rob_bias1_noise04_scale2.csv")
    df.to_csv(iid_csv_path)
    df_ood.to_csv(ood_csv_path)

else:

    iid_csv_path = os.path.join(folder_dir, "iid_rob_bias1_noise04_scale2.csv")
    ood_csv_path = os.path.join(folder_dir, "ood_rob_bias1_noise04_scale2.csv")
    df = pd.read_csv(iid_csv_path, index_col=0)
    df_ood = pd.read_csv(ood_csv_path, index_col=0)

plot_df = df
plot_df_ood = df_ood

plot_order = ['yawRate', 'ayCG', 'handwheelAngle', 'engineSpeed', 'throttle', 'brake', 'axCG',
        'chassisAccelFL', 'chassisAccelFR', 'chassisAccelRL', 'chassisAccelRR', 'no_failure']
plot_order = plot_order[::-1]

fig, axes = plt.subplots(2, 4, figsize=(10, 9), sharey=True, sharex=True, constrained_layout=True, gridspec_kw={'width_ratios': [1, 1, 1, 1]})
fig.set_facecolor('none')

datasets = [plot_df, plot_df_ood]
dataset_labels = ["Same Track", "New Track"]
errors = ['bias', 'mean', 'noise', 'scaling']

bar_colors = {
    "revs_benchmark_seed_42_mse_bootstrapped": (102/255, 0/255, 150/255),
    "revs_m_ft_seed_42_mse_bootstrapped": (178/255, 24/255, 43/255),
    "revs_m_b_n_ft_seed_42_mse_bootstrapped": (7/255, 147/255, 195/255)
}

bar_labels = {
    "revs_benchmark_seed_42_mse_bootstrapped": "Baseline (no pretraining)", 
    "revs_m_ft_seed_42_mse_bootstrapped": "Baseline (mean pretraining)", 
    "revs_m_b_n_ft_seed_42_mse_bootstrapped": "Ours (mean, bias, noise pretraining)"
}

for i, (dataset, dataset_label) in enumerate(zip(datasets, dataset_labels)):
    
    for j, error in enumerate(errors):
        ax = axes[i, j]
        if i == 0:
            error_df = plot_df[plot_df['error'] == error]
        else:
            error_df = plot_df_ood[plot_df_ood['error'] == error]
        error_df = error_df.reindex(plot_order)
        
        y = np.arange(len(error_df.index))
        height = 0.3

        plot_with_error_bars(
            ax=ax,
            error_df=error_df, 
            column_base_name="revs_benchmark_seed_42", 
            y_position=y - height,
            height=height,
            color=bar_colors["revs_benchmark_seed_42_mse_bootstrapped"],
            label=bar_labels["revs_benchmark_seed_42_mse_bootstrapped"],
            ecolor='black',
            elinewidth=1.5,
            capsize=2
        )
        
        plot_with_error_bars(
            ax=ax,
            error_df=error_df, 
            column_base_name="revs_m_ft_seed_42", 
            y_position=y,
            height=height,
            color=bar_colors["revs_m_ft_seed_42_mse_bootstrapped"],
            label=bar_labels["revs_m_ft_seed_42_mse_bootstrapped"],
            ecolor='black',
            elinewidth=1.5,
            capsize=2
        )
        
        plot_with_error_bars(
            ax=ax,
            error_df=error_df, 
            column_base_name="revs_m_b_n_ft_seed_42", 
            y_position=y + height,
            height=height,
            color=bar_colors["revs_m_b_n_ft_seed_42_mse_bootstrapped"],
            label=bar_labels["revs_m_b_n_ft_seed_42_mse_bootstrapped"],
            ecolor='black',
            elinewidth=1.5,
            capsize=2
        )


        ax.grid(True, linestyle='--', alpha=0.7, zorder=0)
        
        title = f'{dataset_label} - {error.capitalize()}'
        ax.text(0.99, 0.05, title, 
            transform=ax.transAxes,
            horizontalalignment='right',
            verticalalignment='top',
            fontsize=12,
            zorder=2,
            bbox=dict(facecolor='white', 
                alpha=0.8,
                edgecolor='none',
                pad=3))
        
        ax.set_yticks(y)
        ax.set_yticklabels(error_df.index, fontsize=12)
        ax.set_facecolor('none')

        padding_factor = 0.5 # Adjust this factor as needed
        ax.set_ylim(y[0] - height - padding_factor * height, y[-1] + height + padding_factor * height)

        if i == 1:
            ax.set_xlabel('MSE')

handles, labels = axes[0, 0].get_legend_handles_labels()

normal_font = FontProperties(weight='normal', size=12)
bold_font = FontProperties(weight='bold', size=12)

legend = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.55, 0.11), 
                   ncol=3, frameon=False)

for i, text in enumerate(legend.get_texts()):
    if i == 2:
        text.set_fontproperties(normal_font)
    else:
        text.set_fontproperties(normal_font)

plt.tight_layout()
fig.subplots_adjust(bottom=0.15, hspace=0.02, wspace=0.02)

fig.savefig(f"{folder_dir}/figure_3.png", dpi=600, bbox_inches='tight', pad_inches=0.05)
plt.close(fig)
