#!/usr/bin/env python3

import argparse
import json
import os
import sys
import warnings
from itertools import combinations
from typing import Callable

sys.path.append('src')

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import PyTree

from efg.goofspiel import IIGoofspiel
from goofspiel.generate_data import filter_possible_histories, generate_samples
from goofspiel.generate_data import get_all_histories, sample_state
from goofspiel.generate_policies import load_policy_ckpts, load_policy_network
from goofspiel.model_utils import ConditionalFlow, load_model
from jax.scipy.special import kl_div


def filter_valid_samples(samples: jax.Array, possible_samples: jax.Array) -> tuple[jax.Array, int]:
    def _is_valid(sample: jax.Array, possible_samples: jax.Array) -> bool:
        return jnp.any(jnp.all(sample == possible_samples, axis=1))

    valid_mask = jax.vmap(_is_valid, in_axes=(0, None))(samples, possible_samples)

    return samples[valid_mask], jnp.sum(valid_mask)


def compute_js_divergence(samples1: jax.Array, samples2: jax.Array) -> tuple[float, int, int]:
    """ Compute Jensen-Shannon divergence between two discrete empirical distributions """

    samples1, samples1_counts = jnp.unique(samples1, return_counts=True, axis=0)
    samples2, samples2_counts = jnp.unique(samples2, return_counts=True, axis=0)
    support = jnp.unique(jnp.concatenate([samples1, samples2], axis=0), axis=0)

    def _compute_probs(samples: jax.Array, counts: jax.Array, support: jax.Array) -> jax.Array:
        # Compute indices of samples in `support` that match samples in `samples`
        indices = jnp.where(jnp.all(support[:, jnp.newaxis, :] == samples, axis=2))[0]
        probs = jnp.zeros(support.shape[0]).at[indices].set(counts)

        # Add 1e-12 to avoid division by zero, when `samples` is empty
        return probs / (jnp.sum(probs) + 1e-12)

    samples1_probs = _compute_probs(samples1, samples1_counts, support)
    samples2_probs = _compute_probs(samples2, samples2_counts, support)

    def _js_div(p: jax.Array, q: jax.Array) -> float:
        p = (p + 1e-12) / jnp.sum(p + 1e-12)
        q = (q + 1e-12) / jnp.sum(q + 1e-12)
        m = 0.5 * (p + q)

        return jnp.clip(jnp.sum(0.5 * kl_div(p, m) + 0.5 * kl_div(q, m)), 0.0)

    return _js_div(samples1_probs, samples2_probs), samples1.shape[0], samples2.shape[0]


def compute_batch_statistics(
    action_histories: jax.Array, samples1: jax.Array, samples2: jax.Array, depth: int
) -> dict[str, jax.Array]:
    """ Compute JS divergence, valid frequencies and sizes of two batches of distributions """

    assert samples1.shape == samples2.shape, 'Shape mismatch!'

    num_dists, num_samples, num_cards = samples1.shape
    js_divs = np.zeros([num_dists, 2], np.float32)  # [0] = all, [1] = valid
    valid_freqs = np.zeros([num_dists, 2], np.float32)  # [0] = samples1, [1] = samples2
    dists_sizes = np.zeros([num_dists, 5], np.float32)  # ...

    histories, _ = get_all_histories(num_cards, depth)

    for i in range(num_dists):
        # Gather all possible valid histories that would produce the same sequence of outcomes
        valid_mask = filter_possible_histories(action_histories[i], histories)
        possible_histories, possible_histories_count = histories[valid_mask], jnp.sum(valid_mask)

        # Filter out invalid samples from both distributions
        valid_samples1, valid_samples1_count = filter_valid_samples(samples1[i], possible_histories)
        valid_samples2, valid_samples2_count = filter_valid_samples(samples2[i], possible_histories)

        # Compute frequencies of valid samples in each distribution
        valid_samples1_freq = valid_samples1_count / num_samples
        valid_samples2_freq = valid_samples2_count / num_samples

        # Compute JS divergence between the two empirical distributions
        js_div, *unique_counts = compute_js_divergence(samples1[i], samples2[i])
        valid_js_div, *valid_unique_counts = compute_js_divergence(valid_samples1, valid_samples2)

        js_divs[i] = np.array([js_div, valid_js_div])
        valid_freqs[i] = np.array([valid_samples1_freq, valid_samples2_freq])
        dists_sizes[i] = np.array([possible_histories_count, *unique_counts, *valid_unique_counts])

    js_divs_means, js_divs_stds = np.mean(js_divs, axis=0), np.std(js_divs, axis=0)
    valid_freqs_means, valid_freqs_stds = np.mean(valid_freqs, axis=0), np.std(valid_freqs, axis=0)
    dists_sizes_means, dists_sizes_stds = np.mean(dists_sizes, axis=0), np.std(dists_sizes, axis=0)

    return {
        'js_divs_means': js_divs_means, 'js_divs_stds': js_divs_stds,
        'valid_freqs_means': valid_freqs_means, 'valid_freqs_stds': valid_freqs_stds,
        'dists_sizes_means': dists_sizes_means, 'dists_sizes_stds': dists_sizes_stds
    }


