import os
import sys
sys.path.append('.')
path0 = os.path.dirname(sys.argv[0])

import numpy as np
import matplotlib.pyplot as plt
import re

# Set global font to commonly available fonts
plt.rcParams.update({
    "font.family": "serif",
    # Use fonts that are likely to be available on most systems
    "font.serif": ["DejaVu Serif", "Liberation Serif", "Bitstream Vera Serif", "Nimbus Roman", "Times New Roman"],
    "mathtext.fontset": "dejavuserif",
    "text.usetex": False,
})


# Read the data file
with open(os.path.join(path0, 'nll-comparison.dat'), 'r') as f:
    data = f.read()

# Extract zones
zones = re.split(r'zone T=\s*"([^"]+)"', data)[1:]
zones = [(zones[i], zones[i+1].strip()) for i in range(0, len(zones), 2)]

# Parse data for each model and sample size
models_data = {}
for model_str, zone_data in zones:
    model_str = model_str.strip()
    match = re.search(r'(\w+)\s*\(\s*N\s*=\s*(\d+)\s*\)', model_str)
    if match:
        model_name, n_samples = match.groups()
        n_samples = int(n_samples)
        
        lines = zone_data.strip().split('\n')
        data_points = []
        for line in lines:
            values = line.split()
            if len(values) == 3:
                try:
                    wishart_scale = float(values[0])
                    train_nll = float(values[1])
                    test_nll = float(values[2])
                    data_points.append((wishart_scale, train_nll, test_nll))
                except ValueError:
                    continue
        
        if n_samples not in models_data:
            models_data[n_samples] = {}
        models_data[n_samples][model_name] = data_points

# Create plots
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

sample_sizes = sorted(models_data.keys())
colors = {'HVBLL': 'blue', 'VBLL': 'red'}
markers = {'train': 'o', 'test': 's'}
linestyles = {'train': '-', 'test': '--'}

# Store handles for legend
legend_handles = []
legend_labels = []

for i, n in enumerate(sample_sizes):
    ax = axes[i]
    # Remove the individual title since we'll add a subtitle at the bottom
    # ax.set_title(f'N = {n}')
    ax.set_xlabel('$\\sigma_0^2$')
    ax.set_ylabel('Negative Log-Likelihood')
    
    for model in ['HVBLL', 'VBLL']:
        if model in models_data[n]:
            data_points = models_data[n][model]  # No longer filtering outliers
            data_points = sorted(data_points)  # Sort by wishart_scale
            
            x_values = [p[0] for p in data_points]
            train_values = [p[1] for p in data_points]
            test_values = [p[2] for p in data_points]
            
            # Filter out data points with extreme NLL values
            threshold = 50  # Adjust this threshold as needed
            valid_indices = [i for i in range(len(data_points)) 
                            if abs(train_values[i]) < threshold and abs(test_values[i]) < threshold]
            
            filtered_x = [x_values[i] for i in valid_indices]
            filtered_train = [train_values[i] for i in valid_indices]
            filtered_test = [test_values[i] for i in valid_indices]
            
            # Plot without adding to individual subplot legends
            train_line, = ax.plot(filtered_x, filtered_train, 
                    color=colors[model], marker=markers['train'], linestyle=linestyles['train'])
            test_line, = ax.plot(filtered_x, filtered_test, 
                    color=colors[model], marker=markers['test'], linestyle=linestyles['test'])
            
            # Add to legend handles only for the first subplot to avoid duplicates
            if i == 0:
                legend_handles.append(train_line)
                legend_handles.append(test_line)
                legend_labels.append(f'{model} Train')
                legend_labels.append(f'{model} Test')
    
    ax.set_xscale('log')
    # Set the same y-axis limits for both subplots
    # ax.set_ylim([-2, 2.5])
    
    # Add lettered subtitle at the bottom of each subplot (a through b)
    subplot_letter = chr(97 + i)  # ASCII 'a' is 97, so this gives 'a', 'b'
    ax.text(0.5, -0.25, f"({subplot_letter}) $N_s$ = {n}", transform=ax.transAxes, 
           fontsize=12, ha='center', va='center', fontweight='bold')

# Apply tight_layout first
plt.tight_layout(rect=[0, 0, 1, 0.85])  # Reserve space at the top for the legend

# Create a single legend at the top of the figure - moved a bit lower
fig.legend(legend_handles, legend_labels, loc='upper center', bbox_to_anchor=(0.5, 0.95),
           ncol=4, fontsize=12, frameon=True, fancybox=False, shadow=False)

# Don't apply subplots_adjust after tight_layout as it would override the settings
# plt.subplots_adjust(top=0.90, bottom=0.2)

plt.savefig(os.path.join(path0, 'sparse-nll-comparison.png'), dpi=50, bbox_inches='tight')
plt.savefig(os.path.join(path0, 'sparse-nll-comparison.pdf'), dpi=300, bbox_inches='tight')
plt.close()
