import argparse
import jax
import wandb
import jax.numpy as jnp
from jax import random
from rich.traceback import install


from utils.jax import jax_debug_wrapper
from utils.logging import init_logger, log_results
from meta.es_step import es_train_step
from meta.meta import (
    create_mmap_train_state,
    get_candidate_fitness_fun,
)


def wandb_callback(metrics):
    wandb.log(
        data=metrics,
    )


def make_train(args):
    def _train_fn(rng):

        # --- Initialize LPMD ---
        rng, _rng = jax.random.split(rng)
        with jax.disable_jit():
            mmap_train_state = create_mmap_train_state(_rng, args)

        # --- Initialize candidate_fitness ---

        dataset_path = f"{args.dataset_path}/{args.data_type}/{args.env_name}/dataset_{args.reward_agent1}vs{args.reward_agent2}.npz"
        dataset = jnp.load(dataset_path, allow_pickle=True).item()

        candidate_fitness_fun = get_candidate_fitness_fun(
            args, mmap_train_state, dataset
        )

        # --- TRAIN LOOP ---
        for gen in range(args.train_steps):
            # --- Update LPMD ---
            rng, _rng = jax.random.split(rng)
            jax.debug.print("Training step {x}", x=gen)
            mmap_train_state, metrics = es_train_step(
                _rng,
                mmap_train_state,
                args.n_devices,
                args.contrained_net,
                candidate_fitness_fun,
                args,
            )
            log_results(metrics, mmap_train_state, gen)
        return metrics, mmap_train_state

    return _train_fn


def run_training_experiment(args):
    if args.log:
        init_logger(args)
    train_fn = make_train(args)
    rng = random.PRNGKey(args.seed)
    metrics, train_state = train_fn(rng)
    wandb.finish()


def main():
    arg_parser = argparse.ArgumentParser()
    # Parent dataset
    arg_parser.add_argument("--env_name", type=str, default="hopper")
    arg_parser.add_argument("--data_type", type=str, default="basic")
    arg_parser.add_argument("--reward_agent1", type=str, default="2100")
    arg_parser.add_argument("--reward_agent2", type=str, default="2100")
    arg_parser.add_argument("--dataset_path", type=str)

    # Child dataset
    arg_parser.add_argument("--num_data_points", type=int, default=5120)

    # Inner training
    arg_parser.add_argument("--num_eval_agents", type=int, default=100)
    arg_parser.add_argument("--loss_type", type=str, default="borpo2")
    arg_parser.add_argument("--reference_agent", type=str, default=None)
    arg_parser.add_argument("--partial", action="store_true")
    arg_parser.add_argument("--off_policy", type=int, default=1)
    arg_parser.add_argument("--noise", type=float, default=0.0)
    arg_parser.add_argument("--shuffle_agents", type=int, default=0)
    arg_parser.add_argument("--judge_temp", type=float, default=101.5)
    arg_parser.add_argument("--update_epochs_multiplier", type=float, default=1.0)

    # Experiment
    arg_parser.add_argument("--seed", type=int, default=0)
    arg_parser.add_argument("--log", action="store_true")
    arg_parser.add_argument("--train_steps", type=int, default=256)
    arg_parser.add_argument("--n_devices", type=int, default=1)
    arg_parser.add_argument("--num_mini_batches", type=int, default=1)
    arg_parser.add_argument("--debug", action="store_true")
    arg_parser.add_argument("--debug_nans", action="store_true")
    arg_parser.add_argument("--wandb_project", type=str)
    arg_parser.add_argument("--wandb_entity", type=str)
    arg_parser.add_argument("--wandb_group", type=str)
    arg_parser.add_argument("--main_folder_path", type=str)

    # ES
    arg_parser.add_argument("--rank_transform", type=int, default=0)
    arg_parser.add_argument("--temporally_aware", type=int, default=0)
    arg_parser.add_argument("--parametrised_reward_model", type=int, default=0)
    arg_parser.add_argument("--add_logsimoid_bias", type=int, default=0)
    arg_parser.add_argument("--es_checkpoint", type=str, default=None)
    arg_parser.add_argument("--contrained_net", type=bool, default=True)
    arg_parser.add_argument("--mmap_net_width", type=int, default=128)
    arg_parser.add_argument("--lpmd_opt", type=str, default="Adam")
    arg_parser.add_argument("--lpmd_max_grad_norm", type=float, default=0.5)
    arg_parser.add_argument("--es_lrate_init", type=float, default=0.001)  # 0.001
    arg_parser.add_argument("--es_lrate_decay", type=float, default=0.999)  # 0.999
    arg_parser.add_argument("--es_lrate_limit", type=float, default=1e-3)
    arg_parser.add_argument("--es_sigma_init", type=float, default=0.03)  # 0.1
    arg_parser.add_argument("--es_sigma_decay", type=float, default=0.999)
    arg_parser.add_argument("--es_sigma_limit", type=float, default=0.01)
    arg_parser.add_argument("--es_mean_decay", type=float, default=0.0)
    arg_parser.add_argument(
        "--num_agents",
        help="Meta-train batch size, doubled for antithetic task sampling when using ES",
        type=int,
        default=64,
    )
    args = arg_parser.parse_args()

    experiment_fn = jax_debug_wrapper(args, run_training_experiment)
    experiment_fn(args)


if __name__ == "__main__":
    install()
    main()
