import json
import os
import subprocess
import copy
import sys
from concurrent.futures import ProcessPoolExecutor
import time
import logging
# -------------------------------- SETTINGS ------------------------------
RES_FILE_SUFFIX = "" 
CONFIG_FILE = "exps/BallIL.json"
SEEDS = [1993] 
INCREMENTS= [(5, 5), (10, 10), (20, 20)]
MAX_CONCURRENT_PROCESSES = 1
GPUS = [0]

param_grid_diff_tasks = {
    "5_5":[
        {"cifar100": {"w_confu":1, "w_cons": 0.5, "w_ball_cls": 400, "w_concp":0.01, "blur_r": 0.08}}
    ],
    "10_10": [
    ],
    "20_20":[
    ]
}
# ------------------------ LOGGING CONFIGURATION ------------------------
logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s [%(filename)s] => %(message)s",
            handlers=[
                logging.FileHandler(filename="logs/" + "BallIL_tune_CIFAR100.log"),
                logging.StreamHandler(sys.stdout),
            ],
)


def run_experiment(params, gpu_id, init_cls, increment, seed):
    with open(CONFIG_FILE, 'r') as f:
        config = json.load(f)
    config_updated = copy.deepcopy(config)
    for key, value in params.items():
        if key in config_updated:
            config_updated[key].update(value)
            config_updated[key].update({"init_cls":init_cls, "increment": increment})
        else:
            config_updated[key] = value
        config_updated['note'] = [str(config_updated['note'])+' '+str(value)]
    config_updated['resume'] = False
    config_updated["print_info"] = False
    config_updated["device"] = [gpu_id]

    
    config_updated["seed"] = seed
    cmd = ["python", "main_tune.py"]
    env = os.environ.copy()
    logging.info(f"Running experiment on seed {seed}, init_cls {init_cls}, increm {increment}, on GPU {gpu_id}, params:{params}")
    env["CONFIG_JSON"] = json.dumps(config_updated)
    subprocess.run(cmd, env=env)


with ProcessPoolExecutor(max_workers=MAX_CONCURRENT_PROCESSES) as executor:
    futures = []
    idx = 0
    for init_cls, increment in INCREMENTS:
        for _, params in enumerate(param_grid_diff_tasks["{}_{}".format(init_cls, increment)]):
            for seed in SEEDS:
                gpu_id = GPUS[idx % len(GPUS)]
                idx += 1
                futures.append(executor.submit(run_experiment, params, gpu_id, init_cls, increment, seed))
                time.sleep(3)
    for future in futures:
        future.result()