def report_metrics(stats: dict[str, jax.Array], name1: str, name2: str, depth: int | None) -> None:
    js_divs_means, js_divs_stds = stats['js_divs_means'], stats['js_divs_stds']
    valid_freqs_means, valid_freqs_stds = stats['valid_freqs_means'], stats['valid_freqs_stds']
    dists_sizes_means, dists_sizes_stds = stats['dists_sizes_means'], stats['dists_sizes_stds']

    print(f'Depth {depth}:' if depth else 'Average across all depths:')
    print(f'\tJS Div: {js_divs_means[0]:.5f} +- {js_divs_stds[0]:.5f}')
    print(f'\tValid-only JS Div: {js_divs_means[1]:.5f} +- {js_divs_stds[1]:.5f}')

    print(f'\t{name1} valid frequency: {valid_freqs_means[0]:.5f} +- {valid_freqs_stds[0]:.5f}')
    print(f'\t{name2} valid frequency: {valid_freqs_means[1]:.5f} +- {valid_freqs_stds[1]:.5f}')

    print(f'\tTrue distribution size: {dists_sizes_means[0]:.5f} +- {dists_sizes_stds[0]:.5f}')
    print(f'\t{name1} distribution size: {dists_sizes_means[1]:.5f} +- {dists_sizes_stds[1]:.5f}')
    print(f'\t{name2} distribution size: {dists_sizes_means[2]:.5f} +- {dists_sizes_stds[2]:.5f}')
    print(f'\t{name1} valid-only distribution size: {dists_sizes_means[3]:.5f} +- {dists_sizes_stds[3]:.5f}')
    print(f'\t{name2} valid-only distribution size: {dists_sizes_means[4]:.5f} +- {dists_sizes_stds[4]:.5f}')


def evaluate(
    game: IIGoofspiel, policy_network: Callable, policy_ckpts: PyTree[jax.Array],
    model: ConditionalFlow, num_cards: int, num_dists: int, num_samples: int,
    num_cond_samples: int, report: bool, key: jax.random.PRNGKey
) -> list[dict[str, jax.Array]]:
    results = []

    for depth in range(1, num_cards + 1):
        histories, hands = get_all_histories(num_cards, depth)

        # Sample `num_dists` states at the given depth `depth`
        key, *sample_keys = jax.random.split(key, 1 + num_dists)
        states = jax.vmap(sample_state, in_axes=None)(
            game, policy_network, policy_ckpts, jnp.int32(depth + 1), key=jnp.array(sample_keys)
        )

        # Generate samples (opponent's histories) from `states` using the true posterior
        key, *generate_keys = jax.random.split(key, 1 + num_dists)
        true_samples = jax.vmap(generate_samples, in_axes=(None,) * 3 + (0,) + (None,) * 4)(
            game, policy_network, policy_ckpts, states, histories, hands,
            jnp.int32(depth + 1), num_samples, key=jnp.array(generate_keys)
        )

        # Generate samples (opponent's histories) from `states` using a Normalizing Flow
        # model parameterized by embedding vectors computed from some `true_samples`
        key, *flow_keys = jax.random.split(key, 1 + num_dists)
        embeddings = jax.vmap(model.embed)(true_samples[:, :num_cond_samples])
        flow_samples = jax.vmap(model.sample, in_axes=(0, None))(
            embeddings, num_samples, key=jnp.array(flow_keys)
        )

        # Compare two empirical distributions of histories generated from `states`
        stats = compute_batch_statistics(states.action_history, true_samples, flow_samples, depth)
        results.append(stats)

        if report:
            report_metrics(stats, 'Posterior', 'NFs', depth)
            print('\n' + 32 * '=' + '\n')

    return results


