import os
import argparse
from pathlib import Path
import htcondor
import glob

JOB_BID_SINGLE = 25

def launch_merge_lora_job(
        base_model_dir,
        lora_checkpoint_dir,
        JOB_MEMORY=64,
        JOB_CPUS=8,
        JOB_GPUS=1,
        JOB_BID=JOB_BID_SINGLE,
        GPU_MEM=40000,
):
    # Name/prefix for cluster logs related to this job
    LOG_PATH = "/fast/XXXX-11/logs/forecasting/lora_merges"
    
    CLUSTER_LOGS_SAVE_DIR = Path(LOG_PATH)
    os.makedirs(CLUSTER_LOGS_SAVE_DIR, exist_ok=True)
    
    cluster_job_log_name = str(
        CLUSTER_LOGS_SAVE_DIR
        / f"$(Cluster).$(Process)"
    )

    executable = 'launch_merge_lora_job.sh'

    # Construct job description
    job_settings = {
        "executable": executable,
        "arguments": f"{base_model_dir} {lora_checkpoint_dir}",
        "output": f"{cluster_job_log_name}.out",
        "error": f"{cluster_job_log_name}.err",
        "log": f"{cluster_job_log_name}.log",
        
        "request_gpus": f"{JOB_GPUS}",
        "request_cpus": f"{JOB_CPUS}",
        "request_memory": f"{JOB_MEMORY}GB",
        "request_disk": f"{JOB_MEMORY}GB",
        
        "jobprio": f"{JOB_BID - 1000}",
        "notify_user": "XXXX-1.XXXX-2@tuebingen.mpg.de",
        "notification": "error",
    }

    if GPU_MEM is not None:
        job_settings["requirements"] = f"(TARGET.CUDAGlobalMemoryMb >= {GPU_MEM}) && (CUDACapability >= 8.0)"
    else:
        job_settings["requirements"] = "CUDACapability >= 8.0"

    job_description = htcondor.Submit(job_settings)

    # Submit job to scheduler
    schedd = htcondor.Schedd()
    submit_result = schedd.submit(job_description)

    print(
        f"Launched LoRA merge job with cluster-ID={submit_result.cluster()}, "
        f"proc-ID={submit_result.first_proc()}")
    print(f"Base model: {base_model_dir}")
    print(f"LoRA checkpoint: {lora_checkpoint_dir}")
    print(f"Merged output will be saved to: {os.path.join(lora_checkpoint_dir, 'merged')}")


def find_checkpoint_dirs(main_dir):
    """Find all checkpoint directories in the main directory."""
    # Use glob to find all directories starting with "checkpoint-"
    checkpoint_pattern = os.path.join(main_dir, "checkpoint-*")
    checkpoint_dirs = [d for d in glob.glob(checkpoint_pattern) if os.path.isdir(d)]
    
    # Sort by checkpoint number
    checkpoint_dirs.sort(key=lambda x: int(x.split('-')[-1]))
    
    return checkpoint_dirs


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Launch jobs to merge LoRA adapters with a base model")
    
    parser.add_argument('--base_model_dir', type=str, 
                       default="/is/cluster/fast/rolmedo/models/llama-3.1-8b-instruct/", 
                       help="Directory containing the base model")
    
    parser.add_argument('--main_dir', type=str, required=True,
                       help="Main directory containing checkpoint folders")
    
    parser.add_argument('--gpu_mem', type=int, default=40000,
                       help="Minimum GPU memory required in MB")
    
    parser.add_argument('--job_memory', type=int, default=64,
                       help="Job memory request in GB")
    
    parser.add_argument('--job_cpus', type=int, default=1,
                       help="Number of CPUs to request")
    
    parser.add_argument('--specific_checkpoints', type=str, nargs='+',
                       help="Specific checkpoint names to process (e.g., 'checkpoint-1000')")
    
    args = parser.parse_args()
    
    # Find all checkpoint directories in the main directory
    if args.specific_checkpoints:
        # Use only the specified checkpoints
        checkpoint_dirs = [os.path.join(args.main_dir, cp) for cp in args.specific_checkpoints]
        # Filter out non-existent directories
        checkpoint_dirs = [d for d in checkpoint_dirs if os.path.isdir(d)]
    else:
        # Find all checkpoint directories
        checkpoint_dirs = find_checkpoint_dirs(args.main_dir)
    
    if not checkpoint_dirs:
        print(f"No checkpoint directories found in {args.main_dir}")
        exit(1)
    
    print(f"Found {len(checkpoint_dirs)} checkpoint directories to process:")
    for dir in checkpoint_dirs:
        print(f"  - {dir}")
    
    # Launch a merge job for each checkpoint directory
    for lora_dir in checkpoint_dirs:
        launch_merge_lora_job(
            base_model_dir=args.base_model_dir,
            lora_checkpoint_dir=lora_dir,
            JOB_MEMORY=args.job_memory,
            JOB_CPUS=args.job_cpus,
            GPU_MEM=args.gpu_mem,
        )