#!/usr/bin/env python3

import argparse
import glob
import os
from collections.abc import Callable
from typing import Any

import jax
import jax.numpy as jnp
import orbax.checkpoint
from flax.training.train_state import TrainState
from sbx import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv

from envs.goofspiel import GoofspielGymEnv


def make_env(num_cards: int, opponent: TrainState | None, seed: int) -> Callable:
    def execute():
        return Monitor(GoofspielGymEnv(num_cards, opponent=opponent, seed=seed))

    return execute


def load_policy_network(num_cards: int, seed: int) -> Callable:
    env = GoofspielGymEnv(num_cards, opponent=None, seed=seed)
    model = PPO('MlpPolicy', env, seed=seed, verbose=0)

    return model.policy.actor.apply


def load_policy_ckpts(policy_dir: str, num_cards: int) -> list[dict[str, Any]]:
    checkpointer = orbax.checkpoint.PyTreeCheckpointer()

    # Assume there is only one file with `num_cards` in the name
    ckpts = checkpointer.restore(glob.glob(f'{policy_dir}/policies-{num_cards:02}-*.ckpt')[0])

    return [ckpt['params'] for ckpt in ckpts]


def evaluate(agent: Callable, env: GoofspielGymEnv, num_episodes: int, key: jax.Array) -> float:
    total_reward = 0

    def predict(observation: jax.Array, mask: jax.Array, key: jax.Array) -> int:
        dist = agent(obs[jnp.newaxis, :])[0]
        dist = jnp.where(mask > 0, dist.logits, -jnp.inf)

        return jax.random.categorical(key, dist)

    for _ in range(num_episodes):
        obs, done = env.reset()[0], False

        while not done:
            key, sample_key = jax.random.split(key, 2)
            action = predict(obs, env._state.legal_actions_k_hot(0), sample_key)

            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            total_reward += reward

    return total_reward / num_episodes


def main(args: argparse.Namespace) -> None:
    print(f'Starting training with {args.num_cards} cards ...')

    # Initialize two-player Goofspiel with a random policy
    envs = SubprocVecEnv(
        [make_env(args.num_cards, None, args.seed + i) for i in range(args.num_envs)]
    )

    # Initialize PPO with a randomly initialized MLP network
    model = PPO(
        'MlpPolicy',
        envs,
        n_steps=args.train_timesteps // args.num_envs,
        learning_rate=5e-5,
        batch_size=128,
        n_epochs=24,
        seed=args.seed,
        verbose=1,
    )

    policies = []
    for i in range(1, args.self_play_iters + 1):
        print(f'Starting self-play iteration {i} ...')

        for _ in range(args.train_iters):
            model.learn(total_timesteps=args.train_timesteps, progress_bar=True)

            # Periodically checkpoint the current policy
            policies.append(model.policy.actor_state)

        # NOTE: For some reason, `envs.set_attr('_opponent', policies[-1])` is throwing
        # `PicklingError`, so we terminate current environments and create new ones
        # with a new opponent. Unfortunately, this approach is a little slower.

        # Terminate environment processes to free up memory
        envs.close()

        if i < args.self_play_iters:
            # Initialize two-player Goofspiel with the previous policy
            envs = SubprocVecEnv(
                [
                    make_env(args.num_cards, policies[-1], args.seed + i * args.self_play_iters + j)
                    for j in range(args.num_envs)
                ]
            )
            model.set_env(envs)

    # Serialize and save the final list of policies
    checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    checkpointer.save(
        f'{args.base_dir}/{args.policy_dir}/policies-{args.num_cards:02}-{args.seed}.ckpt',
        policies,
        force=True,
    )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, default=os.getcwd(), help='Base directory')
    parser.add_argument('--num_cards', type=int, default=5, help='Cards in the game')
    parser.add_argument('--num_envs', type=int, default=2, help='Parallel environments to run')
    parser.add_argument(
        '--policy_dir', type=str, default='goofspiel_policies', help='Policy directory'
    )
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--self_play_iters', type=int, default=5, help='Self-play iterations')
    parser.add_argument(
        '--train_iters', type=int, default=5, help='Training iterations per self-play iteration'
    )
    parser.add_argument(
        '--train_timesteps',
        type=int,
        default=2048,
        help='Training timesteps per training iteration',
    )
    args = parser.parse_args()

    main(args)
