#!/usr/bin/env python3

import argparse
import json
import os
import sys
import warnings

sys.path.append('src')

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import matplotlib.pyplot as plt
import numpy as np
import optax
from tqdm import tqdm

from efg.goofspiel import IIGoofspiel
from goofspiel.generate_data import generate_data_rnn, precompute_all_histories
from goofspiel.generate_policies import load_policy_ckpts, load_policy_network
from goofspiel.model_utils import create_model, parse_network_hyperparams, train_step


def main(args: argparse.Namespace) -> None:
    np.random.seed(args.seed)
    key = jax.random.PRNGKey(args.seed)

    game = IIGoofspiel(2, args.num_cards)
    policy_network = load_policy_network(args.num_cards, seed=0)
    policy_ckpts = load_policy_ckpts(f'{args.base_dir}/{args.policy_dir}', args.num_cards)
    policy_ckpts = [policy_ckpts[i] for i in np.random.choice(len(policy_ckpts), 256, False)]

    # Parse hyperparameters for the networks
    hyperparams = parse_network_hyperparams(args.embedding_net)
    embedding_net_hyperparams = {
        'embedder': args.embedder, 'num_cards': args.num_cards,
        'obs_size': game.num_infostate_features(), 'embedding_size': hyperparams[0],
        'num_layers': hyperparams[1], 'hidden_size': hyperparams[2]
    }

    hyperparams = parse_network_hyperparams(args.flow_net)
    flow_net_hyperparams = {
        'dequant_num_layers': hyperparams[0], 'dequant_hidden_size': hyperparams[1],
        'dequant_num_params': 2, 'num_coupling_layers': hyperparams[2],
        'num_layers': hyperparams[3], 'hidden_size': hyperparams[4], 'num_params': 5
    }

    # Collect hyperparameters for the networks
    hyperparams = {
        'embedding_net_hyperparams': embedding_net_hyperparams,
        'flow_net_hyperparams': flow_net_hyperparams
    }

    # Create a new model
    key, model_key = jax.random.split(key, 2)
    model = create_model(embedding_net_hyperparams, flow_net_hyperparams, model_key)

    optimizer = optax.nadam(1e-3)
    optimizer_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

    num_cores, losses, progress_bar = jax.device_count(), [], tqdm(range(args.num_epochs))
    generate_data_parallel = jax.pmap(
        generate_data_rnn, in_axes=(None,) * 8 + (0,),
        static_broadcasted_argnums=(0, 1, 6, 7)
    )

    # Precompute all histories and hands
    histories_buffer, hands_buffer = precompute_all_histories(args.num_cards)

    # Prepare the model directory for checkpointing
    os.makedirs(f'{args.base_dir}/{args.model_dir}', mode=0o755, exist_ok=True)

    for epoch in progress_bar:
        # Sample a random pair of pretrained policies
        policies = [policy_ckpts[i] for i in np.random.choice(len(policy_ckpts), 2, False)]

        # Traverse the game tree at all depths
        depths = jnp.arange(1, args.num_cards + 1)
        histories = histories_buffer[depths - 1]
        hands = hands_buffer[depths - 1]

        # Sample data -- histories of the opponent in infostates at the given depth
        key, *data_keys = jax.random.split(key, 1 + num_cores)
        batch = generate_data_parallel(
            game, policy_network, policies, histories, hands, depths,
            args.batch_size // num_cores, args.num_samples, jnp.array(data_keys)
        )

        # Merge the two leading dimensions in the returned batch
        batch = jtu.tree_map(lambda x: jnp.reshape(x, [-1, *x.shape[2:]]), batch)

        # Perform one gradient update on the given batch
        key, train_key = jax.random.split(key, 2)
        model, optimizer_state, loss_value = train_step(
            model, optimizer, optimizer_state, batch, train_key
        )
        losses.append(loss_value)

        if epoch % 10 == 0:
            progress_bar.set_postfix({'loss': np.mean(losses)})

        if epoch > 0 and epoch % args.ckpt_freq == 0:
            model_name = f'model-{args.run_name}-{args.num_cards:02}-{epoch}-{args.seed}'

            # Save the trained model to the disk
            with open(f'{args.base_dir}/{args.model_dir}/{model_name}.eqx', 'wb') as f:
                f.write((json.dumps(hyperparams) + '\n').encode())
                eqx.tree_serialise_leaves(f, model)

    model_name = f'model-{args.run_name}-{args.num_cards:02}-{args.num_epochs}-{args.seed}'

    # Save training curves to the disk
    plt.plot(losses, label='Raw')
    plt.plot(np.convolve(losses, np.ones(32), 'valid') / 32, label='Smoothed')
    plt.legend()
    plt.grid()
    plt.savefig(f'{args.base_dir}/{args.model_dir}/{model_name}.png')

    # Save the trained model to the disk
    with open(f'{args.base_dir}/{args.model_dir}/{model_name}.eqx', 'wb') as f:
        f.write((json.dumps(hyperparams) + '\n').encode())
        eqx.tree_serialise_leaves(f, model)


if __name__ == '__main__':
    warnings.simplefilter('ignore')

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--batch-size', type=int, default=64, help='Batch size')
    parser.add_argument('--ckpt-freq', type=int, default=4096, help='Checkpoint frequency')
    parser.add_argument('--embedder', type=str, default='mean', help='Type of the embedder')
    parser.add_argument('--embedding-net', type=str, default='32,2,64', help='Hyperparams of the embedding network')
    parser.add_argument('--flow-net', type=str, default='2,32,8,3,48', help='Hyperparams of the flow network')
    parser.add_argument('--model-dir', type=str, default='goofspiel-models', help='Model directory')
    parser.add_argument('--num-cards', type=int, default=5, help='Cards in the game')
    parser.add_argument('--num-epochs', type=int, default=256, help='Number of training epochs')
    parser.add_argument('--num-filtering-steps', type=int, default=48, help='Number of filtering steps in MCMC')
    parser.add_argument('--num-samples', type=int, default=128, help='Number of histories to sample in each state')
    parser.add_argument('--policy-dir', type=str, default='goofspiel-policies', help='Policy directory')
    parser.add_argument('--run-name', type=str, default='rnn', help='Name of the training run')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    args = parser.parse_args()

    main(args)
