#!/usr/bin/env python3

import argparse
import json
import os
import warnings

import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from tqdm import tqdm

from envs.goofspiel import IIGoofspiel
from goofspiel.generate_data import generate_data_flow, precompute_all_histories
from goofspiel.generate_policies import load_policy_ckpts, load_policy_network
from goofspiel.utils import create_model, parse_network_hyperparams, train_step


def main(args: argparse.Namespace) -> None:
    np.random.seed(args.seed)
    key = jax.random.key(args.seed)

    game = IIGoofspiel(2, args.num_cards)
    policy_network = load_policy_network(args.num_cards, seed=0)
    policy_ckpts = load_policy_ckpts(f'{args.base_dir}/{args.policy_dir}', args.num_cards)
    # policy_ckpts = [policy_ckpts[i] for i in np.random.choice(len(policy_ckpts), 256, False)]

    # Parse hyperparameters for the networks
    hyperparams = parse_network_hyperparams(args.embedding_net)
    embedding_net_hyperparams = {
        'embedder': args.embedder,
        'num_cards': args.num_cards,
        'obs_size': game.num_infostate_features(),
        'embedding_size': hyperparams[0],
        'num_layers': hyperparams[1],
        'hidden_size': hyperparams[2],
    }

    hyperparams = parse_network_hyperparams(args.flow_net)
    flow_net_hyperparams = {
        'dequant_num_layers': hyperparams[0],
        'dequant_hidden_size': hyperparams[1],
        'dequant_num_params': 2,
        'num_coupling_layers': hyperparams[2],
        'num_layers': hyperparams[3],
        'hidden_size': hyperparams[4],
        'num_params': 5,
    }

    # Collect hyperparameters for the networks
    hyperparams = {
        'embedding_net_hyperparams': embedding_net_hyperparams,
        'flow_net_hyperparams': flow_net_hyperparams,
    }

    # Create a new model
    key, model_key = jax.random.split(key, 2)
    model = create_model(embedding_net_hyperparams, flow_net_hyperparams, model_key)

    optimizer = optax.nadam(1e-3)
    optimizer_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

    num_cores, losses, progress_bar = jax.device_count(), [], tqdm(range(args.num_epochs))
    generate_data_parallel = jax.pmap(
        generate_data_flow,
        in_axes=(None,) * 3 + (0, 0, 0, None, None, 0),
        static_broadcasted_argnums=(0, 1, 6, 7),
    )

    # Precompute all histories and hands
    histories_buffer, hands_buffer = precompute_all_histories(args.num_cards)

    # Prepare the model directory for checkpointing
    os.makedirs(f'{args.base_dir}/{args.model_dir}', mode=0o755, exist_ok=True)

    for epoch in progress_bar:
        # Sample a random pair of pretrained policies
        policies = [policy_ckpts[i] for i in np.random.choice(len(policy_ckpts), 2, False)]

        # Sample random valid depths to sample infostates from
        key, depth_key = jax.random.split(key, 2)
        depths = jax.random.randint(depth_key, args.batch_size, 1, args.num_cards + 1)
        histories = histories_buffer[depths - 1]
        hands = hands_buffer[depths - 1]

        # Reshape the data to be compatible with `jax.pmap`
        depths = jnp.reshape(depths, (num_cores, -1, *depths.shape[1:]))
        histories = jnp.reshape(histories, (num_cores, -1, *histories.shape[1:]))
        hands = jnp.reshape(hands, (num_cores, -1, *hands.shape[1:]))

        # Sample data -- histories of the opponent in infostates at the given depth
        key, *data_keys = jax.random.split(key, 1 + num_cores)
        batch = generate_data_parallel(
            game,
            policy_network,
            policies,
            histories,
            hands,
            depths,
            args.batch_size // num_cores,
            args.num_samples,
            jnp.array(data_keys),
        )

        # Merge the two leading dimensions in the returned batch
        batch = jnp.reshape(batch, [-1, *batch.shape[2:]])

        # Perform one gradient update on the given batch
        key, train_key = jax.random.split(key, 2)
        model, optimizer_state, loss_value = train_step(
            model, optimizer, optimizer_state, batch, train_key
        )
        losses.append(loss_value)

        if epoch % 10 == 0:
            progress_bar.set_postfix({'loss': np.mean(losses)})

        if epoch > 0 and epoch % args.ckpt_freq == 0:
            model_name = f'model-{args.run_name}-{args.num_cards:02}-{epoch}-{args.seed}'

            # Save the trained model to the disk
            with open(f'{args.base_dir}/{args.model_dir}/{model_name}.eqx', 'wb') as f:
                f.write((json.dumps(hyperparams) + '\n').encode())
                eqx.tree_serialise_leaves(f, model)

    model_name = f'model-{args.run_name}-{args.num_cards:02}-{args.num_epochs}-{args.seed}'

    # Save training curves to the disk
    plt.plot(losses, label='Raw')
    plt.plot(np.convolve(losses, np.ones(32), 'valid') / 32, label='Smoothed')
    plt.legend()
    plt.grid()
    plt.savefig(f'{args.base_dir}/{args.model_dir}/{model_name}.png')

    # Save the trained model to the disk
    with open(f'{args.base_dir}/{args.model_dir}/{model_name}.eqx', 'wb') as f:
        f.write((json.dumps(hyperparams) + '\n').encode())
        eqx.tree_serialise_leaves(f, model)


if __name__ == '__main__':
    warnings.simplefilter('ignore')

    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
    parser.add_argument('--ckpt_freq', type=int, default=4096, help='Checkpoint frequency')
    parser.add_argument('--embedder', type=str, default='mean', help='Type of the embedder')
    parser.add_argument(
        '--embedding_net', type=str, default='32,3,128', help='Hyperparams of the embedding network'
    )
    parser.add_argument(
        '--flow_net', type=str, default='2,32,8,3,48', help='Hyperparams of the flow network'
    )
    parser.add_argument('--model_dir', type=str, default='goofspiel_models', help='Model directory')
    parser.add_argument('--num_cards', type=int, default=5, help='Cards in the game')
    parser.add_argument('--num_epochs', type=int, default=256, help='Number of training epochs')
    parser.add_argument(
        '--num_filtering_steps', type=int, default=48, help='Number of filtering steps in MCMC'
    )
    parser.add_argument(
        '--num_samples', type=int, default=128, help='Number of histories to sample in each state'
    )
    parser.add_argument(
        '--policy_dir', type=str, default='goofspiel_policies', help='Policy directory'
    )
    parser.add_argument('--run_name', type=str, default='flow', help='Name of the training run')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    args = parser.parse_args()

    main(args)
