#!/usr/bin/env python3

import argparse
import json
import os
import sys
import warnings
from collections.abc import Callable
from itertools import combinations

sys.path.append('src')

import equinox as eqx
import jax
import jax.numpy as jnp
import networkx as nx
import numpy as np
from efg.goofspiel import IIGoofspiel
from jax.scipy.special import kl_div
from jaxtyping import PyTree

from goofspiel.generate_data import (
    filter_possible_histories,
    generate_samples,
    get_all_histories,
    sample_state,
)
from goofspiel.generate_policies import load_policy_ckpts, load_policy_network
from goofspiel.utils import ConditionalFlow, load_model


def create_graph(actions: jax.Array, outcomes: jax.Array, num_cards: int) -> nx.DiGraph:
    graph = nx.DiGraph()

    for action, outcome in zip(actions, outcomes):
        action = int(action)

        if outcome == 1:  # Player 1 won this turn
            for opponent_action in range(0, action):
                graph.add_edge(f'P0{action}', f'P1{opponent_action}', capacity=1)
        elif outcome == 0:  # Both players played the same card
            graph.add_edge(f'P0{action}', f'P1{action}', capacity=1)
        elif outcome == -1:  # Player 2 won this turn
            for opponent_action in range(action + 1, num_cards):
                graph.add_edge(f'P0{action}', f'P1{opponent_action}', capacity=1)
        else:
            raise ValueError(f'Unexpected outcome value: {outcome}!')

    return graph


def find_maximum_matchings(graph: nx.DiGraph, depth: int) -> list[tuple[tuple[str, str], ...]]:
    def is_valid_matching(edges: tuple[tuple[str, str], ...]) -> bool:
        visited = set()

        for u, v in edges:
            if u in visited or v in visited:
                return False

            visited.add(u)
            visited.add(v)

        return True

    # Calling both `graph.edges()` and `combinations()` preserves the order the edges were added
    # in; therefore we do not need to check the order later in `decode_maximum_matchings()`.
    return [edges for edges in combinations(graph.edges(), depth) if is_valid_matching(edges)]


def decode_maximum_matchings(matchings: list[tuple[tuple[str, str], ...]]) -> jax.Array:
    return jnp.array([[int(v[2:]) for _, v in edges] for edges in matchings])


def convert_actions_to_hand(actions: jax.Array, num_cards: int) -> jax.Array:
    return jnp.ones(num_cards, jnp.int32).at[actions].set(0)


def get_possible_hands(
    action_history: jax.Array, num_cards: int, depth: int
) -> tuple[jax.Array, jax.Array]:
    actions = action_history[1:depth, 0]
    opp_actions = action_history[1:depth, 1]
    outcomes = jnp.sign(actions - opp_actions)

    # Create a bipartite graph, find all maximum matchings and decode them into action sequences.
    # Each action sequence would generate the same sequence of outcomes if played by the opponent
    graph = create_graph(actions, outcomes, num_cards)
    matchings = find_maximum_matchings(graph, depth - 1)
    action_seqs = decode_maximum_matchings(matchings)

    # Convert action sequences to possible opponent's hands at this point in the game
    possible_hands = jax.vmap(convert_actions_to_hand, in_axes=(0, None))(action_seqs, num_cards)
    possible_hands, indices = jnp.unique(possible_hands, return_index=True, axis=0)

    action_hists = jnp.full((possible_hands.shape[0], num_cards + 1), -1, jnp.int32)
    action_hists = action_hists.at[:, 1:depth].set(action_seqs[indices])

    return possible_hands, action_hists


def filter_valid_samples(samples: jax.Array, possible_samples: jax.Array) -> tuple[jax.Array, int]:
    def _is_valid(sample: jax.Array, possible_samples: jax.Array) -> bool:
        return jnp.any(jnp.all(sample == possible_samples, axis=1))

    valid_mask = jax.vmap(_is_valid, in_axes=(0, None))(samples, possible_samples)

    return samples[valid_mask], jnp.sum(valid_mask)


def compute_js_divergence(samples1: jax.Array, samples2: jax.Array) -> tuple[float, int, int]:
    """Compute Jensen-Shannon divergence between two discrete empirical distributions"""

    samples1, samples1_counts = jnp.unique(samples1, return_counts=True, axis=0)
    samples2, samples2_counts = jnp.unique(samples2, return_counts=True, axis=0)
    support = jnp.unique(jnp.concatenate([samples1, samples2], axis=0), axis=0)

    def _compute_probs(samples: jax.Array, counts: jax.Array, support: jax.Array) -> jax.Array:
        # Compute indices of samples in `support` that match samples in `samples`
        indices = jnp.where(jnp.all(support[:, jnp.newaxis, :] == samples, axis=2))[0]
        probs = jnp.zeros(support.shape[0]).at[indices].set(counts)

        # Add 1e-12 to avoid division by zero, when `samples` is empty
        return probs / (jnp.sum(probs) + 1e-12)

    samples1_probs = _compute_probs(samples1, samples1_counts, support)
    samples2_probs = _compute_probs(samples2, samples2_counts, support)

    def _js_div(p: jax.Array, q: jax.Array) -> float:
        p = (p + 1e-12) / jnp.sum(p + 1e-12)
        q = (q + 1e-12) / jnp.sum(q + 1e-12)
        m = 0.5 * (p + q)

        return jnp.clip(jnp.sum(0.5 * kl_div(p, m) + 0.5 * kl_div(q, m)), 0.0)

    return _js_div(samples1_probs, samples2_probs), samples1.shape[0], samples2.shape[0]


