import sys
import os

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

import hydra
from hydra.utils import instantiate
from easydict import EasyDict
from omegaconf import OmegaConf
import yaml
import domains
import concurrent.futures
import time
import datetime
import uuid
from utils import *


def run_experiment(cfg, args, seed):

    set_path(args)
    control_seed(seed)

    env = instantiate(args.domain.domain)
    qfunction = instantiate(args.domain.exp.qfunction)
    il = instantiate(args.domain.exp.il)
    experiment = instantiate(args.selection)
    
    dataset = get_dataset(env, args.dataset, args.root)
    
    experiment.init_exp(env, dataset, il, qfunction, args.domain.exp, 
                        args.root, args.selection_params, args.general, seed,
                        args.dataset)

    now_str = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    uid = str(uuid.uuid4())[:8]
    my_root_path = f"{experiment.save_root}/configs/{now_str}_{uid}"
    os.makedirs(my_root_path, exist_ok=True)
    config_path = os.path.join(my_root_path, "config.yaml")
    OmegaConf.save(cfg, config_path)

    experiment.run()



@hydra.main(config_path="config", config_name="config", version_base=None)
def main(hydra_cfg):
    
    yaml_config = OmegaConf.to_yaml(hydra_cfg, resolve=True)
    args = EasyDict(yaml.safe_load(yaml_config))

    if args.general.parallel_num == 1:
        run_experiment(hydra_cfg, args, args.general.seed)

    else:
        seeds = range(args.general.seed, args.general.seed + args.general.parallel_num)

        with concurrent.futures.ProcessPoolExecutor(max_workers=args.general.parallel_num) as executor:
            futures = [
                executor.submit(run_experiment, hydra_cfg, args, seed)
                for seed in seeds
            ]
            # Wait for all tasks to complete (handle exceptions if needed)
            concurrent.futures.wait(futures)

    
if __name__ == "__main__" and __package__ is None:
    __package__ = "RLLF"
    main()
    

# 