#!/usr/bin/env python3

import argparse
import glob
import os
import warnings
from collections.abc import Callable
from functools import lru_cache
from itertools import permutations

import distrax
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import PyTree
from tqdm import tqdm

from envs.goofspiel import IIGoofspiel, IIGoofspielState
from goofspiel.generate_policies import load_policy_ckpts, load_policy_network

TEMPERATURE = 2


@eqx.filter_jit
def sample_action(
    policy_network: Callable,
    policy_ckpt: PyTree[jax.Array],
    obs: jax.Array,
    mask: jax.Array,
    key: jax.Array,
) -> int:
    dist = policy_network(policy_ckpt, obs[jnp.newaxis, :])[0]
    dist = jnp.where(mask > 0, dist.logits, -jnp.inf)
    probs = jax.nn.softmax(dist / TEMPERATURE)

    return distrax.Categorical(probs=probs).sample(seed=key)


@eqx.filter_jit
def sample_state(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    depth: int,
    key: jax.Array,
) -> IIGoofspielState:
    key, reset_key = jax.random.split(key, 2)
    state, _ = game.reset(reset_key)

    def _step(i, args):
        state, key = args
        key, step_key, *sample_keys = jax.random.split(key, 4)

        action = jax.lax.cond(
            i == 1,
            lambda: jax.random.randint(sample_keys[0], (2,), 0, game.num_cards),
            lambda: jnp.array(
                [
                    sample_action(
                        policy_network,
                        policy_ckpts[i],
                        state.infostate_features(i),
                        state.legal_actions_k_hot(i),
                        sample_keys[i],
                    )
                    for i in range(2)
                ]
            ),
        )

        next_state, *_ = game.step(state, action, step_key)

        return next_state, key

    return jax.lax.fori_loop(1, depth, _step, (state, key))[0]


@eqx.filter_jit
def sample_trajectory(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    depth: int,
    key: jax.Array,
) -> IIGoofspielState:
    key, reset_key = jax.random.split(key, 2)
    state, _ = game.reset(reset_key)

    def _step(args, i):
        state, key = args
        key, step_key, *sample_keys = jax.random.split(key, 4)

        action = jax.lax.cond(
            i == 0,
            lambda: jax.random.randint(sample_keys[0], (2,), 0, game.num_cards),
            lambda: jnp.array(
                [
                    sample_action(
                        policy_network,
                        policy_ckpts[i],
                        state.infostate_features(i),
                        state.legal_actions_k_hot(i),
                        sample_keys[i],
                    )
                    for i in range(2)
                ]
            ),
        )

        next_state, *_ = game.step(state, action, step_key)

        return (next_state, key), next_state

    return jax.lax.scan(_step, (state, key), jnp.arange(depth))[1]


@eqx.filter_jit
def counterfactual_reach_logprob(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    state: IIGoofspielState,
    player_id: int,
    depth: int,
) -> float:
    def _policy_func(state: IIGoofspielState, player: int) -> jax.Array:
        obs, mask = state.infostate_features(player), state.legal_actions_k_hot(player)

        dist = jax.lax.cond(
            state.round == 1,
            lambda: jnp.full(state.prize_deck.shape[0], 1 / state.prize_deck.shape[0]),
            lambda: policy_network(policy_ckpts[player], obs[jnp.newaxis, :])[0].logits,
        )
        dist = jnp.where(mask > 0, dist, -jnp.inf)

        return jax.nn.softmax(dist / TEMPERATURE)

    initial_state, *_ = game.next_state(game.initial_state(), state.action_history[0])
    opponent_id = (1 - player_id) % game.num_players

    def _accumulate_reach(i, args):
        current_state, current_reach = args

        opponent_action = state.action_history[i, opponent_id]
        opponent_dist = _policy_func(current_state, opponent_id)

        return (
            game.next_state(current_state, state.action_history[i])[0],
            current_reach + jnp.log(opponent_dist[opponent_action]),
        )

    return jax.lax.fori_loop(1, depth, _accumulate_reach, (initial_state, 0.0))[1]


