from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import torch
import sys 
import json 
import re 

from tqdm import tqdm 


import os
from pathlib import Path

import htcondor
import time
import yaml

JOB_BID_SINGLE = 25 # 100
JOB_BID_MULTI = 400 

def launch_lm_label_job(
        JOB_MEMORY,
        JOB_CPUS,
        JOB_GPUS=1,
        JOB_BID=JOB_BID_SINGLE,
        GPU_MEM=None,
        job_id=None,
        config_overrides=None,
):
    # Name/prefix for cluster logs related to this job
    LOG_PATH = "/fast/XXXX-11/logs/forecasting/sft/lora/"

    CLUSTER_LOGS_SAVE_DIR=Path(LOG_PATH)
    cluster_job_log_name = str(
        CLUSTER_LOGS_SAVE_DIR
        / f"$(Cluster).$(Process)"
    )

    # Create a unique job directory for this run
    if job_id is None:
        job_id = f"job_{int(time.time())}"
    
    job_dir = Path(f"/is/cluster/XXXX-11/forecasting-rl/sft/jobs/{job_id}")
    job_dir.mkdir(parents=True, exist_ok=True)
    
    # Create a copy of the shell script for this job
    original_script_path = "/is/cluster/XXXX-11/forecasting-rl/sft/launch_lora_job1gpu.sh"
    job_script_path = job_dir / "launch_lora_job1gpu.sh"
    
    with open(original_script_path, 'r') as f:
        script_content = f.read()
    
    # Modify the script to use the job-specific config
    script_content = script_content.replace(
        "llamafactory-cli train ../configs/lora_forecasting_sft_llama.yaml",
        f"llamafactory-cli train {job_dir}/lora_config.yaml"
    )
    
    with open(job_script_path, 'w') as f:
        f.write(script_content)
    
    # Make the script executable
    os.chmod(job_script_path, 0o755)
    
    # Create a copy of the config file with any overrides
    original_config_path = "configs/lora_forecasting_sft_llama.yaml"
    job_config_path = job_dir / "lora_config.yaml"
    
    with open(original_config_path, 'r') as f:
        config_content = f.read()
    
    # Apply any config overrides
    if config_overrides:
        config_dict = yaml.safe_load(config_content)
        for key, value in config_overrides.items():
            config_dict[key] = value
        
        # Update run_name to include job_id for easier tracking
        if 'run_name' in config_dict:
            config_dict['run_name'] = f"{config_dict['run_name']}_{job_id}"
        
        config_content = yaml.dump(config_dict, default_flow_style=False)
    
    with open(job_config_path, 'w') as f:
        f.write(config_content)

    # Construct job description
    job_settings = {
        "executable": str(job_script_path),  # Use the job-specific script
        
        "output": f"{cluster_job_log_name}.out",
        "error": f"{cluster_job_log_name}.err",
        "log": f"{cluster_job_log_name}.log",
        
        "request_gpus": f"{JOB_GPUS}",
        "request_cpus": f"{max(JOB_CPUS*JOB_GPUS, 32)}",  # how many CPU cores we want
        "request_memory": f"{JOB_MEMORY*JOB_GPUS}GB",  # how much memory we want
        
        "jobprio": f"{JOB_BID - 1000}",
        "notify_user": "XXXX-12.XXXX-10@tuebingen.mpg.de",
        "notification": "error",
    }

    if GPU_MEM is not None:
        job_settings["requirements"] = f"(TARGET.CUDAGlobalMemoryMb >= {GPU_MEM}) && (CUDACapability >= 8.0)"
    else:
        job_settings["requirements"] = "CUDACapability >= 8.0"

    job_description = htcondor.Submit(job_settings)

    # Submit job to scheduler
    schedd = htcondor.Schedd()
    submit_result = schedd.submit(job_description)

    print(
        f"Launched experiment with cluster-ID={submit_result.cluster()}, "
        f"proc-ID={submit_result.first_proc()}, job_id={job_id}")
    
    return job_id

if __name__ == "__main__":
    # from weak_models_utils import models
    # import argparse
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--base_save_dir', type=str, required=True)  # e.g., /fast/groups/sf/ttt/evaluations/base/

    # args = parser.parse_args()
    GPU_MEM = 65000
    
    # Example of launching multiple jobs with different configurations
    # Job 1
    launch_lm_label_job(
        JOB_MEMORY=64,
        JOB_CPUS=1,
        JOB_GPUS=1,
        JOB_BID=JOB_BID_SINGLE,
        GPU_MEM=GPU_MEM,
        job_id="llama3_lora64",
        config_overrides={
            "max_samples": None,
            "lora_rank": 64,
            "output_dir": "/fast/XXXX-3/forecasting/sft/llama3.1-8b/lora64/",
            "run_name": "llama3.1-8b-forecasting-lora64"
        }
    )
    
    # Job 2 - Uncomment to launch with different parameters
    # launch_lm_label_job(
    #     JOB_MEMORY=80,
    #     JOB_CPUS=1,
    #     JOB_GPUS=1,
    #     JOB_BID=JOB_BID_SINGLE,
    #     GPU_MEM=GPU_MEM,
    #     job_id="llama3_lora32_run1",
    #     config_overrides={
    #         "lora_rank": 32,
    #         "output_dir": "/fast/XXXX-3/forecasting/sft/llama3.1-8b/lora32_run1/",
    #         "run_name": "llama3.1-8b-forecasting-lora32-run1"
    #     }
    # )
        # break 
