import yaml
import sys
import os
import glob
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.lines import Line2D
from matplotlib.ticker import FormatStrFormatter

# Bias bounds, noise and failure strengths experiments

def load_strength_results_from_filenames(pattern, folder_path):
    """
    Loads robustness results from CSV files named according to the pattern
    'prefix_bias{b}_noise{n}_scale{s}.csv', extracting strength values.
    """
    full_pattern = os.path.join(folder_path, pattern) # Construct full path pattern

    all_files = glob.glob(full_pattern)

    df_list = []
    parsed_strengths = []

    # Regex to extract bias, noise, and scale values
    # Handles integers and decimals
    regex = re.compile(r"bias([\d.]+)_noise([\d.]+)_scale([\d.]+)\.csv")

    for f in all_files:
        match = regex.search(os.path.basename(f))
        if match:
            try:
                bias_strength = float(match.group(1))
                noise_strength = float(match.group(2))
                scale_strength = float(match.group(3))

                df_single_run = pd.read_csv(f, index_col=0)
                df_single_run['bias_strength'] = bias_strength
                df_single_run['noise_strength'] = noise_strength
                df_single_run['scale_strength'] = scale_strength
                df_list.append(df_single_run)
                parsed_strengths.append({
                    'bias': bias_strength,
                    'noise': noise_strength,
                    'scale': scale_strength,
                    'file': f
                })
            except ValueError:
                 print(f"Warning: Could not parse numbers from filename {f}. Skipping.")
            except Exception as e:
                print(f"Error loading file {f}: {e}")
        else:
            print(f"Warning: Filename {f} did not match expected pattern. Skipping.")


    if not df_list:
        print(f"No data loaded for pattern '{pattern}'. Returning empty DataFrame.")
        return pd.DataFrame(), []

    combined_df = pd.concat(df_list, ignore_index=False) # Keep original index (features/no_failure)
    combined_df.reset_index(inplace=True) # Make the feature names a column
    combined_df.rename(columns={'index': 'feature'}, inplace=True)

    # Sort strengths for consistent plotting later
    parsed_strengths.sort(key=lambda x: (x['bias'], x['noise'], x['scale']))

    return combined_df, parsed_strengths

