import json
import os
import subprocess
import copy
import sys
from concurrent.futures import ProcessPoolExecutor
import time
import logging
# -------------------------------- SETTINGS ------------------------------
RES_FILE_SUFFIX = ""
CONFIG_FILE = "exps/BallIL.json"
SEEDS = [1993] 
INCREMENTS= [
    (5, 5), 
    (10, 10), 
    (20, 20)
    ]
DATASET = "imagenet100"
SHOT_DATA = False
MAX_CONCURRENT_PROCESSES =1
GPUS = [0]

param_grid_diff_tasks = {
    "5_5":[
        {"imagenet100":{'w_confu': 1, 'w_cons': 0.3, 'w_ball_cls': 400, 'w_concp': 0.1, 'blur_r': 0.03}},
    ],
    "10_10": [
    ],
    "20_20":[
    ]
}
# ------------------------ LOGGING CONFIGURATION ------------------------
logging.basicConfig(
           
            level=logging.INFO,
            format="%(asctime)s [%(filename)s] => %(message)s",
            handlers=[
                logging.FileHandler(filename="logs/" + "BallIL_main.log"),
                logging.StreamHandler(sys.stdout),
            ],
)


def run_experiment(params, gpu_id, init_cls, increment, seed):
    with open(CONFIG_FILE, 'r') as f:
        config = json.load(f)
    config_updated = copy.deepcopy(config)
    for key, value in params.items():
        if key in config_updated:
            config_updated[key].update(value)
            config_updated[key].update({"init_cls":init_cls, "increment": increment})
        else:
            config_updated[key] = value
        config_updated['note'] = [str(config_updated['note'])+' '+str(value)]
    config_updated['dataset'] = DATASET
    config_updated["seed"] = seed
    config_updated["device"] = [gpu_id]
    config_updated['resume'] = False
    config_updated["shot"] = SHOT_DATA
    config_updated["print_info"] = False
    config_updated['suffix_res_file'] = RES_FILE_SUFFIX
 
    cmd = ["python", "main_tune.py"]
    env = os.environ.copy()
    logging.info(f"Running experiment on seed {seed}, init_cls {init_cls}, increm {increment}, on GPU {gpu_id}, params:{params}")
    env["CONFIG_JSON"] = json.dumps(config_updated)
    subprocess.run(cmd, env=env)


with ProcessPoolExecutor(max_workers=MAX_CONCURRENT_PROCESSES) as executor:
    futures = []
    idx = 0
    for init_cls, increment in INCREMENTS:
        for _, params in enumerate(param_grid_diff_tasks["{}_{}".format(init_cls, increment)]):
            for seed in SEEDS:
                gpu_id = GPUS[idx % len(GPUS)]
                idx += 1
                futures.append(executor.submit(run_experiment, params, gpu_id, init_cls, increment, seed))
                time.sleep(3)
    for future in futures:
        future.result()
