import argparse
import os
from glob import glob
import numpy as np
import pickle
import matplotlib.pyplot as plt

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="results", type=str, help="Directory to load stats.")
    parser.add_argument("--show_legend", default=False, action="store_true", help="Show legend in plots.")
    parser.add_argument("--nstates", default=-1, type=int, help="Number of states.")
    return parser.parse_args()

def load_combined_stats(args):
    pattern = os.path.join(args.dir, "*", "results.pkl")
    results_paths = glob(pattern, recursive=True)
    combined_stats = {}
    for path in results_paths:
        with open(path, "rb") as f:
            data = pickle.load(f)
            for stat_key in data.keys():
                if stat_key not in combined_stats:
                    combined_stats[stat_key] = {}
                for model_key in data[stat_key].keys():
                    if model_key not in combined_stats[stat_key]:
                        combined_stats[stat_key][model_key] = []
                    combined_stats[stat_key][model_key].append(np.mean(data[stat_key][model_key], axis=-1))
    for stat_key in combined_stats:
        for model_key in combined_stats[stat_key]:
            combined_stats[stat_key][model_key] = np.stack(combined_stats[stat_key][model_key], axis=1)
    return combined_stats

def print_stats(stats):
    for stat in stats.keys():
        print(stat)
        print("*" * 30)
        for model in stats[stat].keys():
            print(model)
            print(f"{stats[stat][model].mean()} +/- {stats[stat][model].std()}" )
        print("*" * 30)


def plot_legend_in_separate_figure(dirname):
    # Extract the handles and labels for the legend from the main plot
    handles, labels = plt.gca().get_legend_handles_labels()
    # Create a new figure for the legend
    fig_legend = plt.figure(figsize=(3, 2))
    ax_legend = fig_legend.add_subplot(111)
    # Create the legend in the new figure using the handles and labels from the main plot
    legend = ax_legend.legend(handles, labels, loc='center', ncol=7, shadow=False, fancybox=True)
    ax_legend.axis('off')
    # Adjust the bounding box of the legend and the figure size
    fig_legend.canvas.draw()
    bbox = legend.get_window_extent().transformed(fig_legend.dpi_scale_trans.inverted())
    fig_legend.set_size_inches(bbox.width, bbox.height)
    # Save the legend to a separate PDF file
    fig_legend.savefig(os.path.join(dirname, 'legend.pdf'), bbox_inches='tight')

    plt.close(fig_legend)

def plot_stats(args, stats):
    for stat in stats.keys():
        plt.figure(figsize=(4,3))
        for model in stats[stat].keys():
            if stat == "JS Divergence" and model == "True Beliefs":
                continue
            err = stats[stat][model].std(axis=1) / np.sqrt(stats[stat][model].shape[1])
            means = stats[stat][model].mean(axis=1)
            plt.xticks(range(0, len(means), 2))
            plt.plot(means, ".-", label=model)
            plt.fill_between(np.arange(len(means)), means - err, means + err, alpha=0.2)
        plt.ylim(bottom=0)
        plt.xlabel("Steps")
        if args.show_legend:
            plt.ylabel(stat)
        statname = stat.replace(" ", "-")
        fname = os.path.join(args.dir, f"{statname}.pdf")
        if args.nstates > 0:
            plt.title(rf"$|X|={args.nstates}$ states")
        plt.tight_layout()
        plt.grid(True, which="both")
        plt.savefig(fname, bbox_inches="tight")
        if args.show_legend and stat == "JS Divergence":
            plot_legend_in_separate_figure(args.dir)
        plt.close()
        

if __name__ == "__main__":
    args = parse_args()
    stats = load_combined_stats(args)
    print_stats(stats)
    plot_stats(args, stats)