import pickle as pkl
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib as mpl
# plt.rcParams['axes.facecolor'] ='snow'


def plot_error(data, label, color, ax):
    data_mean = np.mean(np.array(data), axis=0)
    error_bars = stats.sem(np.array(data))
    ax.plot(data_mean, label=label, color=color)
    ax.fill_between([i for i in range(data_mean.size)],
                    np.squeeze(data_mean - error_bars),
                    np.squeeze(data_mean + error_bars),
                    color=color,
                    alpha=alpha)


data1 = pkl.load(open("results/20250515-015600_AlphaStar_0.5_0.8/data.p", "rb"))
data2 = pkl.load(open("results/20250515-130604_Random game of skill_0.5_0.8/data.p", "rb"))
data3 = pkl.load(open("results/20250515-171443_Blotto_0.5_0.8/data.p", "rb"))
data4 = pkl.load(open("results/20250516-004227_Kuhn-poker_0.5_0.8/data.p", "rb"))
data5 = pkl.load(open("results/20250524-025807_connect_four_0.5_0.8/data.p", "rb"))


alpha = .4
j = 0

font_size=16
mpl.rcParams.update({
    'font.size': font_size,
    'axes.labelsize': font_size,
    'axes.titlesize': font_size,
    'xtick.labelsize': font_size,
    'ytick.labelsize': font_size,
    'legend.fontsize': font_size,
    'figure.titleweight': 'normal',
})
fig, axes = plt.subplots(1, 4, figsize=(18, 5))

colors = ['blue', 'pink','gold', 'gray','red']
label_dict = {
    "psro_exps": 'PSRO',
    "pipeline_psro_exps": 'Pipeline PSRO',
    "rectified_exps": 'PSRO-rN',
    "dpp_psro_exps": 'DPP-PSRO',
    "bd_rd_psro_exps": 'BD&RD-PSRO',
    "psd_psro_exps": 'PSD-PSRO',
    'sparse_psro_1_exps': 'Sparsification-PSRO',
    'sparse_psro_2_exps': 'Sparsity-PSRO',
    "our_psro_exps": 'Sparse-PSRO(ours)',
}

all_methods = ['psro_exps','psd_psro_exps','sparse_psro_1_exps','sparse_psro_2_exps',"our_psro_exps"]
titles = ["Random Game of Skill", "AlphaStar888", "Blotto", "Kuhn Poker", "Go (board size=3,komi=6.5)",
          "Go (board size=4,komi=6.5)"]
datasets = [data1, data2, data3, data4]

for i, ax in enumerate(axes.flatten()):
    data = datasets[i]
    print(titles[i])
    for ii, method in enumerate(all_methods):
        me_data = data[method]
        print(f'{method} exp: {me_data[-1][-1]}')
        if method == "rectified_exps":
            length = min([len(l) for l in me_data])
            for idx, l in enumerate(me_data):
                me_data[idx] = me_data[idx][:length]
        plot_error(me_data, label=label_dict[method], color=colors[ii], ax=ax)

    ax.grid()
    ax.set_title(titles[i], size=font_size)
    ax.set_xlabel("Iterations", size=font_size)
    ax.set_ylabel("Exploitability", size=font_size)
    ax.tick_params(axis='both', labelsize=font_size)
    ax.set_yscale('log')

handles, labels = axes[0].get_legend_handles_labels()
m = len(labels)
# ncol = m // 2 + (m % 2)
ncol = m

fig.legend(handles, labels, loc='upper center', prop={'size': font_size}, ncol=ncol)

