import os
import subprocess
import glob

def has_multiple_safetensors(dir_path):
    safetensors_files = glob.glob(os.path.join(dir_path, "*.safetensors"))
    return len(safetensors_files) >= 1

def extract_model_name(path):
    # Extract meaningful parts from the path
    parts = path.split('/')
    # Get the relevant parts from path
    # training_config = ""
    
    for i, part in enumerate(parts):
        if "llama" in part.lower() or "qwen" in part.lower():
            model_size = part  # e.g. llama3.1-8b
            training_config = parts[i+1] if i+1 < len(parts) else ""  # e.g. lora256-samples2000
            break
    # Check if this is a checkpoint path
    checkpoint = ""
    if "checkpoint" in path:
        # if "checkpoint-1700" not in path:
        #     return None
        # return None
        # checkpoint = "-" + path.split("/")[-1]  # e.g. -checkpoint-1000

        # Find checkpoint in path
        checkpoint_num = path.split("checkpoint-")[-1].split("-")[0].split("/")[0]
        checkpoint_num = int(checkpoint_num)
        
        if checkpoint_num <= 1000:
            return None
        
        checkpoint = f"-checkpoint-{checkpoint_num}"
        
        # # Skip if checkpoint is less than 1000
        # try:
        #     # checkpoint_num = int(checkpoint.split("-")[-1])
        #     if checkpoint_num % 1000 != 0:
        #         # print(f"Skipping {path} because checkpoint {checkpoint_num} is not a multiple of 1000")
        #         return None
        #     if checkpoint_num < 1000:
        #         # print(f"Skipping {path} because checkpoint < 1000")
        #         return None
        # except ValueError:
        #     # If checkpoint number can't be parsed, skip the check
        #     pass
        
    # Extract lora size from training config
    if "lora" in training_config : 
        lora_size = training_config.split("lora")[-1].split("-")[0].split("/")[0]
        # print(f"Lora size: {lora_size}, Training config: {training_config}")
        if int(lora_size.replace("lora","")) < 256: 
            return None
    
    # Combine into descriptive name
    if "samples" in training_config:
        # For sample-specific models
        samples = training_config.split("-")[-1]
        lora_size = training_config.split("-")[0]
        
        
        # Skip if samples <= 500
        if int(samples.replace("samples","")) <= 500:
            print(f"Skipping {path} because samples <= 500")
            return None
        return f"{model_size}-{lora_size}-{samples}{checkpoint}"
    else:
        # For regular models
        return f"{model_size}-{training_config}{checkpoint}"

def launch_eval(model_dir, model_name, task):
    base_save_dir = f"/fast/XXXX-3/forecasting/evals/sft/{model_name}/outputs/"
    
    if model_name == None :
        return
    
    parent_dir = "/fast/XXXX-11/forecasting/news/retrieval"
    
    # go through all subdirectories in parent_dir
    for subdir in os.listdir(parent_dir):
        subdir_path = os.path.join(parent_dir, subdir)
        # print(f"Launching eval for {subdir_path}")
        if os.path.isdir(subdir_path):
            # launch eval for each subdirectory
            # launch_eval(subdir_path, model_name, task)
            
            pass
        
    cmd = [
        "python", "jobs_eval.py",
        f"--model_dir={model_dir}",
        f"--model={model_name}",
        "--max_new_tokens=32768",
        # "--data=HuggingFaceH4/MATH-500",
        # f"--data={subdir}",
        # "--data=manifold_mcq",
        "--data=metaculus",
        "--data_split=test",
        f"--task={task}",
        "--num_generations=5",
        "--n_gpus=1",
    ]
    
    subprocess.run(cmd)
    print("Launching eval with command: ", cmd)



FOUND = 0

def process_directory(dir_path):
    global FOUND
    # Check if current directory has multiple safetensors
    if has_multiple_safetensors(dir_path):
        model_name = extract_model_name(dir_path)
        if model_name != None:
        
            # Launch evaluations for both tasks
            for task in ["forecasting"]: # ["mcq_forecasting"]: # ["mmlu-pro", "math"]:
                if FOUND < 1:
                    launch_eval(dir_path, model_name, task)
                    FOUND += 1
            
    # Recursively process subdirectories
    for item in os.listdir(dir_path):
        item_path = os.path.join(dir_path, item)
        if os.path.isdir(item_path):
            process_directory(item_path)
            # break

# Base directory containing all model variants
# base_dir2 = "/fast/XXXX-3/forecasting/sft/llama3.1-8b/full"
# base_dir2 = "/fast/XXXX-3/forecasting/merged_models/sft/qwen25-32b"
# base_dir2 = "/fast/XXXX-3/forecasting/sft/qwen25-3b/full"
# base_dir2 = "/fast/XXXX-3/forecasting/sft/qwen25-1.5b/full"
# base_dir2 = "/fast/XXXX-3/forecasting/merged_models/sft/llama33-70b"
# base_dir2 = "/fast/XXXX-3/forecasting/sft/qwen25-14b/full"


base_dir2 = "/fast/XXXX-3/models/Qwen3-4B"

# base_dir2 = "/fast/rolmedo/models/qwen2.5-3b-it"
# base_dir2 = "/fast/rolmedo/models/qwen2.5-14b-it"
# base_dir2 = "/fast/rolmedo/models/llama-3.1-8b-it"
# base_dir2 = "/fast/rolmedo/models/llama-3.3-70b-instruct"
# base_dir2 = "/fast/rolmedo/models/r1-llama-70b"

# Start recursive processing
process_directory(base_dir2)
