#!/usr/bin/env python3

from __future__ import annotations

import argparse
import json
import os
import warnings
from collections import defaultdict
from collections.abc import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import kl_div
from jaxtyping import PyTree

from envs.goofspiel import IIGoofspiel, IIGoofspielState
from goofspiel.generate_data import (
    counterfactual_reach_logprob,
    filter_possible_histories,
    get_all_histories,
    sample_action,
)
from goofspiel.generate_policies import load_policy_ckpts, load_policy_network
from goofspiel.utils import load_model


class ParticleSet(eqx.Module):
    particles: eqx.Module
    weights: jax.Array

    @eqx.filter_jit
    def simulate(
        self,
        game: IIGoofspiel,
        policy_network: Callable,
        policy_ckpts: PyTree[jax.Array],
        gold_action: int,
        gold_outcome: int,
        step_key: jax.Array,
        key: jax.Array,
    ) -> ParticleSet:
        def _simulate_single(
            state: IIGoofspielState, key: jax.Array
        ) -> tuple[IIGoofspielState, int]:
            key, sample_key = jax.random.split(key, 2)
            opponent_action = sample_action(
                policy_network,
                policy_ckpts[1],
                state.infostate_features(1),
                state.legal_actions_k_hot(1),
                sample_key,
            )

            state, *_ = game.step(state, [gold_action, opponent_action], step_key)
            outcome = jnp.sign(gold_action - opponent_action)

            return state, outcome

        key, *simulate_keys = jax.random.split(key, 1 + self.weights.shape[0])
        particles, outcomes = jax.vmap(_simulate_single)(self.particles, jnp.array(simulate_keys))

        weights = jnp.where(outcomes == gold_outcome, self.weights, 0.0)
        weights = jax.lax.cond(
            jnp.sum(weights) > 0.0,
            lambda: weights / jnp.sum(weights),
            lambda: jnp.ones_like(weights) / weights.shape[0],
        )

        return ParticleSet(particles, weights)

    @eqx.filter_jit
    def neff(self) -> float:
        return 1 / (jnp.sum(self.weights**2) + 1e-12)

    @eqx.filter_jit
    def multinomial_resample(self, key: jax.Array) -> ParticleSet:
        num_particles = self.weights.shape[0]

        indices = jax.random.choice(key, num_particles, shape=(num_particles,), p=self.weights)
        particles = jax.tree.map(lambda x: x[indices], self.particles)
        weights = jnp.full(num_particles, 1 / num_particles)

        return ParticleSet(particles, weights)

    @eqx.filter_jit
    def systematic_resample(self, key: jax.Array) -> ParticleSet:
        num_particles = self.weights.shape[0]

        cdf = jnp.cumsum(self.weights)
        x0 = jax.random.uniform(key, (), minval=0.0, maxval=1.0 / num_particles)
        positions = x0 + jnp.arange(num_particles) / num_particles
        indices = jnp.searchsorted(cdf, positions, side='left')

        particles = jax.tree.map(lambda x: x[indices], self.particles)
        weights = jnp.full(num_particles, 1 / num_particles)

        return ParticleSet(particles, weights)


class BaseFilter(eqx.Module):
    def _create_hand(self, history: jax.Array) -> jax.Array:
        def _step(i, hand):
            return jax.lax.cond(history[i] != -1, lambda: hand.at[history[i]].set(0), lambda: hand)

        return jax.lax.fori_loop(0, history.shape[0], _step, jnp.ones(history.shape[0]))

    def _create_state(
        self, state: IIGoofspielState, history: jax.Array, hand: jax.Array
    ) -> IIGoofspielState:
        state = eqx.tree_at(
            lambda x: x.action_history, state, replace_fn=lambda x: x.at[1:, 1].set(history)
        )
        state = eqx.tree_at(lambda x: x.hands, state, replace_fn=lambda x: x.at[1].set(hand))

        return state


