import argparse
import os
import pickle
from glob import glob

import matplotlib.pyplot as plt
import numpy as np

from global_utils import plot_legend


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()

    # fmt: off
    parser.add_argument('--nstates', default=-1, type=int, help='Number of states')
    parser.add_argument('--results_dir', default='gridworld_models', type=str, help='Directory to load stats from')
    parser.add_argument('--show_legend', default=False, action='store_true', help='Show legend in plots')
    # fmt: on

    return parser.parse_args()


def load_combined_stats(args: argparse.Namespace) -> dict[str, dict[str, np.ndarray]]:
    pattern = os.path.join(args.results_dir, '*', 'results.pkl')
    results_paths = glob(pattern, recursive=True)

    combined_stats = {}
    for path in results_paths:
        with open(path, 'rb') as f:
            data = pickle.load(f)

            for stat_key in data:
                if stat_key not in combined_stats:
                    combined_stats[stat_key] = {}

                for model_key in data[stat_key]:
                    if model_key not in combined_stats[stat_key]:
                        combined_stats[stat_key][model_key] = []

                    combined_stats[stat_key][model_key].append(
                        np.mean(data[stat_key][model_key], axis=-1)
                    )

    for stat_key in combined_stats:
        for model_key in combined_stats[stat_key]:
            combined_stats[stat_key][model_key] = np.stack(
                combined_stats[stat_key][model_key], axis=1
            )

    return combined_stats


def print_stats(stats: dict[str, dict[str, np.ndarray]]) -> None:
    for stat in stats:
        print(stat)
        print('*' * 30)

        for model in stats[stat]:
            print(model)
            if np.any(np.isinf(stats[stat][model])):
                print('Inf')
                continue

            print(f'{np.mean(stats[stat][model])} +/- {np.std(stats[stat][model])}')

        print('*' * 30 + '\n')


def plot_stats(args: argparse.Namespace, stats: dict[str, dict[str, np.ndarray]]) -> None:
    for stat in stats:
        plt.figure(figsize=(4, 3))

        for model in stats[stat]:
            if stat == 'JS Divergence' and model == 'True Beliefs':
                continue

            means = np.mean(stats[stat][model], axis=1)
            if np.any(np.isinf(means)):
                continue

            stderrs = np.std(stats[stat][model], axis=1) / np.sqrt(stats[stat][model].shape[1])

            plt.plot(means, '.-', label=model)
            plt.fill_between(np.arange(len(means)), means - stderrs, means + stderrs, alpha=0.2)

        if args.nstates > 0:
            plt.title(rf'$|X|={args.nstates}$ states')

        plt.xlabel('Steps')
        if args.show_legend:
            plt.ylabel(stat)

        plt.xticks(range(0, len(means), 2))
        plt.ylim(bottom=0)

        file_name = stat.replace(' ', '-').lower()
        file_name = os.path.join(args.results_dir, f'{file_name}.pdf')

        plt.grid()
        plt.tight_layout()
        plt.savefig(file_name, bbox_inches='tight')

        if args.show_legend and stat == 'JS Divergence':
            plot_legend(args.results_dir)

        plt.close()


if __name__ == '__main__':
    args = parse_args()
    stats = load_combined_stats(args)
    print_stats(stats)
    plot_stats(args, stats)