plt.tight_layout()
plt.subplots_adjust(top=0.8)
plt.savefig('png/ablation_exp.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)
plt.show()

alpha = .4
j=0
fig2, axes2 = plt.subplots(1, 4, figsize=(18, 5))
colors = ['blue', 'orange', 'green', 'purple', 'brown', 'pink','red']
all_methods = ['psro_exps',"pipeline_psro_exps","rectified_exps","dpp_psro_exps","bd_rd_psro_exps",'psd_psro_exps',"our_psro_exps"]
for i, ax in enumerate(axes2.flatten()):
    data = datasets[i]
    print(titles[i])
    for ii, method in enumerate(all_methods):
        me_data = data[method]
        print(f'{method} exp: {me_data[-1][-1]}')
        if method == "rectified_exps":
            length = min([len(l) for l in me_data])
            for idx, l in enumerate(me_data):
                me_data[idx] = me_data[idx][:length]
        plot_error(me_data, label=label_dict[method], color=colors[ii], ax=ax)

    ax.grid()
    ax.set_title(titles[i], size=font_size)
    ax.set_xlabel("Iterations", size=font_size)
    ax.set_ylabel("Exploitability", size=font_size)
    ax.tick_params(axis='both', labelsize=font_size)
    ax.set_yscale('log')

handles, labels = axes2[0].get_legend_handles_labels()
m = len(labels)


ncol = m
fig2.legend(handles, labels, loc='upper center', prop={'size': font_size}, ncol=ncol)

plt.tight_layout()
plt.subplots_adjust(top=0.8)
plt.savefig('png/comparison_exp.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)
plt.show()


label_dict = {
    "psro_cardinality": 'PSRO',
    "pipeline_psro_cardinality": 'Pipeline PSRO',
    "rectified_cardinality": 'PSRO-rN',
    "dpp_psro_cardinality": 'DPP-PSRO',
    "bd_rd_psro_cardinality": 'BD&RD-PSRO',
    "psd_psro_cardinality": 'PSD-PSRO',
    'sparse_psro_1_cardinality': 'Sparsification-PSRO',
    'sparse_psro_2_cardinality': 'Sparsity-PSRO',
    "our_psro_cardinality": 'Sparse-PSRO(ours)',
}
alpha = .4
j=0
fig2, axes2 = plt.subplots(1, 4, figsize=(18, 5))
colors = ['blue', 'orange', 'green', 'purple', 'brown', 'pink','red']
all_methods = ['psro_cardinality',"pipeline_psro_cardinality","rectified_cardinality","dpp_psro_cardinality","bd_rd_psro_cardinality",'psd_psro_cardinality',"our_psro_cardinality"]
for i, ax in enumerate(axes2.flatten()):
    data = datasets[i]
    print(titles[i])
    for ii, method in enumerate(all_methods):
        me_data = data[method]
        print(f'{method} exp: {me_data[-1][-1]}')
        if method == "rectified_cardinality":
            length = min([len(l) for l in me_data])
            for idx, l in enumerate(me_data):
                me_data[idx] = me_data[idx][:length]
        plot_error(me_data, label=label_dict[method], color=colors[ii], ax=ax)

    ax.grid()
    ax.set_title(titles[i], size=font_size)
    ax.set_xlabel("Iterations", size=font_size)
    ax.set_ylabel("Population Exploitability", size=font_size)
    ax.tick_params(axis='both', labelsize=font_size)
    ax.set_yscale('log')

handles, labels = axes2[0].get_legend_handles_labels()
m = len(labels)


ncol = m
fig2.legend(handles, labels, loc='upper center', prop={'size': font_size}, ncol=ncol)

plt.tight_layout()
plt.subplots_adjust(top=0.8)
plt.savefig('png/comparison_pe.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)
plt.show()

alpha = .4
j=0
fig2, axes2 = plt.subplots(1, 4, figsize=(18, 5))
colors = ['blue', 'pink','gold', 'gray','red']
all_methods = ['psro_cardinality','psd_psro_cardinality','sparse_psro_1_cardinality','sparse_psro_2_cardinality',"our_psro_cardinality"]
for i, ax in enumerate(axes2.flatten()):
    data = datasets[i]
    print(titles[i])
    for ii, method in enumerate(all_methods):
        me_data = data[method]
        print(f'{method} exp: {me_data[-1][-1]}')
        if method == "rectified_cardinality":
            length = min([len(l) for l in me_data])
            for idx, l in enumerate(me_data):
                me_data[idx] = me_data[idx][:length]
        plot_error(me_data, label=label_dict[method], color=colors[ii], ax=ax)

    ax.grid()
    ax.set_title(titles[i], size=font_size)
    ax.set_xlabel("Iterations", size=font_size)
    ax.set_ylabel("Population Exploitability", size=font_size)
    ax.tick_params(axis='both', labelsize=font_size)
    ax.set_yscale('log')

handles, labels = axes2[0].get_legend_handles_labels()
m = len(labels)

ncol = m
fig2.legend(handles, labels, loc='upper center', prop={'size': font_size}, ncol=ncol)

plt.tight_layout()
plt.subplots_adjust(top=0.8)
plt.savefig('png/ablation_pe.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)
plt.show()


label_dict = {
    "psro_pop": 'PSRO',
    "pipeline_psro_pop": 'Pipeline PSRO',
    "rectified_pop": 'PSRO-rN',
    "dpp_psro_pop": 'DPP-PSRO',
    "bd_rd_psro_pop": 'BD&RD-PSRO',
    "psd_psro_pop": 'PSD-PSRO',
    'sparse_psro_1_pop': 'Sparsification-PSRO',
    'sparse_psro_2_pop': 'Sparsity-PSRO',
    "our_psro_pop": 'Sparse-PSRO(ours)',
}
alpha = .4
j=0
fig2, axes2 = plt.subplots(1, 4, figsize=(18, 5))
colors = ['blue', 'orange', 'green', 'purple', 'brown', 'pink','gold', 'gray','red']
all_methods = ['psro_pop',"pipeline_psro_pop","rectified_pop","dpp_psro_pop","bd_rd_psro_pop",'psd_psro_pop','sparse_psro_1_pop','sparse_psro_2_pop',"our_psro_pop"]
for i, ax in enumerate(axes2.flatten()):
    data = datasets[i]
    print(titles[i])
    for ii, method in enumerate(all_methods):
        me_data = data[method]
        pool_size = [len(l) for l in me_data]
        print(f'{method} size: {np.mean(pool_size)}')
        print(f'{pool_size}')