def compute_batch_statistics(
    action_histories: jax.Array, samples1: jax.Array, samples2: jax.Array, depth: int
) -> dict[str, jax.Array]:
    """Compute JS divergence, valid frequencies and sizes of two batches of distributions"""

    assert samples1.shape == samples2.shape, 'Shape mismatch!'

    num_dists, num_samples, num_cards = samples1.shape
    js_divs = np.zeros([num_dists, 2], np.float32)  # [0] = all, [1] = valid
    valid_freqs = np.zeros([num_dists, 2], np.float32)  # [0] = samples1, [1] = samples2
    dists_sizes = np.zeros([num_dists, 5], np.float32)  # ...

    histories, _ = get_all_histories(num_cards, depth)

    for i in range(num_dists):
        # Gather all possible valid histories that would produce the same sequence of outcomes
        valid_mask = filter_possible_histories(action_histories[i], histories)
        possible_histories, possible_histories_count = histories[valid_mask], jnp.sum(valid_mask)

        # Filter out invalid samples from both distributions
        valid_samples1, valid_samples1_count = filter_valid_samples(samples1[i], possible_histories)
        valid_samples2, valid_samples2_count = filter_valid_samples(samples2[i], possible_histories)

        # Compute frequencies of valid samples in each distribution
        valid_samples1_freq = valid_samples1_count / num_samples
        valid_samples2_freq = valid_samples2_count / num_samples

        # Compute JS divergence between the two empirical distributions
        js_div, *unique_counts = compute_js_divergence(samples1[i], samples2[i])
        valid_js_div, *valid_unique_counts = compute_js_divergence(valid_samples1, valid_samples2)

        js_divs[i] = np.array([js_div, valid_js_div])
        valid_freqs[i] = np.array([valid_samples1_freq, valid_samples2_freq])
        dists_sizes[i] = np.array([possible_histories_count, *unique_counts, *valid_unique_counts])

    js_divs_means, js_divs_stds = np.mean(js_divs, axis=0), np.std(js_divs, axis=0)
    valid_freqs_means, valid_freqs_stds = np.mean(valid_freqs, axis=0), np.std(valid_freqs, axis=0)
    dists_sizes_means, dists_sizes_stds = np.mean(dists_sizes, axis=0), np.std(dists_sizes, axis=0)

    return {
        'js_divs_means': js_divs_means,
        'js_divs_stds': js_divs_stds,
        'valid_freqs_means': valid_freqs_means,
        'valid_freqs_stds': valid_freqs_stds,
        'dists_sizes_means': dists_sizes_means,
        'dists_sizes_stds': dists_sizes_stds,
    }


def report_metrics(stats: dict[str, jax.Array], name1: str, name2: str, depth: int | None) -> None:
    js_divs_means, js_divs_stds = stats['js_divs_means'], stats['js_divs_stds']
    valid_freqs_means, valid_freqs_stds = stats['valid_freqs_means'], stats['valid_freqs_stds']
    dists_sizes_means, dists_sizes_stds = stats['dists_sizes_means'], stats['dists_sizes_stds']

    print(f'Depth {depth}:' if depth else 'Average across all depths:')
    print(f'\tJS Div: {js_divs_means[0]:.5f} +- {js_divs_stds[0]:.5f}')
    print(f'\tValid-only JS Div: {js_divs_means[1]:.5f} +- {js_divs_stds[1]:.5f}')

    print(f'\t{name1} valid frequency: {valid_freqs_means[0]:.5f} +- {valid_freqs_stds[0]:.5f}')
    print(f'\t{name2} valid frequency: {valid_freqs_means[1]:.5f} +- {valid_freqs_stds[1]:.5f}')

    print(f'\tTrue distribution size: {dists_sizes_means[0]:.5f} +- {dists_sizes_stds[0]:.5f}')
    print(f'\t{name1} distribution size: {dists_sizes_means[1]:.5f} +- {dists_sizes_stds[1]:.5f}')
    print(f'\t{name2} distribution size: {dists_sizes_means[2]:.5f} +- {dists_sizes_stds[2]:.5f}')
    print(
        f'\t{name1} valid-only distribution size: {dists_sizes_means[3]:.5f} +- {dists_sizes_stds[3]:.5f}'
    )
    print(
        f'\t{name2} valid-only distribution size: {dists_sizes_means[4]:.5f} +- {dists_sizes_stds[4]:.5f}'
    )


