import os
import wandb
import argparse
import src.marl_systems as marl_systems
import src.halluc_systems as halluc_systems
from src.utils.utils import process_config, seed_everything, initialize_save_data

SYSTEM = {
    "IQL": marl_systems.IQL,
    "DDPG": marl_systems.DDPG,
    'Halluc_IQL': halluc_systems.Halluc_IQL,
    'Halluc_DDPG': halluc_systems.Halluc_DDPG,
}

def run(args, gpu_device):
    config = process_config(args.config)
    if not gpu_device or gpu_device == 'cpu':
        gpu_device = 'cpu'
    else:
        gpu_device = "cuda:" + gpu_device
    config.gpu_device = gpu_device
    output_dir = config.logger.output_dir
    exp_name = config.exp_name
    seeds = args.seeds if args.seeds else [config.seed]
    for seed in seeds:
        config.seed = int(seed)
        config.exp_name = exp_name + f"_seed_{seed}"
        config.logger.output_dir = os.path.join(output_dir, config.exp_name)
        initialize_save_data(config)
        seed_everything(config.seed)
        SystemClass = SYSTEM[config.system]
        system = SystemClass(config)
        system.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('config', type=str, default='path to config file')
    parser.add_argument('--gpu-device', type=str, default='cpu')
    parser.add_argument('--seeds', type=lambda s: [int(item) for item in s.split(',')])
    args = parser.parse_args()
    gpu_device = str(args.gpu_device) if args.gpu_device else None

    run(args, gpu_device)