import json
import os
import subprocess
import copy
import sys
from concurrent.futures import ProcessPoolExecutor
import time
import logging
# -------------------------------- SETTINGS ------------------------------
RES_FILE_SUFFIX = "" 
CONFIG_FILE = "exps/BallIL.json"
SEEDS = [1993] 
INCREMENTS= [
    (10, 10), 
    (20, 20), 
    (40, 40)
    ] 
DATASET = "tinyimagenet200"
SHOT_DATA = False 
MAX_CONCURRENT_PROCESSES =1
GPUS = [0]

param_grid_diff_tasks = {
    "20_20": [
        {"tinyimagenet200":{'w_confu': 1, 'w_cons': 0.1, 'w_ball_cls': 400, 'w_concp': 0.1, 'blur_r': 0.03}},
    ],
    "40_40":[
    ]
}
# ------------------------ LOGGING CONFIGURATION ------------------------
logging.basicConfig(
           
            level=logging.INFO,
            format="%(asctime)s [%(filename)s] => %(message)s",
            handlers=[
                logging.FileHandler(filename="logs/" + "BallIL_main.log"),
                logging.StreamHandler(sys.stdout),
            ],
)


def run_experiment(params, gpu_id, init_cls, increment, seed):
    with open(CONFIG_FILE, 'r') as f:
        config = json.load(f)
    config_updated = copy.deepcopy(config)
    for key, value in params.items():
        if key in config_updated:
            config_updated[key].update(value)
            config_updated[key].update({"init_cls":init_cls, "increment": increment})
        else:
            config_updated[key] = value
        config_updated['note'] = [str(config_updated['note'])+' '+str(value)]
    config_updated['dataset'] = DATASET
    config_updated["seed"] = seed
    config_updated["device"] = [gpu_id]
    config_updated['resume'] = False
    config_updated["shot"] = SHOT_DATA
    config_updated["print_info"] = False
    config_updated['suffix_res_file'] = RES_FILE_SUFFIX
 
    cmd = ["python", "main_tune.py"]
    env = os.environ.copy()
    logging.info(f"Running experiment on seed {seed}, init_cls {init_cls}, increm {increment}, on GPU {gpu_id}, params:{params}")
    env["CONFIG_JSON"] = json.dumps(config_updated)
    subprocess.run(cmd, env=env)


with ProcessPoolExecutor(max_workers=MAX_CONCURRENT_PROCESSES) as executor:
    futures = []
    idx = 0
    for init_cls, increment in INCREMENTS:
        for _, params in enumerate(param_grid_diff_tasks["{}_{}".format(init_cls, increment)]):
            for seed in SEEDS:
                gpu_id = GPUS[idx % len(GPUS)]
                idx += 1
                futures.append(executor.submit(run_experiment, params, gpu_id, init_cls, increment, seed))
                time.sleep(3)
    for future in futures:
        future.result()