def evaluate(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    model: ConditionalFlow,
    num_cards: int,
    num_dists: int,
    num_samples: int,
    num_cond_samples: int,
    report: bool,
    key: jax.Array,
) -> list[dict[str, jax.Array]]:
    results = []

    for depth in range(1, num_cards + 1):
        histories, hands = get_all_histories(num_cards, depth)

        # Sample `num_dists` states at the given depth `depth`
        key, *sample_keys = jax.random.split(key, 1 + num_dists)
        states = jax.vmap(sample_state, in_axes=None)(
            game, policy_network, policy_ckpts, jnp.int32(depth + 1), key=jnp.array(sample_keys)
        )

        # Generate samples (opponent's histories) from `states` using the true posterior
        key, *generate_keys = jax.random.split(key, 1 + num_dists)
        true_samples = jax.vmap(generate_samples, in_axes=(None,) * 3 + (0,) + (None,) * 4)(
            game,
            policy_network,
            policy_ckpts,
            states,
            histories,
            hands,
            jnp.int32(depth + 1),
            num_samples,
            key=jnp.array(generate_keys),
        )

        # Generate samples (opponent's histories) from `states` using a Normalizing Flow
        # model parameterized by embedding vectors computed from some `true_samples`
        key, *flow_keys = jax.random.split(key, 1 + num_dists)
        embeddings = jax.vmap(model.embed)(true_samples[:, :num_cond_samples])
        flow_samples = jax.vmap(model.sample, in_axes=(0, None))(
            embeddings, num_samples, key=jnp.array(flow_keys)
        )

        # Compare two empirical distributions of histories generated from `states`
        stats = compute_batch_statistics(states.action_history, true_samples, flow_samples, depth)
        results.append(stats)

        if report:
            report_metrics(stats, 'Posterior', 'NFs', depth)
            print('\n' + 32 * '=' + '\n')

    return results


def save_results(path: str, data: dict[str, list[dict[str, jax.Array]]]) -> None:
    with open(path.replace('eqx', 'eval.eqx'), 'wb') as f:
        structure = jax.tree.map(lambda x: x.shape[0], data)
        f.write((json.dumps(structure) + '\n').encode())
        eqx.tree_serialise_leaves(f, data)


def main(args: argparse.Namespace) -> None:
    np.random.seed(args.seed)
    key = jax.random.key(args.seed)

    game = IIGoofspiel(2, args.num_cards)
    policy_network = load_policy_network(args.num_cards, seed=0)
    policy_ckpts = load_policy_ckpts(f'{args.base_dir}/{args.policy_dir}', args.num_cards)

    key, load_key = jax.random.split(key, 2)
    model = load_model(f'{args.base_dir}/{args.model_dir}/{args.model_ckpt}', load_key)

    results = []
    for _ in range(args.num_policy_pairs):
        indices = np.random.choice(len(policy_ckpts), 2, replace=False)
        policies = [policy_ckpts[i] for i in indices]

        key, evaluate_key = jax.random.split(key, 2)
        data = evaluate(
            game,
            policy_network,
            policies,
            model,
            args.num_cards,
            args.num_dists,
            args.num_samples,
            args.num_cond_samples,
            report=args.report_metrics,
            key=evaluate_key,
        )

        results.append(data)

    # Average across all pairs of policies
    mean_results_per_depth = jax.tree.map(lambda *xs: jnp.mean(jnp.stack(xs), axis=0), *results)

    # Report averaged metrics for each depth
    for depth in range(1, args.num_cards + 1):
        report_metrics(mean_results_per_depth[depth - 1], 'Posterior', 'NFs', depth)
    print('\n' + 32 * '=' + '\n')

    # Average across all valid depths
    mean_results = jax.tree.map(
        lambda *xs: jnp.mean(jnp.stack(xs), axis=0), *mean_results_per_depth
    )

    # Report averaged metrics across all depths
    report_metrics(mean_results, 'Posterior', 'NFs', None)

    # Save the results to the disk
    data = {'mean_results_per_depth': mean_results_per_depth, 'mean_results': mean_results}
    save_results(f'{args.base_dir}/{args.model_dir}/{args.model_ckpt}', data)


if __name__ == '__main__':
    warnings.simplefilter('ignore')

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument(
        '--model-ckpt', type=str, default='model-test.eqx', help='Model checkpoint filename'
    )
    parser.add_argument('--model-dir', type=str, default='goofspiel-models', help='Model directory')
    parser.add_argument('--num-cards', type=int, default=5, help='Cards in the game')
    parser.add_argument(
        '--num-cond-samples', type=int, default=64, help='Number of true samples to embed'
    )
    parser.add_argument('--num-dists', type=int, default=64, help='Number of infostates to sample')
    parser.add_argument(
        '--num-policy-pairs', type=int, default=8, help='Number of policy pairs to evaluate'
    )
    parser.add_argument(
        '--num-samples', type=int, default=128, help='Number of histories to sample in each state'
    )
    parser.add_argument(
        '--policy-dir', type=str, default='goofspiel-policies', help='Policy directory'
    )
    parser.add_argument(
        '--report-metrics', default=False, action='store_true', help='Report metrics'
    )
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    args = parser.parse_args()

    main(args)