@eqx.filter_jit
def generate_samples_mcmc(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    state: IIGoofspielState,
    depth: int,
    num_samples: int,
    num_filtering_steps: int,
    key: jax.Array,
) -> jax.Array:
    def _filter_state_mcmc(key: jax.Array) -> IIGoofspielState:
        initial_state = game.construct_mc_initial_state(state, 0, depth)

        def _step(i, args):
            current_state, key = args

            key, neighbor_key, branch_key = jax.random.split(key, 3)
            p0 = counterfactual_reach_logprob(
                game, policy_network, policy_ckpts, current_state, 0, depth
            )

            # Assume the neighbor generation probabilities (proposal distribution) are symmetric
            candidate, _ = game.get_mc_neighbor(neighbor_key, current_state, 0, depth)
            p1 = counterfactual_reach_logprob(
                game, policy_network, policy_ckpts, candidate, 0, depth
            )

            return jax.lax.cond(
                jax.random.uniform(branch_key) < jnp.exp(p1 - p0),
                lambda: (candidate, key),
                lambda: (current_state, key),
            )

        return jax.lax.fori_loop(0, num_filtering_steps, _step, (initial_state, key))[0]

    return jax.vmap(_filter_state_mcmc)(jax.random.split(key, num_samples))


@eqx.filter_jit
def filter_possible_histories(gold_history: jax.Array, histories: jax.Array) -> jax.Array:
    gold_actions = gold_history[1:, 0]
    gold_outcomes = jnp.sign(gold_history[1:, 0] - gold_history[1:, 1])

    outcomes = jnp.sign(gold_actions - histories)
    valid_mask = jnp.all(gold_outcomes == outcomes, axis=1)

    # NOTE: Due to the padding consisting of vectors of -1s appended to `histories`, it can
    # happen that these paddings are considered valid, and thus we need to filter them out.
    valid_mask = jnp.where(jnp.any(histories != -1, axis=1), valid_mask, False)

    return valid_mask


@eqx.filter_jit
def generate_samples(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    state: IIGoofspielState,
    histories: jax.Array,
    hands: jax.Array,
    depth: int,
    num_samples: int,
    key: jax.Array,
) -> jax.Array:
    mask = filter_possible_histories(state.action_history, histories)

    def _calculate_cf_reach_logprob(history: jax.Array, hand: jax.Array) -> float:
        new_state = eqx.tree_at(
            lambda x: x.action_history, state, replace_fn=lambda x: x.at[1:, 1].set(history)
        )
        new_state = eqx.tree_at(
            lambda x: x.hands, new_state, replace_fn=lambda x: x.at[1].set(hand)
        )

        return counterfactual_reach_logprob(game, policy_network, policy_ckpts, new_state, 0, depth)

    reach_logprobs = jax.vmap(_calculate_cf_reach_logprob)(histories, hands)
    reach_logprobs = jnp.where(mask, reach_logprobs, -jnp.inf)
    reach_probs = jax.nn.softmax(reach_logprobs)

    key, sample_key = jax.random.split(key, 2)
    indices = jax.random.choice(sample_key, histories.shape[0], (num_samples,), p=reach_probs)

    return histories[indices]


@eqx.filter_jit
def generate_data_mcmc(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    batch_size: int,
    num_samples: int,
    depths: jax.Array,
    filtering_steps: int,
    key: jax.Array,
) -> jax.Array:
    key, *sample_keys = jax.random.split(key, 1 + batch_size)
    states = jax.vmap(sample_state, in_axes=(None, None, None, 0, 0))(
        game, policy_network, policy_ckpts, depths, jnp.array(sample_keys)
    )

    key, *generate_keys = jax.random.split(key, 1 + batch_size)
    samples = jax.vmap(generate_samples_mcmc, in_axes=(None, None, None, 0, 0, None, None, 0))(
        game,
        policy_network,
        policy_ckpts,
        states,
        depths,
        num_samples,
        filtering_steps,
        jnp.array(generate_keys),
    )

    # Take only sampled opponent's histories
    return samples.action_history[:, :, 1:, 1]