class ParticleFilter(BaseFilter):
    """Implement the standard SIR filter with threshold-based resampling."""

    num_particles: int

    def __repr__(self) -> str:
        return f'PF ({self.num_particles})'

    @eqx.filter_jit
    def sample(self, particle_set: ParticleSet, num_samples: int, key: jax.Array) -> jax.Array:
        return jax.random.choice(
            key,
            particle_set.particles.action_history[:, 1:, 1],
            (num_samples,),
            p=particle_set.weights / (jnp.sum(particle_set.weights) + 1e-12),
        )

    @eqx.filter_jit
    def compute_beliefs(
        self,
        particle_set: ParticleSet,
        histories: jax.Array,
        mask: jax.Array,
        key: jax.Array,
    ) -> jax.Array:
        def _compute_single(history: jax.Array) -> float:
            matches = jnp.all(particle_set.particles.action_history[:, 1:, 1] == history, axis=1)

            return particle_set.weights @ matches.astype(jnp.float32)

        beliefs = jax.vmap(_compute_single)(histories)
        beliefs = beliefs / (jnp.sum(beliefs) + 1e-12)

        return beliefs

    @eqx.filter_jit
    def update(
        self,
        game: IIGoofspiel,
        policy_network: Callable,
        policy_ckpts: PyTree[jax.Array],
        gold_state: IIGoofspielState,
        gold_action: int,
        gold_outcome: int,
        particle_set: ParticleSet,
        step_key: jax.Array,
        key: jax.Array,
    ) -> ParticleSet:
        key, simulate_key, resample_key = jax.random.split(key, 3)

        # Reuse the same `step_key` to ensure that the game dynamics are the same
        particle_set = particle_set.simulate(
            game, policy_network, policy_ckpts, gold_action, gold_outcome, step_key, simulate_key
        )

        return jax.lax.cond(
            particle_set.neff() < self.num_particles / 2,
            lambda: particle_set.systematic_resample(resample_key),
            lambda: particle_set,
        )

    def reset(self, gold_state: IIGoofspielState, key: jax.Array) -> ParticleSet:
        num_cards, depth = gold_state.prize_deck.shape[0], gold_state.round - 1
        histories, hands = get_all_histories(int(num_cards), int(depth))
        indices = jax.random.choice(key, histories.shape[0], shape=(self.num_particles,))

        particles = jax.vmap(self._create_state, in_axes=(None, 0, 0))(
            gold_state, histories[indices], hands[indices]
        )
        weights = jnp.full(self.num_particles, 1 / self.num_particles)

        return ParticleSet(particles, weights)


class NeuralFilter(BaseFilter):
    @eqx.filter_jit
    def sample(self, embedding: jax.Array, num_samples: int, key: jax.Array) -> jax.Array:
        return self.model.sample(embedding, num_samples, key)

    @eqx.filter_jit
    def compute_beliefs(
        self, embedding: jax.Array, histories: jax.Array, mask: jax.Array, key: jax.Array
    ) -> jax.Array:
        key, *prob_keys = jax.random.split(key, 1 + 20)
        logprobs, *_ = jax.vmap(self.model.log_prob, in_axes=(None, None, 0))(
            histories, embedding, jnp.array(prob_keys)
        )
        logprobs = jax.nn.logsumexp(logprobs, axis=0) - jnp.log(20)
        logprobs = jnp.where(mask, logprobs, -jnp.inf)
        beliefs = jax.nn.softmax(logprobs)

        return beliefs

    def reset(self, gold_state: IIGoofspielState, key: jax.Array) -> jax.Array:
        num_cards, depth = gold_state.prize_deck.shape[0], gold_state.round - 1
        histories, _ = get_all_histories(int(num_cards), int(depth))
        indices = jax.random.choice(key, histories.shape[0], shape=(self.num_particles,))

        return self.model.embed(histories[indices])


class RecurrentFilter(NeuralFilter):
    model: eqx.Module

    def __repr__(self) -> str:
        return 'Recurrent'

    @eqx.filter_jit
    def sample(
        self, state: tuple[jax.Array, jax.Array], num_samples: int, key: jax.Array
    ) -> jax.Array:
        return super().sample(state[0], num_samples, key)

    @eqx.filter_jit
    def compute_beliefs(
        self,
        state: tuple[jax.Array, jax.Array],
        histories: jax.Array,
        mask: jax.Array,
        key: jax.Array,
    ) -> jax.Array:
        return super().compute_beliefs(state[0], histories, mask, key)

    @eqx.filter_jit
    def update(
        self,
        game: IIGoofspiel,
        policy_network: Callable,
        policy_ckpts: PyTree[jax.Array],
        next_gold_state: IIGoofspielState,
        gold_action: int,
        gold_outcome: int,
        state: tuple[jax.Array, jax.Array],
        step_key: jax.Array,
        key: jax.Array,
    ) -> tuple[jax.Array, jax.Array]:
        return self.model.embed(next_gold_state.infostate_features(1), state[1])

    @eqx.filter_jit
    def reset(self, gold_state: IIGoofspielState, key: jax.Array) -> tuple[jax.Array, jax.Array]:
        return self.model.embed(gold_state.infostate_features(1), self.model.reset())