def save_results(path: str, data: dict[str, list[dict[str, jax.Array]]]) -> None:
    with open(path.replace('eqx', 'eval.eqx'), 'wb') as f:
        structure = jtu.tree_map(lambda x: x.shape[0], data)
        f.write((json.dumps(structure) + '\n').encode())
        eqx.tree_serialise_leaves(f, data)


def main(args: argparse.Namespace) -> None:
    np.random.seed(args.seed)
    key = jax.random.PRNGKey(args.seed)

    game = IIGoofspiel(2, args.num_cards)
    policy_network = load_policy_network(args.num_cards, seed=0)
    policy_ckpts = load_policy_ckpts(f'{args.base_dir}/{args.policy_dir}', args.num_cards)

    key, load_key = jax.random.split(key, 2)
    model = load_model(f'{args.base_dir}/{args.model_dir}/{args.model_ckpt}', load_key)

    results = []
    for _ in range(args.num_policy_pairs):
        indices = np.random.choice(len(policy_ckpts), 2, replace=False)
        policies = [policy_ckpts[i] for i in indices]

        key, evaluate_key = jax.random.split(key, 2)
        data = evaluate(
            game, policy_network, policies, model, args.num_cards,
            args.num_dists, args.num_samples, args.num_cond_samples,
            report=args.report_metrics, key=evaluate_key
        )

        results.append(data)

    # Average across all pairs of policies
    mean_results_per_depth = jtu.tree_map(lambda *xs: jnp.mean(jnp.stack(xs), axis=0), *results)

    # Report averaged metrics for each depth
    for depth in range(1, args.num_cards + 1):
        report_metrics(mean_results_per_depth[depth - 1], 'Posterior', 'NFs', depth)
    print('\n' + 32 * '=' + '\n')

    # Average across all valid depths
    mean_results = jtu.tree_map(lambda *xs: jnp.mean(jnp.stack(xs), axis=0), *mean_results_per_depth)

    # Report averaged metrics across all depths
    report_metrics(mean_results, 'Posterior', 'NFs', None)

    # Save the results to the disk
    data = {'mean_results_per_depth': mean_results_per_depth, 'mean_results': mean_results}
    save_results(f'{args.base_dir}/{args.model_dir}/{args.model_ckpt}', data)


if __name__ == '__main__':
    warnings.simplefilter('ignore')

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--model-ckpt', type=str, default='model-test.eqx', help='Model checkpoint filename')
    parser.add_argument('--model-dir', type=str, default='goofspiel-models', help='Model directory')
    parser.add_argument('--num-cards', type=int, default=5, help='Cards in the game')
    parser.add_argument('--num-cond-samples', type=int, default=64, help='Number of true samples to embed')
    parser.add_argument('--num-dists', type=int, default=64, help='Number of infostates to sample')
    parser.add_argument('--num-policy-pairs', type=int, default=8, help='Number of policy pairs to evaluate')
    parser.add_argument('--num-samples', type=int, default=128, help='Number of histories to sample in each state')
    parser.add_argument('--policy-dir', type=str, default='goofspiel-policies', help='Policy directory')
    parser.add_argument('--report-metrics', default=False, action='store_true', help='Report metrics')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    args = parser.parse_args()

    main(args)