@eqx.filter_jit
def generate_data_flow(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    histories: jax.Array,
    hands: jax.Array,
    depths: jax.Array,
    batch_size: int,
    num_samples: int,
    key: jax.Array,
) -> jax.Array:
    key, *sample_keys = jax.random.split(key, 1 + batch_size)
    states = jax.vmap(sample_state, in_axes=(None, None, None, 0, 0))(
        game, policy_network, policy_ckpts, depths + 1, jnp.array(sample_keys)
    )

    key, *generate_keys = jax.random.split(key, 1 + batch_size)
    samples = jax.vmap(generate_samples, in_axes=(None, None, None, 0, 0, 0, 0, None, 0))(
        game,
        policy_network,
        policy_ckpts,
        states,
        histories,
        hands,
        depths + 1,
        num_samples,
        jnp.array(generate_keys),
    )

    return samples


@eqx.filter_jit
def generate_data_rnn(
    game: IIGoofspiel,
    policy_network: Callable,
    policy_ckpts: PyTree[jax.Array],
    histories: jax.Array,
    hands: jax.Array,
    depths: jax.Array,
    batch_size: int,
    num_samples: int,
    key: jax.Array,
) -> tuple[jax.Array, jax.Array]:
    key, *sample_keys = jax.random.split(key, 1 + batch_size)
    sequences = jax.vmap(sample_trajectory, in_axes=None)(
        game, policy_network, policy_ckpts, game.num_cards, key=jnp.array(sample_keys)
    )

    def _get_observation(state: IIGoofspielState) -> jax.Array:
        return state.infostate_features(1)

    observations = jax.vmap(jax.vmap(_get_observation))(sequences)

    key, generate_key = jax.random.split(key, 2)
    generate_samples_batch = jax.vmap(generate_samples, in_axes=(None,) * 3 + (0, 0, 0, 0, None))
    samples = jax.vmap(generate_samples_batch, in_axes=(None,) * 3 + (0,) + (None,) * 4)(
        game,
        policy_network,
        policy_ckpts,
        sequences,
        histories,
        hands,
        depths + 1,
        num_samples,
        key=jax.random.split(generate_key, (batch_size, game.num_cards)),
    )

    return observations, samples


@lru_cache
def get_all_histories(num_cards: int, depth: int) -> tuple[jax.Array, jax.Array]:
    if depth == 0:
        return jnp.full([1, num_cards], -1, jnp.int32), jnp.ones([1, num_cards], jnp.int32)

    perms = jnp.array(list(permutations(range(num_cards), depth)))
    histories = jnp.pad(perms, ((0, 0), (0, num_cards - depth)), constant_values=-1)

    def _generate_hand(history: jax.Array) -> jax.Array:
        hand = jnp.ones(num_cards, jnp.int32)
        hand = hand.at[history].set(0)

        return hand

    hands = jax.vmap(_generate_hand)(histories[:, :depth])

    return histories, hands


def precompute_all_histories(num_cards: int) -> tuple[jax.Array, jax.Array]:
    histories_buffer, hands_buffer = [], []
    max_num_histories = jnp.prod(jnp.arange(1, num_cards + 1))

    for depth in range(1, num_cards + 1):
        histories, hands = get_all_histories(num_cards, depth)
        padding_size = max_num_histories - histories.shape[0]

        histories = jnp.pad(histories, ((0, padding_size), (0, 0)), constant_values=-1)
        hands = jnp.pad(hands, ((0, padding_size), (0, 0)), constant_values=1)

        histories_buffer.append(histories)
        hands_buffer.append(hands)

    histories = jnp.stack(histories_buffer, axis=0, dtype=jnp.int8)
    hands = jnp.stack(hands_buffer, axis=0, dtype=jnp.int8)

    return histories, hands


