#!/usr/bin/env python3

import argparse
import glob
import json
import os
from collections import defaultdict

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import matplotlib.pyplot as plt
import pandas as pd


def load_results(path: str) -> dict[str, list[dict[str, jax.Array]]]:
    with open(path, 'rb') as f:
        structure = json.loads(f.readline().decode().strip())
        structure = jtu.tree_map(lambda x: jnp.zeros(x), structure)
        data = eqx.tree_deserialise_leaves(f, structure)

    return data


def plot_legend() -> None:
    handles, labels = plt.gca().get_legend_handles_labels()

    fig_legend = plt.figure(figsize=(3, 2))
    ax_legend = fig_legend.add_subplot(111)

    legend = ax_legend.legend(handles, labels, loc='center', ncol=7, fancybox=True, shadow=False)
    ax_legend.axis('off')

    fig_legend.canvas.draw()
    bbox = legend.get_window_extent().transformed(fig_legend.dpi_scale_trans.inverted())
    fig_legend.set_size_inches(bbox.width, bbox.height)

    fig_legend.savefig('legend.pdf', bbox_inches='tight')


def plot_filtering_results(result_files: list[str]) -> None:
    filters = [
        'Approx Beliefs', 'Recurrent', 'PF (128)', 'PF (256)', 'PF (512)', 'NBF (64)', 'NBF (128)'
    ]

    for i, files in enumerate(result_files):
        data = defaultdict(lambda: defaultdict(list))
        num_cards = int(files[0].split('/')[-1].split('-')[-4])

        for file in files:
            results = load_results(file)

            for filter in filters:
                for depth, js_divs in results[filter].items():
                    data[filter][depth].append(jnp.mean(js_divs))

        data = jtu.tree_map(lambda x: jnp.array(x), data, is_leaf=lambda x: isinstance(x, list))

        fig = plt.figure(figsize=(4, 3))
        plt.title(f'{num_cards} Cards')

        for filter in filters:
            means, sems = [], []

            for depth, js_divs in data[filter].items():
                means.append(jnp.mean(js_divs))
                sems.append(jnp.std(js_divs) / jnp.sqrt(js_divs.shape[0]))

            xs, ys = jnp.arange(1, len(means) + 1), jnp.arange(0, 0.251, 0.05)
            means, sems = jnp.array(means), jnp.array(sems)

            plt.plot(xs, means, '.-', label=filter)
            plt.fill_between(xs, means - sems, means + sems, alpha=0.25)

        plt.xticks(xs)
        plt.yticks(ys)

        plt.xlabel('Number of Played Cards')
        if i == 0:
            plt.ylabel('JS Divergence')

        plt.ylim(0, 0.26)

        plt.grid()

        plt.savefig(f'goofspiel-filter-results-neurips-{num_cards}.pdf', bbox_inches='tight')
        plt.show()

    plot_legend()


def table_eval_results(data_files: list[str]) -> None:
    nums_cards = [int(data_file.split('/')[-1].split('-')[-3]) for data_file in data_files]

    rows = []
    for data_file in data_files:
        data = load_results(data_file)
        data = data['mean_results']

        rows.append({
            'JS Divergence All Samples': \
                f'{data['js_divs_means'][0]:.5f} ± {data['js_divs_stds'][0]:.5f}',
            'JS Divergence Only Valid Samples': \
                f'{data['js_divs_means'][1]:.5f} ± {data['js_divs_stds'][1]:.5f}',
            'Valid Frequencies': \
                f'{data['valid_freqs_means'][1]:.5f} ± {data['valid_freqs_stds'][1]:.5f}'
        })

    df = pd.DataFrame(rows, index=[f'{num_cards} Cards' for num_cards in nums_cards])
    df.to_latex('goofspiel_model_evals.tex')
    print(df)


def main(args: argparse.Namespace) -> None:
    if args.filter_globs is not None:
        data_files = [
            sorted(glob.glob(f'{args.base_dir}/{args.results_dir}/{filter_glob}'))
            for filter_glob in args.filter_globs
        ]

        plot_filtering_results(data_files)

    elif args.eval_glob is not None:
        data_files = sorted(glob.glob(f'{args.base_dir}/{args.results_dir}/{args.eval_glob}'))

        table_eval_results(data_files)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--eval-glob', type=str, default=None, help='Eval results to load')
    parser.add_argument('--filter-globs', type=str, nargs='+', default=None, help='Filter results to load')
    parser.add_argument('--results-dir', type=str, default='goofspiel-models', help='Results directory')
    args = parser.parse_args()

    main(args)
