import pickle
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
from datetime import datetime

def plot_results():
    p_styles = {
        0.9: {'color': 'blue', 'marker': 'o', 'linestyle': '--'},
        0.5: {'color': 'red', 'marker': 's', 'linestyle': '--'},
        0.1: {'color': 'green', 'marker': '^', 'linestyle': '--'}
    }

    # Enable Latex text processing
    plt.rcParams.update({
        "text.usetex": False,
        "font.family": "serif",
        "mathtext.fontset": "cm",  # Required for \mathfrak
    }) 
    plt.figure(figsize=(10, 6))
    
    data_files = glob.glob(r"data\Anchored_AMDP\kl\p_*.pkl")

    for file in sorted(data_files):
        print(file.split('_'))
        try:
            # For KL, set file.split('_')[2]
            # For Chi-square, set file.split('_')[3]
            p_value = float(file.split('_')[2])
        except:
            continue
            
        with open(file, 'rb') as f:
            sample_error_data = pickle.load(f)
        
        n_samples_list = sorted(sample_error_data.keys())
        errors = [sample_error_data[n_sample] for n_sample in n_samples_list]
        
        # compute fitting line
        log_n_samples = np.log(n_samples_list)
        log_errors = np.log(errors)
        coefficients = np.polyfit(log_n_samples, log_errors, 1)
        slope, intercept = coefficients
        
        log_n_samples_fit = np.linspace(min(log_n_samples), max(log_n_samples), 100)
        log_errors_fit = slope * log_n_samples_fit + intercept
        n_sample_fit = np.exp(log_n_samples_fit)
        errors_fit = np.exp(log_errors_fit)
        
        # Grab p
        style = p_styles.get(p_value, {'color': 'black', 'marker': 'o', 'linestyle': '--'})
        
        # Plot the data points
        line = plt.plot(n_samples_list, errors, 
                       marker=style['marker'], 
                       linestyle='none',
                       color=style['color'],
                       label='_nolegend_')
        
        plt.plot(n_sample_fit, errors_fit, 
                linestyle=style['linestyle'],
                color=style['color'],
                label=r'$\mathfrak{p}=%.1f$ (slope=%.2f)' % (p_value, slope))
        
        plt.plot([], [], 
                marker=style['marker'],
                linestyle='none',
                color=style['color'],
                label='_nolegend_')

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('number of samples', fontsize=24)
    plt.ylabel('average error', fontsize=24)

    plt.tick_params(axis='both', which='major', labelsize=18)

    plt.grid(True, which="both", ls="--", alpha=0.5)
    
    handles, labels = plt.gca().get_legend_handles_labels()
    
    plt.legend(handles, labels, 
              loc='upper right',
              ncol=1,
              framealpha=1,
              handlelength=2.5,
              markerscale=1.2,
              fontsize=18)
    
    output_dir = r"data\Anchored_AMDP\kl"
    os.makedirs(output_dir, exist_ok=True)
    output_filename = os.path.join(output_dir, f"final_graph_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
    plt.savefig(output_filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved to {output_filename}")
    plt.close()

if __name__ == '__main__':
    plot_results()