def load_data(data_dir: str, num_cards: int) -> jax.Array:
    # Assume there is only one file with `num_cards` in the name
    return jnp.load(glob.glob(f'{data_dir}/data-{num_cards:02}-*.npy')[0])


def main(args: argparse.Namespace) -> None:
    np.random.seed(args.seed)
    key = jax.random.key(args.seed)

    game = IIGoofspiel(2, args.num_cards)
    policy_network = load_policy_network(args.num_cards, seed=0)
    policy_ckpts = load_policy_ckpts(f'{args.base_dir}/{args.policy_dir}', args.num_cards)

    num_batches, num_cores, pointer = args.num_iters * args.batch_size, jax.device_count(), 0
    assert args.batch_size % num_cores == 0, '`batch_size` must be divisible by `num_cores`'

    data = np.empty([num_batches, args.num_samples, args.num_cards], np.int32)

    # NOTE: Vectorizing over pairs of policies would be even more efficient, however,
    # individual policies are `flax.linen.Module` which don't play nicely with `jax.vmap`.
    generate_data_parallel = jax.pmap(
        generate_data_flow,
        in_axes=(None,) * 3 + (0, 0, 0, None, None, 0),
        static_broadcasted_argnums=(0, 1, 6, 7),
    )

    # Precompute all histories and hands
    histories_buffer, hands_buffer = precompute_all_histories(args.num_cards)

    for _ in tqdm(range(args.num_iters)):
        # Sample a random pair of pretrained policies
        policies = [policy_ckpts[i] for i in np.random.choice(len(policy_ckpts), 2, False)]

        # Sample random valid depths to sample infostates from
        key, depth_key = jax.random.split(key, 2)
        depths = jax.random.randint(depth_key, args.batch_size, 1, args.num_cards + 1)
        histories = histories_buffer[depths - 1]
        hands = hands_buffer[depths - 1]

        # Reshape the data to be compatible with `jax.pmap`
        depths = jnp.reshape(depths, (num_cores, -1, *depths.shape[1:]))
        histories = jnp.reshape(histories, (num_cores, -1, *histories.shape[1:]))
        hands = jnp.reshape(hands, (num_cores, -1, *hands.shape[1:]))

        # Sample data -- histories of the opponent in infostates at the given depth
        key, *data_keys = jax.random.split(key, 1 + num_cores)
        batch = generate_data_parallel(
            game,
            policy_network,
            policies,
            histories,
            hands,
            depths,
            args.batch_size // num_cores,
            args.num_samples,
            jnp.array(data_keys),
        )

        # Merge the two leading dimensions in the returned batch
        batch = jnp.reshape(batch, [-1, *batch.shape[2:]])

        data[pointer : pointer + args.batch_size] = batch
        pointer += args.batch_size

    os.makedirs(f'{args.base_dir}/{args.data_dir}', mode=0o755, exist_ok=True)
    jnp.save(f'{args.base_dir}/{args.data_dir}/data-{args.num_cards:02}-{args.seed}.npy', data)


if __name__ == '__main__':
    warnings.simplefilter('ignore')

    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument(
        '--batch_size',
        type=int,
        default=64,
        help='Number of states to sample at a particular depth',
    )
    parser.add_argument('--data_dir', type=str, default='goofspiel_data', help='Data directory')
    parser.add_argument('--num_cards', type=int, default=5, help='Cards in the game')
    parser.add_argument(
        '--num_filtering_steps', type=int, default=48, help='Number of filtering steps in MCMC'
    )
    parser.add_argument(
        '--num_iters', type=int, default=256, help='Number of iterations to samples for'
    )
    parser.add_argument(
        '--num_samples', type=int, default=128, help='Number of histories to sample in each state'
    )
    parser.add_argument(
        '--policy_dir', type=str, default='goofspiel_policies', help='Policy directory'
    )
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    args = parser.parse_args()

    main(args)
