import os, sys
import subprocess
from multiprocessing import Pool
from pathlib import Path
import time
import random
import threading
import argparse

# get the directory of the current script（counterfactualCLIP/）
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)

# define the parameter combinations
alphas = [0.4]
lam_hats = [1.1]
lams = [0.8]
batch_sizes = [200]
select_scene_nums = [50]

# ablation parameter type
# ablation_param = {0:"alpha", 1: "lam", 2:"lam_hat", 3:"select_scene_num"}
# param = ablation_param[3]

# fixed parameters
model = "ViT-B-32"  # "ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-H-14"
dataset = "waterbirds" # "waterbirds", "urbancars", "cocogbv1", "cocogbv2", "imagenet_a", "imagenet_w", "nico"
scene_type = "outer_cz" # "outer_cz", "inner_cz", "virtual_cz", "random_cz"

# create the log directory（counterfactualCLIP/param_logs/）
log_dir = current_dir / "param_logs"
log_dir.mkdir(exist_ok=True)

# run the experiment
def run_experiment(cmd):
    subprocess.run(cmd, shell=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run counterfactualCLIP parameter experiments in parallel")
    parser.add_argument('--alphas', type=float, nargs='+', default=[0.4], help='List of alpha parameters')
    parser.add_argument('--lam_hats', type=float, nargs='+', default=[1.1], help='List of lam_hat parameters')
    parser.add_argument('--lams', type=float, nargs='+', default=[0.8], help='List of lam parameters')
    parser.add_argument('--batch_sizes', type=int, nargs='+', default=[200], help='List of batch_size parameters')
    parser.add_argument('--select_scene_nums', type=int, nargs='+', default=[2], help='List of select_scene_num parameters')
    parser.add_argument('--model', type=str, default='ViT-B-32', help='Model name')
    parser.add_argument('--dataset', type=str, default='waterbirds', help='Dataset name')
    parser.add_argument('--scene_type', type=str, default='outer_cz', help='Scene type')
    parser.add_argument('--gpu_list', type=int, nargs='+', default=[0,1,2], help='List of available GPUs')
    parser.add_argument('--max_per_gpu', type=int, default=3, help='Maximum processes per GPU')
    args = parser.parse_args()

    alphas = args.alphas
    lam_hats = args.lam_hats
    lams = args.lams
    batch_sizes = args.batch_sizes
    select_scene_nums = args.select_scene_nums
    model = args.model
    dataset = args.dataset
    scene_type = args.scene_type
    gpu_list = args.gpu_list
    max_per_gpu = args.max_per_gpu

    time_start = time.time()
    print("begin to run the parallel tasks...")
    # generate all parameter combinations
    param_combinations = [(a, h, l, b, s) for a in alphas for h in lam_hats for l in lams for b in batch_sizes for s in select_scene_nums]

    # set the runner script path
    script_path = current_dir / "run_counterfactualCLIP.py"
    # set the log directory
    log_dir_path = os.path.join(log_dir) # for adjustment of parameters
    os.makedirs(log_dir_path, exist_ok=True)

    semaphores = {gpu: threading.Semaphore(max_per_gpu) for gpu in gpu_list}

    def run_cmd_with_semaphore(cmd, gpu_id, a, lam_hats, lams, batch_sizes, select_scene_nums):
        with semaphores[gpu_id]:
            print(f"[begin] GPU {gpu_id} | alpha={a}, lam_hat={lam_hats}, lam={lams}, batch_size={batch_sizes}, select_scene_num={select_scene_nums}, model={model}, scene_type={scene_type}")
            subprocess.run(cmd, shell=True) # run the command

    threads = [] # list of threads
    for i, (a, lam_hats, lams, batch_sizes, select_scene_nums) in enumerate(param_combinations):
        gpu_id = gpu_list[i % len(gpu_list)]
        log_file = f"{log_dir_path}/log_alpha{a}_lam_hat{lam_hats}_lam{lams}_batch_size{batch_sizes}_select_scene_num{select_scene_nums}_scene_type{scene_type}_model{model}.txt"
        cmd = (
            # set the visible GPU
            f"CUDA_VISIBLE_DEVICES={gpu_id} python \"{script_path}\" " 
            # set the parameters
            f"--model {model} --batch_size {batch_sizes} --select_scene_num {select_scene_nums} "
            f"--dataset {dataset} --scene_type {scene_type} --alpha {a} "
            f"--lam_hat {lam_hats} --lam {lams} --cuda_id 0 "
            # redirect the output to the log file
            f"> \"{log_file}\" 2>&1" 
        )
        t = threading.Thread(target=run_cmd_with_semaphore, args=(cmd, gpu_id, a, lam_hats, lams, batch_sizes, select_scene_nums))
        threads.append(t)
        t.start()

    for t in threads:
        t.join()

    time_end = time.time()
    print(f"tasks completed, time used: {int((time_end - time_start)/60)} minutes {int(time_end - time_start)%60} seconds")