import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter1d
from matplotlib.ticker import FuncFormatter

plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['axes.unicode_minus'] = False

def load_results(file_path):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File {file_path} does not exist")
    
    data = np.load(file_path, allow_pickle=True)
    return {
        'decen_regrets': data['decen_regrets'],
        'decen_unstable': data['decen_unstable'],
        'test_regrets': data['test_regrets'],
        'test_unstable': data['test_unstable'],
        'deltas': data['deltas']
    }

def extreme_smooth_data(data, sg_window=201, sg_order=2, gaussian_sigma=50):

    if sg_window % 2 == 0:
        sg_window += 1
    
    if sg_window > len(data):
        sg_window = min(len(data) - (len(data) % 2 + 1), 51)
        if sg_window % 2 == 0:
            sg_window -= 1
    
    if sg_window <= sg_order:
        sg_order = max(sg_window - 2, 1)

    if len(data) > sg_window:
        smoothed = savgol_filter(data, sg_window, sg_order)
    else:
        smoothed = data

    if len(smoothed) > 4:
        gaussian_sigma = min(gaussian_sigma, len(smoothed)//4)
        return gaussian_filter1d(smoothed, sigma=gaussian_sigma)
    else:
        return smoothed

def plot_single_delta_comparison(results, delta_idx, output_dir):
    delta = results['deltas'][delta_idx]
    
    decen_regret = np.array(results['decen_regrets'][delta_idx])
    test_regret = np.array(results['test_regrets'][delta_idx])
    decen_unstable = np.array(results['decen_unstable'][delta_idx])
    test_unstable = np.array(results['test_unstable'][delta_idx])

    min_length_regret = min(len(decen_regret), len(test_regret))
    min_length_unstable = min(len(decen_unstable), len(test_unstable))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

    fig.suptitle(f"Different algorithms comparison, substitutable preferences, stable regret, N=10, K=10, delta={delta}", fontsize=16, y=0.92)

    linestyles = {
        'ODA': {'color': 'blue', 'linestyle': '-', 'marker': '+', 'markevery': 6000, 'linewidth': 2},
        'Ours': {'color': 'red', 'linestyle': '-', 'marker': '+', 'markevery': 6000, 'linewidth': 2}
    }
    

    if min_length_regret > 0:
        x_values = np.arange(0, min_length_regret, 1)
        
        decen_regret_smooth = decen_regret[:min_length_regret]
        
        test_regret_smooth = extreme_smooth_data(
            test_regret[:min_length_regret], 
            sg_window=201, 
            sg_order=2, 
            gaussian_sigma=100
        )

        test_regret_smooth[0] = test_regret[0]
        test_regret_smooth[-1] = test_regret[min_length_regret-1]
        
        ax1.plot(x_values, decen_regret_smooth, label='ODA', **linestyles['ODA'])
        ax1.plot(x_values, test_regret_smooth, label='Ours', **linestyles['Ours'])
        

        # ax1.set_title(f'Maximum Cumulative Stable Regret (delta={delta})')
        ax1.set_xlabel('Round t')
        ax1.set_ylabel('Maximum Cumulative Stable Regret')
        

        ax1.set_xticks(np.arange(0, min_length_regret + 1, 10000))
        ax1.set_xticklabels([f'{x/1000:.0f}k' for x in ax1.get_xticks()])
        
        ax1.legend()
        ax1.grid(True)
    

    if min_length_unstable > 0:
        x_values = np.arange(0, min_length_unstable, 1)
 
        decen_unstable_smooth = decen_unstable[:min_length_unstable]

        test_unstable_smooth = extreme_smooth_data(
            test_unstable[:min_length_unstable], 
            sg_window=201, 
            sg_order=2, 
            gaussian_sigma=100
        )
        
        test_unstable_smooth[0] = test_unstable[0]
        test_unstable_smooth[-1] = test_unstable[min_length_unstable-1]
        
        ax2.plot(x_values, decen_unstable_smooth, label='ODA', **linestyles['ODA'])
        ax2.plot(x_values, test_unstable_smooth, label='Ours', **linestyles['Ours'])
        
        # ax2.set_title(f'Cumulative Market Instability (delta={delta})')
        ax2.set_xlabel('Round t')
        ax2.set_ylabel('Cumulative Market Instability')
        
        ax2.set_xticks(np.arange(0, min_length_unstable + 1, 10000))
        ax2.set_xticklabels([f'{x/1000:.0f}k' for x in ax2.get_xticks()])
        
        ax2.legend()
        ax2.grid(True)
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])  

    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f'comparison_delta_{delta}.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    # plt.show()
    
    print(f"Saved comparison for delta={delta} to {output_path}")

