import argparse
import torch
import os
import wandb
from pathlib import Path
from utils.helper_fn import get_hparam_class, get_obj_class, get_policy_type
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
wandb.require("core")

def call_agent():
    try:
        obj = obj_class(env_name, env_args, hparam_space, device=args.device)
        obj.wandb_search()
    except TimeoutError as e:
        # Update the status to 'crashed' due to timeout
        wandb.run.summary["status"] = "crashed"
        wandb.run.summary["failure_reason"] = str(e)
        wandb.run.finish()
        raise e
    else:
        # Finish the wandb experiment normally if no issues
        wandb.finish()
    return


def parse_args():
    parser = argparse.ArgumentParser()

    # training-aid hyperparameters
    parser.add_argument("--wandb_project_name", type=str, default="LLM-Instruct-PPO")
    parser.add_argument("--sweep_id", type=str, default="07vw9adw", help="sweep id for wandb,"
                                                                         " only used in agent mode")
    parser.add_argument("--role", type=str, default="sweep", choices=["sweep", "agent", "run_single"])
    parser.add_argument("--task", type=str, default="SimGlucoseEnv-adult1",
                        help="remember to change this for different tasks! "
                             "Wandb sweep won't work correctly if this is not changed!")
    parser.add_argument("--log_dir", type=str, default="sweep_log/")
    parser.add_argument("--n_test_in_train", type=int, default=20)
    parser.add_argument("--epoch", type=int, default=2)
    parser.add_argument("--num_actions", type=int, default=11)
    parser.add_argument("--step_per_epoch", type=int, default=10*288)
    parser.add_argument("--obs_window", type=int, default=12)
    parser.add_argument("--buffer_size", type=int, default=1e5)
    parser.add_argument("--policy_name", type=str, default="LLM-Instruct",
                        choices=["LLM-Instruct", "LLM", "PPO"],
                        help="remember to change this for different tasks! "
                             "Wandb sweep won't work correctly if this is not changed!")
    parser.add_argument("--inference_mode", type=str, default="local",
                        choices=["API", "local"],
                        help="different ways of calling actorLM in LLM Policy and LLM-Instruct Policy. ")
    parser.add_argument("--transformers_mode", type=str, default="transformers",
                        choices=["transformer", "vllm"],
                        help="different ways of inferencing local actorLM in LLM Policy and LLM-Instruct Policy. ")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    args = parser.parse_known_args()[0]
    return args


if __name__ == "__main__":
    args = parse_args()
    hparam_class = get_hparam_class(args.policy_name, offline=False)
    obj_class = get_obj_class(args.policy_name, offline=False)

    Path(args.log_dir).mkdir(parents=True, exist_ok=True)

    policy_type = get_policy_type(args.policy_name, offline=False)
    env_args = {"discrete": policy_type == "discrete",
                "n_act": args.num_actions,}

    env_name = args.task
    log_dir = os.path.join(args.log_dir, env_name + '-' + args.policy_name)
    hparam_space = hparam_class(args.policy_name,
                                log_dir,
                                1,  # number of training envs
                                args.n_test_in_train,  # number of test envs
                                args.epoch,
                                args.step_per_epoch,  # number of training steps per epoch
                                args.buffer_size,
                                args.obs_window,
                                args.inference_mode,
                                args.transformers_mode,
                                args.num_actions,
                                linear=False
                                )
    search_space = hparam_space.get_search_space()

    print("All prepared. Start to experiment")
    if args.role == "sweep":
        sweep_configuration = {
            "method": "grid",
            "project": args.wandb_project_name,
            "name": env_name + f"-{args.policy_name}",
            "metric": {"goal": "maximize", "name": "test/returns_stat/mean"},
            "parameters": search_space
        }
        sweep_id = wandb.sweep(sweep_configuration, project=args.wandb_project_name)
        wandb.agent(sweep_id=sweep_id, function=call_agent, project=args.wandb_project_name)
    else:
        if args.role == "agent":
            wandb.agent(sweep_id=args.sweep_id, function=call_agent, project=args.wandb_project_name)
        if args.role == "run_single":
            obj = obj_class(env_name, env_args, hparam_space, device=args.device)
            config_dict = hparam_space.sample(mode="random")
            obj.search_once({**config_dict, **{"wandb_project_name": args.wandb_project_name}})
        else:
            print("role must be one of [sweep, agent, run_single], get {}".format(args.role))
            raise NotImplementedError