import os
import argparse
from pathlib import Path
import htcondor
import glob

JOB_BID_SINGLE = 25

def launch_eval_job(
        base_save_dir,
        model_dir,
        model_name,
        batch_size=32,
        max_new_tokens=16384,
        data_split="test",
        data="metaculus",
        num_generations=5,
        JOB_MEMORY=128,
        JOB_CPUS=8,
        JOB_GPUS=1,
        JOB_BID=JOB_BID_SINGLE,
        GPU_MEM=40000,
):
    # Name/prefix for cluster logs related to this job
    LOG_PATH = "/fast/XXXX-11/logs/forecasting/evals"
    
    CLUSTER_LOGS_SAVE_DIR = Path(LOG_PATH)
    os.makedirs(CLUSTER_LOGS_SAVE_DIR, exist_ok=True)
    
    cluster_job_log_name = str(
        CLUSTER_LOGS_SAVE_DIR
        / f"$(Cluster).$(Process)"
    )

    executable = 'launch_custom_eval_job.sh'

    # Construct job description
    job_settings = {
        "executable": executable,
        "arguments": (
            f"{base_save_dir} "
            f"{model_dir} "
            f"{model_name} "
            f"{batch_size} "
            f"{max_new_tokens} "
            f"{data_split} "
            f"{data} "
            f"{num_generations} "
        ),
        "output": f"{cluster_job_log_name}.out",
        "error": f"{cluster_job_log_name}.err",
        "log": f"{cluster_job_log_name}.log",
        
        "request_gpus": f"{JOB_GPUS}",
        "request_cpus": f"{max(JOB_CPUS*JOB_GPUS, 32)}",
        "request_memory": f"{JOB_MEMORY*JOB_GPUS}GB",
        "request_disk": f"{JOB_MEMORY*JOB_GPUS}GB",
        
        "jobprio": f"{JOB_BID - 1000}",
        "notify_user": "XXXX-1.XXXX-2@tuebingen.mpg.de",
        "notification": "error",
    }

    if GPU_MEM is not None:
        job_settings["requirements"] = f"(TARGET.CUDAGlobalMemoryMb >= {GPU_MEM}) && (CUDACapability >= 8.0)"
    else:
        job_settings["requirements"] = "CUDACapability >= 8.0"

    job_description = htcondor.Submit(job_settings)

    # Submit job to scheduler
    schedd = htcondor.Schedd()
    submit_result = schedd.submit(job_description)

    print(
        f"Launched eval job with cluster-ID={submit_result.cluster()}, "
        f"proc-ID={submit_result.first_proc()}")
    print(f"Model: {model_dir}")
    print(f"Results will be saved to: {base_save_dir}")


def find_merged_checkpoints(main_dir):
    """Find all merged checkpoint directories."""
    # Look for checkpoint-*/merged directories
    merged_pattern = os.path.join(main_dir, "checkpoint-*/merged")
    merged_dirs = [d for d in glob.glob(merged_pattern) if os.path.isdir(d)]
    
    # Sort by checkpoint number
    merged_dirs.sort(key=lambda x: int(x.split('checkpoint-')[1].split('/')[0]))
    
    return merged_dirs


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Launch evaluation jobs for merged models")
    
    parser.add_argument('--main_dir', type=str, required=True,
                       help="Main directory containing checkpoint folders with merged models")
    
    parser.add_argument('--model_name', type=str, default="llama3.1-8b-ins",
                       help="Name of the model (e.g., 'llama3.1-8b-ins')")
    
    parser.add_argument('--batch_size', type=int, default=32,
                       help="Batch size for evaluation")
    
    parser.add_argument('--max_new_tokens', type=int, default=16384,
                       help="Maximum new tokens for generation")
    
    parser.add_argument('--data_split', type=str, default="test",
                       help="Data split to use (train or test)")
    
    parser.add_argument('--data', type=str, default="metaculus",
                       choices=['metaculus', 'halawi'],
                       help="Which dataset to use")
    
    parser.add_argument('--num_generations', type=int, default=5,
                       help="Number of generations per prompt")
    
    parser.add_argument('--gpu_mem', type=int, default=40000,
                       help="Minimum GPU memory required in MB")
    
    parser.add_argument('--job_memory', type=int, default=64,
                       help="Job memory request in GB")
    
    parser.add_argument('--job_cpus', type=int, default=4,
                       help="Number of CPUs to request")
    
    parser.add_argument('--job_gpus', type=int, default=1,
                       help="Number of GPUs to request")
    
    parser.add_argument('--specific_checkpoints', type=str, nargs='+',
                       help="Specific checkpoint numbers to evaluate (e.g., '500 1000')")
    
    args = parser.parse_args()
    
    # Find all merged checkpoint directories
    all_merged_dirs = find_merged_checkpoints(args.main_dir)
    
    if not all_merged_dirs:
        print(f"No merged checkpoint directories found in {args.main_dir}")
        exit(1)
    
    # Filter to specific checkpoints if requested
    if args.specific_checkpoints:
        filtered_dirs = []
        for checkpoint_number in args.specific_checkpoints:
            matching_dirs = [d for d in all_merged_dirs if f"checkpoint-{checkpoint_number}/" in d]
            filtered_dirs.extend(matching_dirs)
        merged_dirs = filtered_dirs
    else:
        merged_dirs = all_merged_dirs
    
    if not merged_dirs:
        print("No matching merged checkpoints found after filtering.")
        exit(1)
    
    print(f"Found {len(merged_dirs)} merged checkpoints to evaluate:")
    for dir in merged_dirs:
        print(f"  - {dir}")
    
    # Launch an evaluation job for each merged checkpoint
    for model_dir in merged_dirs:
        # Extract checkpoint number for organizing output
        checkpoint_num = model_dir.split('checkpoint-')[1].split('/')[0]
        
        # Create base save directory
        base_save_dir = os.path.join(
            args.main_dir, 
            f"checkpoint-{checkpoint_num}", 
            f"outputs/{args.data}_{args.data_split}"
        )
        
        # Ensure output directory exists
        os.makedirs(base_save_dir, exist_ok=True)
        
        launch_eval_job(
            base_save_dir=base_save_dir,
            model_dir=model_dir,
            model_name=args.model_name,
            batch_size=args.batch_size,
            max_new_tokens=args.max_new_tokens,
            data_split=args.data_split,
            data=args.data,
            num_generations=args.num_generations,
            JOB_MEMORY=args.job_memory,
            JOB_CPUS=args.job_cpus,
            JOB_GPUS=args.job_gpus,
            GPU_MEM=args.gpu_mem,
        )