#!/usr/bin/env python3

import argparse
import glob
import json
import os
from collections import defaultdict

import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

from global_utils import plot_legend


def load_results(path: str) -> dict[str, list[dict[str, jax.Array]]]:
    with open(path, 'rb') as f:
        structure = json.loads(f.readline().decode().strip())
        structure = jax.tree.map(lambda x: jnp.zeros(x), structure)
        data = eqx.tree_deserialise_leaves(f, structure)

    return data


def plot_filtering_results(result_files: list[str]) -> None:
    filters = [
        'Approx Beliefs',
        'Recurrent',
        'PF (128)',
        'PF (256)',
        'PF (512)',
        'NBF (64)',
        'NBF (128)',
    ]

    for i, files in enumerate(result_files):
        data = defaultdict(lambda: defaultdict(list))
        num_cards = int(files[0].split('/')[-1].split('-')[-4])

        for file in files:
            results = load_results(file)

            for filter in filters:
                for depth, js_divs in results[filter].items():
                    data[filter][depth].append(jnp.mean(js_divs))

        data = jax.tree.map(lambda x: jnp.array(x), data, is_leaf=lambda x: isinstance(x, list))

        plt.figure(figsize=(4, 3))
        plt.title(f'{num_cards} Cards')

        for filter in filters:
            means, sems = [], []

            for js_divs in data[filter].values():
                means.append(jnp.mean(js_divs))
                sems.append(jnp.std(js_divs) / jnp.sqrt(js_divs.shape[0]))

            xs, ys = jnp.arange(1, len(means) + 1), jnp.arange(0, 0.251, 0.05)
            means, sems = jnp.array(means), jnp.array(sems)

            plt.plot(xs, means, '.-', label=filter)
            plt.fill_between(xs, means - sems, means + sems, alpha=0.25)

        plt.xticks(xs)
        plt.yticks(ys)

        plt.xlabel('Number of Played Cards')
        if i == 0:
            plt.ylabel('JS Divergence')

        plt.ylim(0, 0.26)

        plt.grid()

        plt.savefig(f'goofspiel-filter-results-neurips-{num_cards}.pdf', bbox_inches='tight')
        plt.show()

    plot_legend()


def table_eval_results(data_files: list[str]) -> None:
    nums_cards = [int(data_file.split('/')[-1].split('-')[-3]) for data_file in data_files]

    rows = []
    for data_file in data_files:
        data = load_results(data_file)
        data = data['mean_results']

        rows.append(
            {
                'JS Divergence All Samples': f'{data["js_divs_means"][0]:.5f} ± {data["js_divs_stds"][0]:.5f}',
                'JS Divergence Only Valid Samples': f'{data["js_divs_means"][1]:.5f} ± {data["js_divs_stds"][1]:.5f}',
                'Valid Frequencies': f'{data["valid_freqs_means"][1]:.5f} ± {data["valid_freqs_stds"][1]:.5f}',
            }
        )

    df = pd.DataFrame(rows, index=[f'{num_cards} Cards' for num_cards in nums_cards])
    df.to_latex('goofspiel_model_evals.tex')
    print(df)


def main(args: argparse.Namespace) -> None:
    if args.filter_globs is not None:
        data_files = [
            sorted(glob.glob(f'{args.base_dir}/{args.results_dir}/{filter_glob}'))
            for filter_glob in args.filter_globs
        ]

        plot_filtering_results(data_files)

    elif args.eval_glob is not None:
        data_files = sorted(glob.glob(f'{args.base_dir}/{args.results_dir}/{args.eval_glob}'))

        table_eval_results(data_files)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--eval-glob', type=str, default=None, help='Eval results to load')
    parser.add_argument(
        '--filter-globs', type=str, nargs='+', default=None, help='Filter results to load'
    )
    parser.add_argument(
        '--results-dir', type=str, default='goofspiel-models', help='Results directory'
    )
    args = parser.parse_args()

    main(args)
