# Copyright (c) 2023
# Copyright holder of the paper "End-to-End Meta-Bayesian Optimisation with Transformer Neural Processes".
# Submitted to NeurIPS 2023 for review.
# All rights reserved.


import os
import pickle
from datetime import datetime

import numpy as np

from nap.environment.hpo import get_hpo_specs
from nap.environment.mip import get_cond_mip_specs
from nap.environment.objectives import get_HPO_domain
from nap.policies.policies import iclr2020_NeuralAF
from nap.RL.ppo import PPO
from gym.envs.registration import register
import torch

torch.cuda.set_device(0)

rootdir = os.path.join(os.path.dirname(os.path.realpath(__file__)))
rootdir_nap = os.path.join(os.path.dirname(os.path.realpath(__file__)), "nap")

dims, num_dims, cat_dims, min_points, cat_alphabet, train_datasets, valid_datasets, test_datasets, train_gps, _, _ = \
    get_cond_mip_specs(rootdir)
num_classes = [len(cat_alphabet[k]) for k in cat_alphabet]

gp_params = {
    "cont_kern_ls": [],
    "cat_kern_ls": [],
    "ls": [],
    "lamda": [],
    "mean_const": [],
    "lik_noise": [],
}
for gp_path in train_gps:
    gp_model = torch.load(gp_path)
    gp_params["cont_kern_ls"].append(gp_model.covar_module.base_kernel.continuous_kern.lengthscale.detach().cpu().numpy())
    gp_params["cat_kern_ls"].append(gp_model.covar_module.base_kernel.categorical_kern.lengthscale.detach().cpu().numpy())
    gp_params["ls"].append(gp_model.covar_module.base_kernel.lengthscale.detach().cpu().numpy())
    gp_params["lamda"].append(gp_model.covar_module.base_kernel.lamda)
    gp_params["lik_noise"].append(gp_model.likelihood.noise.detach().cpu().numpy())
    gp_params["mean_const"].append(gp_model.mean_module.constant.detach().cpu().numpy())

cont_kern_ls = np.concatenate(gp_params["cont_kern_ls"]).mean(0).tolist()
cat_kern_ls = np.concatenate(gp_params["cat_kern_ls"]).mean(0).tolist()
ls = np.concatenate(gp_params["ls"]).mean().item()
lamda = np.array(gp_params["lamda"]).mean().item()
lik_noise = np.concatenate(gp_params["lik_noise"]).mean().item()
mean_const = np.array(gp_params["mean_const"]).mean().item()

# specifiy environment
env_spec = {
        "env_id": f"MetaBO-fixed-v0",
        "f_type": "MIP",
        "D": dims,
        "f_opts": {
            "kernel": "Mixture",
            "min_regret": 1e-20,
            "data": train_datasets,
            "cat_dims": cat_dims,
            "num_classes": num_classes,
            "cont_dims": num_dims,
            "shuffle_and_cutoff": True,
            "continuous_kern_lengthscale": cont_kern_ls,
            "categorical_kern_lengthscale": cat_kern_ls,
            "outputscale": ls,
            "lamda": lamda,
            "likelihood_noise": lik_noise,
            "mean_constant": mean_const,
        },
        "features": ["posterior_mean", "posterior_std", "incumbent", "timestep_perc", "timestep", "budget"],
        "T": 24,
        "n_init_samples": 0,
        "pass_X_to_pi": False,
        "local_af_opt": False,
        "cardinality_domain": 200,
        "remove_seen_points": False,  # only True for testing
        # will be set individually for each new function to the sampled hyperparameters
        "kernel": "Mixture",
        "kernel_lengthscale": None,  # kernel_lengthscale,
        "kernel_variance": None,  # kernel_variance,
        "noise_variance": None,  # noise_variance,
        "use_prior_mean_function": False,
        "reward_transformation": "neg_log10"  # true maximum not known
    }

# specify PPO parameters
n_iterations = 500
batch_size = 1200
n_workers = 10
arch_spec = 4 * [200]
ppo_spec = {
    "batch_size": batch_size,
    "max_steps": n_iterations * batch_size,
    "minibatch_size": batch_size // 20,
    "n_epochs": 4,
    "lr": 1e-4,
    "epsilon": 0.15,
    "value_coeff": 1.0,
    "ent_coeff": 0.0,
    "gamma": 0.98,
    "lambda": 0.98,
    "loss_type": "GAElam",
    "normalize_advs": True,
    "n_workers": n_workers,
    "env_id": env_spec["env_id"],
    "seed": 0,
    "argmax": False,
    "env_seeds": list(range(n_workers)),
    "policy_options": {
        "activations": "relu",
        "arch_spec": arch_spec,
        "exclude_t_from_policy": True,
        "exclude_T_from_policy": True,
        "use_value_network": True,
        "t_idx": -2,
        "T_idx": -1,
        "arch_spec_value": arch_spec
    }
}

# register environment
register(
    id=env_spec["env_id"],
    entry_point="nap.environment.function_gym:MetaBOEnv",
    max_episode_steps=env_spec["T"],
    reward_threshold=None,
    kwargs=env_spec
)

# log data and weights go here, use this folder for evaluation afterwards
logpath = os.path.join(rootdir_nap, "log", 'TRAIN/MIP', env_spec["env_id"], datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S"))

# set up policy
policy_fn = lambda observation_space, action_space, deterministic: \
    iclr2020_NeuralAF(observation_space=observation_space,
                      action_space=action_space,
                      deterministic=True if ppo_spec["argmax"] else deterministic,
                      options=ppo_spec["policy_options"])

# do training
print("Training on {}.\nFind logs, weights, and learning curve at {}\n\n".format(env_spec["env_id"], logpath))
ppo = PPO(policy_fn=policy_fn, params=ppo_spec, logpath=logpath, save_interval=100)
ppo.train()