def plot_strength_trends_specific(combined_df, model_keys, features_to_plot, dataset_label, line_labels, line_colors, folder_dir, log = False):
    """
    Plots MSE trends against specific failure strength for different models,
    features, and relevant error types (bias, noise, scaling).
    """
    error_types_to_plot = ['bias', 'noise', 'scaling'] # Exclude 'mean'
    strength_cols = {'bias': 'bias_strength', 'noise': 'noise_strength', 'scaling': 'scale_strength'}
    strength_labels = {'bias': 'Bias Strength', 'noise': 'Noise Strength', 'scaling': 'Scaling Strength'}

    num_features = len(features_to_plot)
    plotted_indices = []

    ncols = min(num_features, 6)
    nrows = (num_features + ncols - 1) // ncols

    normal_font = FontProperties(weight='normal', size=10)

    for error_type in error_types_to_plot:
        strength_col = strength_cols[error_type]
        strength_label = strength_labels[error_type]

        fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 1.5), sharex=True, sharey=True, constrained_layout=False)
        fig.suptitle(f'{dataset_label} - Error Type: {error_type.capitalize()}', fontsize=12, y=1.1)
        axes = axes.flatten()

        # Filter for the specific error type and ensure strength is numeric and sorted
        df_error = combined_df[combined_df['error'] == error_type].copy()
        df_error[strength_col] = pd.to_numeric(df_error[strength_col])
        df_error.sort_values(strength_col, inplace=True)

        if df_error.empty:
            print(f"No data found for error type '{error_type}' in {dataset_label}. Skipping plot.")
            plt.close(fig)
            continue

        unique_strengths = sorted(df_error[strength_col].unique())

        for i, feature in enumerate(features_to_plot):
            if i >= len(axes): break
            ax = axes[i]
            # Filter for the specific feature
            df_feature = df_error[df_error['feature'] == feature]

            plotted_indices.append(i)

            if df_feature.empty:
                ax.set_title(f"{feature} (No data)")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
                continue

            for model_key in model_keys:
                mse_col = f"{model_key}_mse_bootstrapped"
                lower_col = f"{model_key}_ci_lower"
                upper_col = f"{model_key}_ci_upper"

                if mse_col not in df_feature.columns:
                    print(f"Warning: Column {mse_col} not found for feature {feature}, error {error_type}. Skipping model {model_key}.")
                    continue
                if log:
                    ax.set_yscale('log')

                # Plot line for this model
                ax.plot(df_feature[strength_col], df_feature[mse_col],
                        label=line_labels.get(model_key, model_key),
                        color=line_colors.get(model_key, None),
                        marker='o', markersize=4, linewidth=1.5)

                # Plot confidence interval band
                if lower_col in df_feature.columns and upper_col in df_feature.columns:
                     ax.fill_between(df_feature[strength_col], df_feature[lower_col], df_feature[upper_col],
                                     color=line_colors.get(model_key, None), alpha=0.2)

            ax.set_title(f"{feature}")
            ax.grid(True, linestyle='--', alpha=0.7)
            ax.set_xticks(unique_strengths)
            ax.tick_params(axis='x', rotation=45)
            #ax.set_yscale('log')
            
            current_col = i % ncols
            if current_col == 0:
                    ax.set_ylabel("MSE", fontsize=10)
            else:
                    ax.set_ylabel("")
        for c in range(ncols):
            # Find indices in this column that were plotted
            col_indices = [idx for idx in plotted_indices if idx % ncols == c]
            if col_indices:
                # Find the bottom-most plotted axis index in this column
                bottom_ax_idx = max(col_indices)
                ax_bottom = axes[bottom_ax_idx]
                ax_bottom.set_xlabel(strength_label, fontsize=10)
                ax_bottom.tick_params(axis='x', labelbottom=True, rotation=30, labelsize=9)

        for j in range(len(axes)):
             if j not in plotted_indices:
                 axes[j].set_visible(False)

        handles, labels = axes[0].get_legend_handles_labels()
        if handles:
             legend = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=len(model_keys), frameon=False)
             for text in legend.get_texts():
                text.set_fontproperties(normal_font)

        fig.subplots_adjust(hspace=0.5)
        fig.savefig(f"{folder_dir}/{dataset_label}_{error_type}_strength_trends.png", dpi=600, bbox_inches='tight')
        plt.close(fig)


# Model size and other masks experiment