class NeuralGTFilter(NeuralFilter):
    model: eqx.Module
    num_particles: int

    def __repr__(self) -> str:
        return 'Approx Beliefs'

    def update(
        self,
        game: IIGoofspiel,
        policy_network: Callable,
        policy_ckpts: PyTree[jax.Array],
        next_gold_state: IIGoofspielState,
        gold_action: int,
        gold_outcome: int,
        embedding: jax.Array,
        step_key: jax.Array,
        key: jax.Array,
    ) -> jax.Array:
        def _calculate_cf_reach_logprob(history: jax.Array, hand: jax.Array) -> float:
            state = self._create_state(next_gold_state, history, hand)

            return counterfactual_reach_logprob(
                game, policy_network, policy_ckpts, state, 0, next_gold_state.round
            )

        histories, hands = get_all_histories(int(game.num_cards), int(next_gold_state.round - 1))
        mask = filter_possible_histories(next_gold_state.action_history, histories)

        reach_logprobs = jax.vmap(_calculate_cf_reach_logprob)(histories, hands)
        reach_logprobs = jnp.where(mask, reach_logprobs, -jnp.inf)
        reach_probs = jax.nn.softmax(reach_logprobs)

        indices = jax.random.choice(key, histories.shape[0], (self.num_particles,), p=reach_probs)

        return self.model.embed(histories[indices])


class NeuralBayesFilter(NeuralFilter):
    """Implement the proposed Neural Bayesian Filter."""

    model: eqx.Module
    num_particles: int

    def __repr__(self) -> str:
        return f'NBF ({self.num_particles})'

    @eqx.filter_jit
    def update(
        self,
        game: IIGoofspiel,
        policy_network: Callable,
        policy_ckpts: PyTree[jax.Array],
        gold_state: IIGoofspielState,
        gold_action: int,
        gold_outcome: int,
        embedding: jax.Array,
        step_key: jax.Array,
        key: jax.Array,
    ) -> jax.Array:
        key, sample_key, simulate_key = jax.random.split(key, 3)
        histories = self.sample(embedding, self.num_particles, sample_key)
        histories = jnp.clip(histories, -1, gold_state.prize_deck.shape[0] - 1)

        hands = jax.vmap(self._create_hand)(histories)
        mask = filter_possible_histories(gold_state.action_history, histories)

        particle_set = ParticleSet(
            jax.vmap(self._create_state, in_axes=(None, 0, 0))(gold_state, histories, hands),
            jnp.where(mask, 1.0 / jnp.sum(mask), 0.0),
        )

        # Reuse the same `step_key` to ensure that the game dynamics are the same
        particle_set = particle_set.simulate(
            game, policy_network, policy_ckpts, gold_action, gold_outcome, step_key, simulate_key
        )

        return self.model.embed(
            particle_set.particles.action_history[:, 1:, 1], particle_set.weights
        )


FilterState = ParticleSet | tuple[jax.Array, jax.Array] | jax.Array


def evaluate_filter(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    filter: BaseFilter,
    gold_state: IIGoofspielState,
    filter_state: FilterState,
    num_cards: int,
    depth: int,
    key: jax.Array,
) -> float:
    def _calculate_cf_reach_logprob(history: jax.Array, hand: jax.Array) -> float:
        state = eqx.tree_at(
            lambda x: x.action_history, gold_state, replace_fn=lambda x: x.at[1:, 1].set(history)
        )
        state = eqx.tree_at(lambda x: x.hands, state, replace_fn=lambda x: x.at[1].set(hand))

        return counterfactual_reach_logprob(game, policy_network, policy_ckpts, state, 0, depth + 1)

    def _calculate_js_div(p: jax.Array, q: jax.Array) -> float:
        p = (p + 1e-12) / jnp.sum(p + 1e-12)
        q = (q + 1e-12) / jnp.sum(q + 1e-12)
        m = 0.5 * (p + q)

        return jnp.clip(jnp.sum(0.5 * kl_div(p, m) + 0.5 * kl_div(q, m)), 0.0)

    histories, hands = get_all_histories(num_cards, depth)
    mask = filter_possible_histories(gold_state.action_history, histories)

    reach_logprobs = jax.vmap(_calculate_cf_reach_logprob)(histories, hands)
    reach_logprobs = jnp.where(mask, reach_logprobs, -jnp.inf)
    reach_probs = jax.nn.softmax(reach_logprobs)

    key, belief_key = jax.random.split(key, 2)
    filter_beliefs = filter.compute_beliefs(filter_state, histories, mask, belief_key)

    return _calculate_js_div(reach_probs, filter_beliefs)


def save_results(path: str, data: dict[str, dict[str, jax.Array]], seed: int) -> None:
    with open(path.replace('.eqx', f'-{seed}.filter.eval.eqx'), 'wb') as f:
        structure = jax.tree.map(lambda x: x.shape[0], data)
        f.write((json.dumps(structure) + '\n').encode())
        eqx.tree_serialise_leaves(f, data)