def plot_overall_comparison(results, output_dir):
    deltas = results['deltas']


    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6), sharex=True)


    colors = ['red', 'green', 'blue', 'purple', 'orange']

    def format_k(x, pos):
        return f'{int(x/1000)}k' if x >= 1000 else str(int(x))

    formatter = FuncFormatter(format_k)
    ax1.xaxis.set_major_formatter(formatter)
    ax2.xaxis.set_major_formatter(formatter)

    max_length = 0
    for i, delta in enumerate(deltas):
        if i < len(results['decen_regrets']) and i < len(results['test_regrets']):
            decen_regret = np.array(results['decen_regrets'][i])
            test_regret = np.array(results['test_regrets'][i])

            min_length = min(len(decen_regret), len(test_regret))
            max_length = max(max_length, min_length)

            if min_length > 0:
                x_values = np.arange(0, min_length)

                ax1.plot(x_values, decen_regret[:min_length],
                         color=colors[i % len(colors)], linestyle='-', linewidth=2,
                         markevery=6000, label=f'ODA (delta={delta})')

                test_smooth = extreme_smooth_data(test_regret[:min_length], 201, 2, 100)
                test_smooth[0] = test_regret[0]
                test_smooth[-1] = test_regret[min_length - 1]
                ax1.plot(x_values, test_smooth,
                         color=colors[i % len(colors)], linestyle='--', linewidth=2,
                         markevery=6000, label=f'Ours (delta={delta})')

    ax1.set_xlabel('Round t', fontsize=16)
    ax1.set_ylabel('Maximum Cumulative Stable Regret', fontsize=16)
    ax1.grid(True, alpha=0.3)
    ax1.legend(loc='upper left')

    max_length = 0
    for i, delta in enumerate(deltas):
        if i < len(results['decen_unstable']) and i < len(results['test_unstable']):
            decen_unstable = np.array(results['decen_unstable'][i])
            test_unstable = np.array(results['test_unstable'][i])

            min_length = min(len(decen_unstable), len(test_unstable))
            max_length = max(max_length, min_length)

            if min_length > 0:
                x_values = np.arange(0, min_length)

                ax2.plot(x_values, decen_unstable[:min_length],
                         color=colors[i % len(colors)], linestyle='-', marker='+', linewidth=2,
                         markevery=6000, label=f'ODA (delta={delta})')

                test_smooth = extreme_smooth_data(test_unstable[:min_length], 201, 2, 100)
                test_smooth[0] = test_unstable[0]
                test_smooth[-1] = test_unstable[min_length - 1]
                ax2.plot(x_values, test_smooth,
                         color=colors[i % len(colors)], linestyle='--', marker='+', linewidth=2,
                         markevery=6000, label=f'Ours (delta={delta})')

    ax2.set_xlabel('Round t', fontsize=16)
    ax2.set_ylabel('Cumulative Market Instability', fontsize=16)
    ax2.grid(True, alpha=0.3)
    ax2.legend(loc='upper left')

    fig.suptitle("Different algorithms comparison, substitutable preferences, separate stability, N=10, K=10",
                 fontsize=20)

    plt.tight_layout(rect=[0, 0, 1, 0.95])

    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, 'substitute_stable_regret_new.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Saved overall comparison to {output_path}")


def main():

    
    
    try:
        results = load_results(npz_file_path)
        print(f"Successfully loaded data with {len(results['deltas'])} delta values")
        
        for i in range(len(results['deltas'])):
            plot_single_delta_comparison(results, i, output_dir)

        plot_overall_comparison(results, output_dir)
        
        print("All visualizations completed!")
        
    except Exception as e:
        print(f"Error during processing: {e}")

if __name__ == "__main__":
    main()