import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import pandas as pd
import numpy as np

fontsize=12
tick_fontsize=11
legend_fontsize=12
palette = {'VGB': 'green', 'ActionLevel': 'grey'}
markers = {'VGB': 's', 'ActionLevel': 'o'}

def load_training_results(K):
    fname = f"parity_results_c{c}_n{n}.pkl"
    with open(fname, 'rb') as f:
        results = pickle.load(f)
    return results['train_progress_x'], results['train_progress_y']

def load_work_results(c):
    data = []
    for n in [8,9,10]:
        fname = f"parity_results_c{c}_n{n}.pkl"
        with open(fname, 'rb') as f:
            results = pickle.load(f)
        epoch_list = results['epoch_list']
        js_work_list = results['js_work_list']
        tw_work_list = results['tw_work_list']
        for epoch in epoch_list:
            for t in js_work_list[epoch]:
                data.append({'c': c, 'n': n, 'H': c*n, 'epoch': epoch, 'Algorithm': 'VGB', 'time': t[0], 'steps': t[1], 'timeout': t[0]==60})
            for t in tw_work_list[epoch]:
                data.append({'c': c, 'n': n, 'H': c*n, 'epoch': epoch, 'Algorithm': 'ActionLevel', 'time': t[0], 'steps': t[1], 'timeout': t[0]==60})
    df = pd.DataFrame(data)
    return df



def visualize(df, K,save_path=None):
    lims = {5: [22, 52000, 0.135], 4: [35, 130000, 0.75]}
    blims = {5: [0.5, 1000, 0], 4: [0.5, 1000, 0]}
    fig, axes = plt.subplots(3,3,figsize=(10,12))
    i=0
    for y, ylabel in [('time', 'Wall-clock time (s)'), ('steps', 'Step count'), ('timeout', 'Fraction of time-outs')]:
        j=0
        for ax, H in zip(axes[i], sorted(df['H'].unique())):
            subdf = df[df['H'] == H]
            if y in ['time', 'steps']:
                subdf = subdf[subdf['timeout']==False]
            sns.lineplot(data = subdf, x="epoch", y=y, hue="Algorithm", style="Algorithm", palette=palette, markers=markers, dashes=False, errorbar='se', ax=ax)
            if i == 2:
                ax.set_xlabel("Training batch", fontweight="bold",fontsize=fontsize)
            else:
                ax.set_xlabel("")
            if j == 0:
                ax.set_ylabel(ylabel, fontweight="bold",fontsize=fontsize)
            else:
                ax.set_ylabel("")
            if i == 0:
                ax.set_title(f"K = {K}, T = {int(H)//K}", fontweight="bold",fontsize=fontsize)
            ax.tick_params(axis='both', labelsize=tick_fontsize)
            ax.tick_params(axis='x', labelrotation=45)
            handles, labels = ax.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))   # removes duplicates
            ax.legend(by_label.values(), by_label.keys(), fontsize=legend_fontsize)
            if y == 'time':
                ax.set_ylim(bottom=blims[K][0], top=lims[K][0])
                ax.set_yscale("log")
            elif y == 'steps':
                ax.set_ylim(bottom=blims[K][1], top=lims[K][1])
                ax.set_yscale("log")
            elif y == 'timeout':
                ax.set_ylim(bottom=blims[K][2], top=lims[K][2])
            j += 1
        i += 1
    if save_path == None:
        fig.tight_layout(pad=2.0)
        plt.show()
    else:
        fig.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')
        plt.show()


def visualize_training(fig_x_data, fig_y_data, save_path=None, chunk_size=None):
    fig, ax = plt.subplots()
    fig_lines = {}
    # Determine chunk_size if not provided
    if chunk_size is None:
        # Try to infer chunk_size from the gaps in h
        h_list = sorted(fig_y_data.keys())
        if len(h_list) > 1:
            diffs = [h2 - h1 for h1, h2 in zip(h_list[:-1], h_list[1:])]
            chunk_size = min(diffs)
        else:
            chunk_size = 1
    print(chunk_size)
    # Split h indices into two groups
    group0 = [h for h in fig_y_data.keys() if h % chunk_size == 0]
    group1 = [h for h in fig_y_data.keys() if h % chunk_size != 0]
    # Sort for consistent color mapping
    group0 = sorted(group0)
    group1 = sorted(group1)
    print(group0)
    print(group1)
    # Assign colors: higher h = brighter
    cmap0 = plt.cm.Blues
    cmap1 = plt.cm.Reds
    n0 = len(group0)
    n1 = len(group1)
    for i, h in enumerate(group0):
        if len(fig_y_data[h]) > 0:
            color = cmap0((i + 1) / n0) if n0 > 1 else cmap0(1.0)
            fig_lines[h], = ax.plot(fig_x_data, fig_y_data[h], label=f"h={h}", color=color)
    for i, h in enumerate(group1):
        if len(fig_y_data[h]) > 0:
            color = cmap1((i + 1) / n1) if n1 > 1 else cmap1(1.0)
            fig_lines[h], = ax.plot(fig_x_data, fig_y_data[h], label=f"h={h}", color=color)
    # Move legend outside the plot on the right
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.set_xlabel("Training batch", fontweight="bold", fontsize=fontsize)
    ax.set_ylabel("Conditional training error", fontweight="bold", fontsize=fontsize)
    ax.tick_params(axis='both', labelsize=tick_fontsize)
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"Training visualization saved to: {save_path}")
    else:
        fig.tight_layout()
        plt.show()
        print("Training visualization displayed")