def main(args: argparse.Namespace) -> None:
    np.random.seed(args.seed)
    key = jax.random.key(args.seed)

    game = IIGoofspiel(2, args.num_cards)
    policy_network = load_policy_network(args.num_cards, seed=0)
    policy_ckpts = load_policy_ckpts(f'{args.base_dir}/{args.policy_dir}', args.num_cards)

    key, load_key = jax.random.split(key, 2)
    rnn_model = load_model(f'{args.base_dir}/{args.model_dir}/{args.rnn_model_ckpt}', load_key)

    key, load_key = jax.random.split(key, 2)
    flow_model = load_model(f'{args.base_dir}/{args.model_dir}/{args.flow_model_ckpt}', load_key)

    filters = [
        ParticleFilter(128),
        ParticleFilter(256),
        ParticleFilter(512),
        RecurrentFilter(rnn_model),
        NeuralGTFilter(flow_model, 32),
        NeuralBayesFilter(flow_model, 64),
        NeuralBayesFilter(flow_model, 128),
    ]

    results = defaultdict(lambda: defaultdict(list))
    for _ in range(args.num_policy_pairs):
        policies = [policy_ckpts[i] for i in np.random.choice(len(policy_ckpts), 2, replace=False)]

        for _ in range(args.num_eval_episodes):
            key, reset_key = jax.random.split(key, 2)
            gold_state, _ = game.reset(reset_key)

            # Play the first round with actions sampled uniformly at random
            key, sample_key, step_key = jax.random.split(key, 3)
            gold_action = jax.random.randint(sample_key, (2,), 0, args.num_cards)
            gold_state, *_ = game.step(gold_state, gold_action, step_key)

            key, *reset_keys = jax.random.split(key, 1 + len(filters))
            filter_states = [
                filter.reset(gold_state, reset_key)
                for filter, reset_key in zip(filters, reset_keys)
            ]

            for depth in range(2, jnp.minimum(args.num_cards, 8) + 1):
                key, *sample_keys = jax.random.split(key, 1 + 2)
                gold_action = [
                    sample_action(
                        policy_network,
                        policies[i],
                        gold_state.infostate_features(i),
                        gold_state.legal_actions_k_hot(i),
                        sample_keys[i],
                    )
                    for i in range(2)
                ]

                key, step_key = jax.random.split(key, 2)
                next_gold_state, *_ = game.step(gold_state, gold_action, step_key)
                gold_outcome = jnp.sign(gold_action[0] - gold_action[1])

                next_filter_states = []
                for filter, filter_state in zip(filters, filter_states):
                    state = (
                        next_gold_state
                        if isinstance(filter, RecurrentFilter | NeuralGTFilter)
                        else gold_state
                    )

                    key, update_key = jax.random.split(key, 2)
                    filter_state = filter.update(
                        game,
                        policy_network,
                        policies,
                        state,
                        gold_action[0],
                        gold_outcome,
                        filter_state,
                        step_key,
                        update_key,
                    )

                    next_filter_states.append(filter_state)

                gold_state = next_gold_state
                filter_states = next_filter_states

                for filter, filter_state in zip(filters, filter_states):
                    key, evaluate_key = jax.random.split(key, 2)
                    js_div = evaluate_filter(
                        game,
                        policy_network,
                        policies,
                        filter,
                        gold_state,
                        filter_state,
                        args.num_cards,
                        depth,
                        evaluate_key,
                    )
                    results[filter.__repr__()][depth].append(js_div)
                    print(f'Depth {depth}, {filter}: {js_div}')
            print()

    for filter, depths in results.items():
        for depth, js_divs in depths.items():
            mean_js_div, std_js_div = np.mean(js_divs), np.std(js_divs)
            print(f'{filter} at depth {depth}: {mean_js_div:.4f} ± {std_js_div:.4f}')
        print()

    results = jax.tree.map(lambda x: np.array(x), results, is_leaf=lambda x: isinstance(x, list))
    save_results(f'{args.base_dir}/{args.model_dir}/{args.flow_model_ckpt}', results, args.seed)


if __name__ == '__main__':
    warnings.simplefilter('ignore')

    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--flow_model_ckpt', type=str, default='model-test.eqx', help='Flow model checkpoint filename')
    parser.add_argument('--model_dir', type=str, default='goofspiel_models', help='Model directory')
    parser.add_argument('--num_cards', type=int, default=5, help='Cards in the game')
    parser.add_argument('--num_eval_episodes', type=int, default=50, help='Number of evaluation episodes')
    parser.add_argument('--num_particles_nf', type=int, default=64, help='Number of particles for NFs')
    parser.add_argument('--num_particles_pf', type=int, default=64, help='Number of particles for PF')
    parser.add_argument('--num_policy_pairs', type=int, default=10, help='Number of policy pairs to evaluate')
    parser.add_argument('--policy_dir', type=str, default='goofspiel_policies', help='Policy directory')
    parser.add_argument('--rnn_model_ckpt', type=str, default='model-test.eqx', help='RNN model checkpoint filename')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    args = parser.parse_args()
    # fmt: on

    main(args)
