
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from time import time

from algorithms.fact_MAHDDRQN.core import FactMAHDDRQN

from common.imports import *
from common.utils import set_random_seed, set_torch, str2bool
from environments.utils import *

# Dictionary mapping algorithm names to their corresponding classes
ALGORITHMS: Dict[str, Type[Any]] = {'FactMAHDDRQN': FactMAHDDRQN}

def main(args):
    start_time = time()

    assert args.time_limit <= 1440, f"Invalid time limit: {args.time_limit}. Timeout limit is : 1440"
    assert args.n_envs >= 1, f"Invalid n° of environments: {args.n_envs}. Must be >= 1"
    assert args.alg in ALGORITHMS.keys(), f"Unsupported algorithm: {args.alg}. Supported algorithms are: {ALGORITHMS}"
    
    run_name = f"{args.alg}_{args.env_id}_{args.seed}"    
    
    if args.env_id.startswith('BP'): env_args = get_boxpushing_params(args)
    elif args.env_id.startswith('OSD'): env_args = get_osd_params(args)
    elif args.env_id.startswith('CT'): env_args = get_capturetarget_params(args)
    else: raise ValueError(f"{args.env_id} does not exist.")

    # Update args with environment arguments
    args = ap.Namespace(**vars(args), **env_args)
    
    args.max_steps = env_args['terminate_step']
    venv = gym.vector.AsyncVectorEnv([make_env(args, idx) for idx in range(args.n_envs)])
    # action is (n_agents, n_envs) 

    # Set random seed and Torch configuration
    set_random_seed(args.seed)
    set_torch(args.n_threads, args.th_deterministic, args.cuda)

    device = th.device("cuda" if th.cuda.is_available() and args.cuda else "cpu")

    # Run the specified algorithm
    ALGORITHMS[args.alg](venv, run_name, start_time, args, device)   

if __name__ == '__main__':
    parser = ap.ArgumentParser()

    # Cluster
    parser.add_argument("--time-limit", type=float, default=1300, help="Time limit for the action ranking")

    # Environment
    # BP-MA-v0, OSD-S-v4, OSD-D-v7, OSD-T-v0, OSD-T-v1, OSD-F-v0, CT-MA-v0
    parser.add_argument("--env-id", type=str, default="OSD-S-v4", help="Id of the environment.")
    parser.add_argument("--n-envs", type=int, default=10, help="Number of asynchronous envs to run.")
    parser.add_argument("--grid-size", type=int, default=10, help="Grid size for Box Pushing.")

    # Experiment
    # DecMAHDDRQN
    parser.add_argument("--alg", type=str, default='FactMAHDDRQN', help="Algorithm to run")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--total-timesteps", type=int, default=5000000, help="Total timesteps for the experiment")
    parser.add_argument("--learning-starts", type=int, default=20000, help="When to start learning")
    parser.add_argument("--eval-freq", type=int, default=1000, help="Total timesteps between deterministic evals")

    # Logger
    parser.add_argument("--verbose", type=str2bool, default=True, help="Toggles prints")
    parser.add_argument("--exp-tag", type=str, default='', help="Tag for logging the experiment")
    parser.add_argument("--track", type=str2bool, default=False, help="Tag for logging the experiment")
    parser.add_argument("--wandb-project", type=str, default="", help="Wandb's project name.")
    parser.add_argument("--wandb-entity", type=str, default="", help="Entity (team) of wandb's project.")
    parser.add_argument("--wandb-mode", type=str, default="online", help="Online or offline wandb mode.")
    
    # Torch
    parser.add_argument("--th-deterministic", type=str2bool, default=True, help="Enable deterministic in Torch.")
    parser.add_argument("--cuda", type=str2bool, default=False, help="Enable CUDA by default.")
    parser.add_argument("--n-threads", type=int, default=1, help="Max number of torch threads.")

    main(parser.parse_known_args()[0])
