import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
import sys
sys.path.append('.')
path = os.path.dirname(sys.argv[0])

# Set global font to commonly available fonts with increased font size
plt.rcParams.update({
    "font.family": "serif",
    # Use fonts that are likely to be available on most systems
    "font.serif": ["DejaVu Serif", "Liberation Serif", "Bitstream Vera Serif", "Nimbus Roman", "Times New Roman"],
    "mathtext.fontset": "dejavuserif",
    "text.usetex": False,
    "font.size": 40,  # Increase base font size
    "axes.labelsize": 40,  # Larger font for axis labels
    "axes.titlesize": 40,  # Larger font for titles
    "xtick.labelsize": 40,  # Larger font for x-tick labels
    "ytick.labelsize": 40,  # Larger font for y-tick labels
    "legend.fontsize": 35,  # Larger font for legend
})

# Function to process data and create plot
def process_and_plot(ax, data_file, subplot_title):
    # Read the data file
    data_path = os.path.join(path, data_file)
    with open(data_path, 'r') as f:
        lines = f.readlines()

    # Extract header and data - fix the parsing issue
    header_line = lines[0].strip()
    header = header_line.replace("Variables=", "").split()

    # Skip the header line and process data lines
    data_lines = []
    for line in lines[1:]:
        if line.strip():  # Skip empty lines
            values = line.strip().split()
            if len(values) == len(header):
                data_lines.append(values)
            else:
                print(f"Warning: Line with {len(values)} values doesn't match header with {len(header)} columns. Skipping.")

    # Convert to DataFrame
    df = pd.DataFrame(data_lines, columns=header)
    df = df.astype({col: float for col in df.columns})

    # Calculate the ratio of estimated to real noise
    df['noise_ratio'] = df['E_noise'] / df['E_noise_real']

    # Group by i_func and dim_input
    grouped = df.groupby(['i_func', 'dim_input'])

    # Color map for different i_func values
    colors = ['blue', 'red', 'green', 'purple']
    markers = ['o', 's', '^', 'd']

    # Plot each group
    for i, ((i_func, dim_input), group) in enumerate(grouped):
        color = colors[int(i_func) % len(colors)]
        marker = markers[int(dim_input) % len(markers)]
        
        # Sort by num_samples for connected lines
        group = group.sort_values('num_samples')
        
        # Plot the ratio between estimated and real noise
        ax.loglog(
            group['num_samples'], 
            group['noise_ratio'], 
            marker=marker, 
            linestyle='-', 
            color=color, 
            linewidth=3.5,
            markersize=15,
            label=f'Eq. {24+int(i_func)}, $n_x$={int(dim_input)}'
        )

    # Add horizontal line at y=1 (perfect estimation)
    ax.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, linewidth=3.0, label='Ideal')

    # Set labels and title
    ax.set_xlabel('Number of Samples (log scale)')
    ax.set_ylabel('Ratio of Estimated to Real Noise \n (log scale)')
    ax.set_title(subplot_title)

    # Set y-axis to log scale with the range [0.1, 100.0]
    ax.set_ylim(0.01, 100.0)

# Create the figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(40, 13))

# Process and plot first dataset
process_and_plot(ax1, 'variance-multi-dimension.dat', 'Clustering-based Estimation')

# Process and plot second dataset
process_and_plot(ax2, 'variance-mc-dropout.dat', 'MC Dropout Estimation')

# Create a common legend for both subplots
handles, labels = ax1.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(by_label.values(), by_label.keys(), 
           loc='center right',  # Position the legend at the right side
           bbox_to_anchor=(1.0, 0.5),  # Position outside the figure
)

# Adjust layout
plt.tight_layout()
plt.subplots_adjust(right=0.82, wspace=0.3)  # Make room for the legend on the right and increase space between subplots

# Save the figure
output_path = os.path.join(path, 'noise_ratio_comparison_combined')
plt.savefig(output_path + '.png', dpi=30)
plt.savefig(output_path + '.pdf', dpi=300)
print(f"Figure saved to {output_path}")

# Show the figure
plt.show()
