import os
import sys
sys.path.append('.')
path = os.path.dirname(sys.argv[0])

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import glob
from matplotlib import rcParams

# Set global font to commonly available fonts
plt.rcParams.update({
    "font.family": "serif",
    # Use fonts that are likely to be available on most systems
    "font.serif": ["DejaVu Serif", "Liberation Serif", "Bitstream Vera Serif", "Nimbus Roman", "Times New Roman"],
    "mathtext.fontset": "dejavuserif",
    "text.usetex": False,
})

from utils import prepare_dataset


NAME_MODEL = ['HVBLL', 'VBLL']
NAME_FUNC = ['lin-lin', 'lin-sin', 'sin-lin', 'sin-sin']

N_SAMPLE = 100


if __name__ == '__main__':


    # Find all CSV files in the script's directory
    csv_files = glob.glob(os.path.join(path, '*.csv'))

    # Create a figure with 2x4 subplots
    fig, axes = plt.subplots(2, 4, figsize=(50, 25))
    axes = axes.flatten()  # Flatten to make indexing easier

    # Keep track of all lines for the legend
    all_lines = []
    all_labels = []

    # Plot sub-plots for each model and function
    for i, name_model in enumerate(NAME_MODEL):
        for j, name_func in enumerate(NAME_FUNC):
            
            # Training data
            X_train, Y_train = prepare_dataset(name_func, N_SAMPLE, GPU_ID=-1, net_name=name_model, only_get_training_data=True)
            
            # Load the corresponding CSV file
            file_name = f"{name_model}-{name_func}.csv"
            csv_path = os.path.join(path, file_name)
            
            try:
                # Read the CSV data
                df = pd.read_csv(csv_path)
                
                # Get subplot index (i*4 + j for 2 models x 4 functions)
                subplot_idx = i*4 + j
                
                # Plot training data
                axes[subplot_idx].scatter(X_train, Y_train, color='black', s=40, alpha=0.6, label='Training data')
                
                # Plot true function and prediction
                axes[subplot_idx].plot(df['x_points'], df['y_true'], color='blue', linestyle='-', 
                                      linewidth=2, label='Ground truth mean')
                
                # Plot prediction with uncertainty
                axes[subplot_idx].plot(df['x_points'], df['y_pred'], color='red', linestyle='-', 
                                      linewidth=2, label='Predicted mean')
                
                # Add uncertainty shading (1 standard deviation)
                axes[subplot_idx].fill_between(df['x_points'], 
                                              df['y_true'] - df['a_uncertainty_true'],
                                              df['y_true'] + df['a_uncertainty_true'],
                                              color='blue', alpha=0.2, label='Ground truth noise level (1σ)')
                
                # # Add uncertainty shading (1 standard deviation)
                # axes[subplot_idx].fill_between(df['x_points'], 
                #                               df['y_pred'] - df['y_std_pred'],
                #                               df['y_pred'] + df['y_std_pred'],
                #                               color='green', alpha=0.2, label='Total uncertainty (1σ)')
                
                # Add aleatoric uncertainty if available
                if 'a_uncertainty_pred' in df.columns:
                    axes[subplot_idx].fill_between(df['x_points'], 
                                                  df['y_pred'] - df['a_uncertainty_pred'],
                                                  df['y_pred'] + df['a_uncertainty_pred'],
                                                  color='red', alpha=0.2, label='Aleatoric uncertainty (1σ)')
                
                # Add lettered subtitle at the bottom of each subplot (a through h)
                subplot_letter = chr(97 + subplot_idx)  # ASCII 'a' is 97, so this gives 'a' through 'h'
                axes[subplot_idx].text(0.5, -0.15, f"({subplot_letter}) {name_model}: {name_func}", transform=axes[subplot_idx].transAxes, 
                                     fontsize=40, ha='center', va='center', fontweight='bold')
                
                # axes[subplot_idx].grid(True, linestyle='--', alpha=0.7)
                axes[subplot_idx].set_xlabel('x', fontsize=40)
                axes[subplot_idx].set_ylabel('y', fontsize=40)
                axes[subplot_idx].tick_params(axis='both', which='major', labelsize=30)
                axes[subplot_idx].spines['top'].set_visible(False)
                axes[subplot_idx].spines['right'].set_visible(False)
                
                # Collect lines for the shared legend
                for line in axes[subplot_idx].get_lines():
                    if line.get_label() not in all_labels:
                        all_lines.append(line)
                        all_labels.append(line.get_label())
                        
                # Get the fill_between collections for the legend
                for collection in axes[subplot_idx].collections:
                    if hasattr(collection, 'get_label'):
                        label = collection.get_label()
                        if label not in all_labels and not label.startswith('_'):
                            all_lines.append(collection)
                            all_labels.append(label)
                
            except Exception as e:
                print(f"Error plotting {file_name}: {e}")
                axes[subplot_idx].text(0.5, 0.5, f"Error plotting {file_name}", 
                                     horizontalalignment='center', verticalalignment='center')
    
    # Add a shared legend at the top with improved styling and larger font
    legend = fig.legend(all_lines, all_labels, loc='upper center', ncol=3, 
            bbox_to_anchor=(0.5, 0.95), fontsize=40, frameon=True, fancybox=False, shadow=False)

    # Adjust layout - change rect to account for top legend instead of bottom
    # plt.tight_layout(rect=[0, 0, 1, 0.85])  # Adjust the top margin to accommodate the legend

    # Save the figure in both PNG and PDF formats
    plt.savefig(os.path.join(path, 'combined_plots.png'), dpi=50, bbox_inches='tight')
    plt.savefig(os.path.join(path, 'combined_plots.pdf'), dpi=300, bbox_inches='tight')

    plt.close()