def plot_task_comparison_matplotlib_individual(df, model_bases, error_order=None, log=True, title="Model Comparison", save_path="task_comparison_individual.png"):
    # Identify the feature column
    if df.index.name is not None and df.index.name != 'error':
        feature_col_name = df.index.name
        df_plot = df.reset_index() # Reset index for consistent column access later
        print(f"Using index '{feature_col_name}' as the feature identifier.")
    else:
        df_temp = df.reset_index(drop=df.index.name=='error')
        potential_feature_col = df_temp.columns[0]
        if potential_feature_col.startswith('Unnamed'):
             feature_col_name = 'feature'
             df_temp.rename(columns={potential_feature_col: feature_col_name}, inplace=True)
             print(f"Warning: Assuming first unnamed column is the feature identifier, renaming to '{feature_col_name}'.")
        else:
             feature_col_name = potential_feature_col

        df_plot = df_temp
        print(f"Using column '{feature_col_name}' as the feature identifier.")

    # Data Preparation
    # Define model order: Sort by MSE of the 'no_failure' FEATURE row
    original_model_order = list(model_bases)
    try:
        # Find the row(s) where the feature identifier is 'no_failure'
        no_failure_sensor_rows = df_plot[df_plot[feature_col_name] == 'no_failure']

        if not no_failure_sensor_rows.empty:
            # Take the first row found for 'no_failure' feature (should be consistent)
            nf_sensor_data = no_failure_sensor_rows.iloc[0]

            nf_sensor_mses = {}
            for base in model_bases:
                mse_col = f"{base}_mse_bootstrapped"
                if mse_col in nf_sensor_data.index and pd.notna(nf_sensor_data[mse_col]):
                    nf_sensor_mses[base] = nf_sensor_data[mse_col]
                else:
                    nf_sensor_mses[base] = float('inf')

            # Sort models by ascending MSE from the 'no_failure' feature row
            model_order = sorted(model_bases, key=lambda m: nf_sensor_mses.get(m, float('inf')))

            if model_order == original_model_order and len(model_order) > 1:
                 print("Warning: Model order unchanged after sorting. Check MSE values in 'no_failure' feature row.")
                 print("Sorting MSE values ('no_failure' feature row):")
                 for model in model_order: print(f"  {model}: {nf_sensor_mses.get(model, 'N/A')}")
            else:
                 print(f"Model order based on 'no_failure' feature row MSE: {model_order}")

        else:
            print(f"Warning: Data row for feature='{feature_col_name}' == 'no_failure' not found. Using provided model order.")
            model_order = original_model_order

    except KeyError as e:
         print(f"Error during sorting preparation (KeyError): {e}. Check column names ('{feature_col_name}', 'error'). Using provided model order.")
         model_order = original_model_order
    except Exception as e:
         print(f"An unexpected error occurred during sorting: {e}. Using provided model order.")
         model_order = original_model_order


    model_to_x = {model: i for i, model in enumerate(model_order)}

    # Define error types and plotting properties
    if error_order is None:
        # Get unique error types present in the data, excluding the specific 'no_failure' feature row
        plot_error_types = sorted(df_plot[df_plot[feature_col_name] != 'no_failure']['error'].unique())
        all_error_types_in_data = sorted(df_plot['error'].unique())
        print(f"Found error types in data: {all_error_types_in_data}")
        error_types_to_plot = all_error_types_in_data
    else:
        # Use provided order, filtering by those present in the data
        error_types_to_plot = [e for e in error_order if e in df_plot['error'].unique()]
        print(f"Using specified error order (filtered): {error_types_to_plot}")

    # Define colors/markers based on the error types we intend to plot
    # Exclude 'no_failure' error type as it will be shown as a dashed line
    error_types_for_color = [e for e in error_types_to_plot if e != 'no_failure']
    colors = plt.cm.viridis(np.linspace(0, 0.9, len(error_types_for_color)))
    error_colors = {error: color for error, color in zip(error_types_for_color, colors)}
    
    markers = ['o', 's', '^', 'X', 'P', '*']
    error_markers = {error: markers[i % len(markers)] for i, error in enumerate(error_types_for_color)}
    
    # Define offsets for jittering points
    num_errors_plot_offset = len(error_types_for_color)
    offsets = np.linspace(-0.2, 0.2, num_errors_plot_offset) if num_errors_plot_offset > 1 else [0]
    error_offsets = {error: offset for error, offset in zip(error_types_for_color, offsets)}

    point_size = 15
    point_alpha = 0.6

    # Plotting
    fig, ax = plt.subplots(figsize=(12, 5))
    fig.set_facecolor('white')
    legend_handles = {}

    # Extract no_failure MSE values for each model to create the dashed line
    no_failure_values = {}
    for model_base in model_order:
        mse_col = f"{model_base}_mse_bootstrapped"
        if mse_col not in df_plot.columns:
            continue
            
        # Find the 'no_failure' feature row for this model
        no_failure_rows = df_plot[df_plot[feature_col_name] == 'no_failure']
        if not no_failure_rows.empty:
            mse_values = no_failure_rows[mse_col].dropna()
            if not mse_values.empty:
                # Use the first valid MSE value
                no_failure_values[model_base] = mse_values.iloc[0]

    # Plot the dashed line for 'no_failure' across all models if values exist
    if no_failure_values:
        x_values = [model_to_x[model] for model in model_order if model in no_failure_values]
        y_values = [no_failure_values[model] for model in model_order if model in no_failure_values]

        if len(x_values) > 1: 
            # Use 'steps-mid' for a stepped line between model points
            ax.plot(x_values, y_values, 'k--', linewidth=1.25, drawstyle='steps-mid', 
                    label='No Failure', zorder=5, alpha=0.8)
            # Add a legend entry for the no_failure line
            legend_handles['no_failure'] = Line2D([0], [0], linestyle='--', color='k', 
                                                label='No Failure', linewidth=1.25, alpha=0.8)
    
    # Plot individual sensor error points
    for model_base in model_order:
        x_pos_base = model_to_x[model_base]
        mse_col = f"{model_base}_mse_bootstrapped"

        if mse_col not in df_plot.columns:
            continue

        # Iterate through the identified error types to plot (excluding 'no_failure' which is now a line)
        for error in error_types_for_color:
            # Filter df_plot for the current error type
            error_specific_df = df_plot[(df_plot['error'] == error) & (df_plot[feature_col_name] != 'no_failure')]
            if error_specific_df.empty: continue

            # Get plotting properties for this error type
            current_color = error_colors.get(error, 'black')
            current_marker = error_markers.get(error, 'x')
            current_offset = error_offsets.get(error, 0)

            # Plot sensor points for this error type
            mses = error_specific_df[mse_col].dropna()
            if not mses.empty:
                x_coords_plot = [x_pos_base + current_offset] * len(mses)
                ax.scatter(x_coords_plot, mses.values,
                          marker=current_marker, color=current_color, s=point_size, alpha=point_alpha,
                          label=error.replace('_', ' ').capitalize() if error not in legend_handles else "_nolegend_",
                          zorder=3, edgecolors='none')

                # Store handle for legend (only once per error type)
                if error not in legend_handles:
                    proxy = Line2D([0], [0], linestyle='none', marker=current_marker, color=current_color,
                                  markersize=np.sqrt(point_size), label=error.replace('_', ' ').capitalize())
                    legend_handles[error] = proxy

    # Styling 
    ax.set_xticks(list(model_to_x.values()))
    xticklabels = [name.replace('revs_', '').replace('_ft', '').replace('benchmark', 'Base').replace('seed', '').replace('42', '').replace('d_model', '-').replace('d_ff', '-').replace('_', '').upper() for name in model_order]    
    ax.set_xticklabels(xticklabels, rotation=90, ha='right', fontsize=10)
    ax.set_xlabel('Pretraining Tasks', fontsize=12)
    ax.set_ylabel('MSE per Sensor', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.grid(True, linestyle='--', alpha=0.6, axis='y')

    # Y-Scale and Limits
    if log:
        ax.set_yscale('log')
        all_mse_cols = [f"{mb}_mse_bootstrapped" for mb in model_bases if f"{mb}_mse_bootstrapped" in df_plot]
        if all_mse_cols:
             all_plotted_mses = df_plot[all_mse_cols].stack().dropna().values
             positive_mses = all_plotted_mses[all_plotted_mses > 0]
             min_mse_val = np.min(positive_mses) if len(positive_mses) > 0 else 1e-6
             ax.set_ylim(bottom=max(1e-7, min_mse_val * 0.5))
        else:
             ax.set_ylim(bottom=1e-7)
    else:
        ax.set_yscale('linear')

        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        all_mse_cols = [f"{mb}_mse_bootstrapped" for mb in model_bases if f"{mb}_mse_bootstrapped" in df_plot]
        if all_mse_cols:
             all_plotted_mses = df_plot[all_mse_cols].stack().dropna().values
             if len(all_plotted_mses) > 0:
                 min_mse_val = np.min(all_plotted_mses)
                 max_mse_val = np.max(all_plotted_mses)
                 padding = (max_mse_val - min_mse_val) * 0.05 if max_mse_val > min_mse_val else max_mse_val * 0.05
                 ax.set_ylim(bottom=max(0, min_mse_val - padding), top=max_mse_val + padding)
             else:
                  ax.set_ylim(bottom=0)
        else:
             ax.set_ylim(bottom=0)

    # Legend 
    # Order the legend with 'no_failure' first, then other error types
    ordered_legend_keys = ['no_failure'] + [err for err in error_types_for_color if err in legend_handles]
    ordered_handles = [legend_handles[err] for err in ordered_legend_keys if err in legend_handles]
    
    if ordered_handles:
        ax.legend(handles=ordered_handles, title='Applied Error', loc='upper left', bbox_to_anchor=(1.02, 1),
                  fontsize=10, title_fontsize=11, borderaxespad=0.)

    plt.tight_layout(rect=[0.05, 0.05, 0.85, 0.95])

    # Saving 
    try:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"Figure saved to {save_path}")
    except Exception as e:
        print(f"Error saving figure to {save_path}: {e}")

