import htcondor
from pathlib import Path
import argparse

# Job bidding constants
JOB_BID_SINGLE = 25
JOB_BID_MULTI = 400

def launch_mcq_classifier_job(
        dataset_name,
        model_name,
        output_dir,
        wandb_project,
        wandb_run_name,
        num_options=4,
        train_dataset=None,
        test_dataset=None,
        train_ratio=0.5,
        test_ratio=0.5,
        split_test_set=False,
        max_seq_length=512,
        learning_rate=1e-5,
        classifier_lr=1e-4,
        train_batch_size=8,
        eval_batch_size=16,
        gradacc_steps=8,
        num_train_epochs=60,
        eval_steps=100,
        save_steps=500,
        freeze_base_model=False,
        random_init=False,
        only_options=False,
        seed=42,
        warmup_ratio=0.1,
        lr_scheduler="constant_with_warmup",
        option_sampling_strategy="both",
        JOB_MEMORY=32,
        JOB_CPUS=4,
        JOB_GPUS=1,
        JOB_BID=JOB_BID_SINGLE,
        GPU_MEM=None,
):
    # Name/prefix for cluster logs related to this job
    LOG_PATH = "/fast/XXXX-3/logs/classification/mcq/mmlu"

    CLUSTER_LOGS_SAVE_DIR = Path(LOG_PATH)
    CLUSTER_LOGS_SAVE_DIR.mkdir(parents=True, exist_ok=True)
    
    cluster_job_log_name = str(
        CLUSTER_LOGS_SAVE_DIR
        / f"$(Cluster).$(Process)"
    )

    executable = 'scripts/launch_classifier.sh'

    # Build arguments string - FIXING THIS PART
    # Remove the positional arguments approach and use named arguments for everything
    cmd_args = [
        f"--dataset_name {dataset_name}",
        f"--model_name {model_name}",
    ]
    
    # Add optional arguments that might be None
    if output_dir is not None:
        cmd_args.append(f"--output_dir {output_dir}")
    if wandb_project is not None:
        cmd_args.append(f"--wandb_project {wandb_project}")
    if wandb_run_name is not None:
        cmd_args.append(f"--wandb_run_name {wandb_run_name}")
    
    # Add all other arguments
    cmd_args.append(f"--train_ratio {train_ratio}")
    cmd_args.append(f"--test_ratio {test_ratio}")
    cmd_args.append(f"--num_train_epochs {num_train_epochs}")
    cmd_args.append(f"--gradacc_steps {gradacc_steps}")
    
    # Add optional boolean flags
    if train_dataset:
        cmd_args.append(f"--train_dataset {train_dataset}")
    if test_dataset:
        cmd_args.append(f"--test_dataset {test_dataset}")
    if split_test_set:
        cmd_args.append("--split_test_set")
    if freeze_base_model:
        cmd_args.append("--freeze_base_model")
    if random_init:
        cmd_args.append("--random_init")
    if only_options:
        cmd_args.append("--only_options")
        
    # Add remaining arguments
    cmd_args.append(f"--num_options {num_options}")
    cmd_args.append(f"--max_seq_length {max_seq_length}")
    cmd_args.append(f"--learning_rate {learning_rate}")
    cmd_args.append(f"--classifier_lr {classifier_lr}")
    cmd_args.append(f"--train_batch_size {train_batch_size}")
    cmd_args.append(f"--eval_batch_size {eval_batch_size}")
    cmd_args.append(f"--eval_steps {eval_steps}")
    cmd_args.append(f"--save_steps {save_steps}")
    cmd_args.append(f"--seed {seed}")
    cmd_args.append(f"--warmup_ratio {warmup_ratio}")
    cmd_args.append(f"--lr_scheduler {lr_scheduler}")
    cmd_args.append(f"--option_sampling_strategy {option_sampling_strategy}")
    
    # Combine all arguments
    arguments = " ".join(cmd_args)

    # Construct job description
    job_settings = {
        "executable": executable,
        "arguments": arguments,
        
        "output": f"{cluster_job_log_name}.out",
        "error": f"{cluster_job_log_name}.err",
        "log": f"{cluster_job_log_name}.log",
        
        "request_gpus": f"{JOB_GPUS}",
        "request_cpus": f"{min(JOB_CPUS, 48)}",
        "request_memory": f"{JOB_MEMORY}GB",
        
        "jobprio": f"{JOB_BID - 1000}",
        "notify_user": "XXXX-12.XXXX-10@tuebingen.mpg.de",
        "notification": "error",
    }

    if GPU_MEM is not None:
        job_settings["requirements"] = f"(TARGET.CUDAGlobalMemoryMb >= {GPU_MEM}) && (CUDACapability >= 8.0) && (TARGET.Machine != \"g125.internal.cluster.is.localnet\") && (TARGET.Machine != \"g147.internal.cluster.is.localnet\") && (TARGET.Machine != \"g136.internal.cluster.is.localnet\")"
    else:
        job_settings["requirements"] = "(CUDACapability >= 8.0) && (TARGET.Machine != \"g125.internal.cluster.is.localnet\") && (TARGET.Machine != \"g147.internal.cluster.is.localnet\") && (TARGET.Machine != \"g136.internal.cluster.is.localnet\")"

    job_description = htcondor.Submit(job_settings)

    # Submit job to scheduler
    schedd = htcondor.Schedd()
    submit_result = schedd.submit(job_description)

    print(
        f"Launched MCQ classifier job with cluster-ID={submit_result.cluster()}, "
        f"proc-ID={submit_result.first_proc()}")
    
    return submit_result


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Launch MCQ classifier training job")
    
    # Data parameters
    parser.add_argument("--dataset_name", type=str, default="mmlu_pro",
                        help="Name of the dataset to use")
    parser.add_argument("--train_dataset", type=str, default=None,
                        help="Optional: Separate dataset to use for training")
    parser.add_argument("--test_dataset", type=str, default=None,
                        help="Optional: Separate dataset to use for testing")
    parser.add_argument("--num_options", type=int, default=4,
                        help="Number of options for MCQ")
    parser.add_argument("--train_ratio", type=float, default=0.5,
                        help="Proportion of data for training")
    parser.add_argument("--test_ratio", type=float, default=0.5,
                        help="Proportion of data for testing")
    parser.add_argument("--split_test_set", action="store_true",
                        help="Split test set into train/test when no training set is available")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--only_options", action="store_true",
                        help="Include only options in the prompt (no question)")
    parser.add_argument("--option_sampling_strategy", type=str, default="both", 
                        choices=["both", "correct", "incorrect"],
                        help="Strategy for sampling additional options")
    
    # Model parameters
    parser.add_argument("--model_name", type=str, default="microsoft/deberta-v3-large",
                        help="Name of the pretrained model")
    parser.add_argument("--max_seq_length", type=int, default=512,
                        help="Maximum sequence length for tokenization")
    parser.add_argument("--freeze_base_model", action="store_true",
                        help="Freeze base model and only train classifier head")
    parser.add_argument("--random_init", action="store_true",
                        help="Initialize model with random weights")
    
    # Training parameters
    parser.add_argument("--output_dir", type=str, default=None,
                        help="Directory to save model outputs (if None, auto-generated)")
    parser.add_argument("--learning_rate", type=float, default=5e-6,
                        help="Learning rate for optimization (backbone)")
    parser.add_argument("--classifier_lr", type=float, default=1e-4,
                        help="Learning rate for classifier head (default: 1e-4)")
    parser.add_argument("--train_batch_size", type=int, default=8,
                        help="Batch size for training")
    parser.add_argument("--eval_batch_size", type=int, default=16,
                        help="Batch size for evaluation")
    parser.add_argument("--gradacc_steps", type=int, default=8,
                        help="Number of gradient accumulation steps")
    parser.add_argument("--num_train_epochs", type=int, default=60,
                        help="Number of training epochs")
    parser.add_argument("--eval_steps", type=int, default=100,
                        help="Steps between evaluations")
    parser.add_argument("--save_steps", type=int, default=500,
                        help="Steps between model saves")
    parser.add_argument("--warmup_ratio", type=float, default=0.1,
                        help="Ratio of total steps for warmup")
    parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup",
                        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
                        help="Learning rate scheduler type")
    
    # Wandb parameters
    parser.add_argument("--wandb_project", type=str, default="cawn",
                        help="Wandb project name")
    parser.add_argument("--wandb_run_name", type=str, default=None,
                        help="Wandb run name (defaults to output_dir basename)")
    
    # Job parameters
    parser.add_argument("--job_memory", type=int, default=None,
                        help="Memory in GB per CPU")
    parser.add_argument("--job_cpus", type=int, default=4,
                        help="Number of CPUs per GPU")
    parser.add_argument("--job_gpus", type=int, default=1,
                        help="Number of GPUs")
    parser.add_argument("--gpu_mem", type=int, default=35000,
                        help="Minimum GPU memory in MB")
    
    args = parser.parse_args()

    if args.job_memory is None:
        args.job_memory = 16 * args.job_cpus
    
    # Launch the job
    launch_mcq_classifier_job(
        dataset_name=args.dataset_name,
        model_name=args.model_name,
        output_dir=args.output_dir,
        wandb_project=args.wandb_project,
        wandb_run_name=args.wandb_run_name,
        num_options=args.num_options,
        train_dataset=args.train_dataset,
        test_dataset=args.test_dataset,
        train_ratio=args.train_ratio,
        test_ratio=args.test_ratio,
        split_test_set=args.split_test_set,
        max_seq_length=args.max_seq_length,
        learning_rate=args.learning_rate,
        classifier_lr=args.classifier_lr,
        train_batch_size=args.train_batch_size,
        eval_batch_size=args.eval_batch_size,
        gradacc_steps=args.gradacc_steps,
        num_train_epochs=args.num_train_epochs,
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        freeze_base_model=args.freeze_base_model,
        random_init=args.random_init,
        only_options=args.only_options,
        seed=args.seed,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler=args.lr_scheduler,
        option_sampling_strategy=args.option_sampling_strategy,
        JOB_MEMORY=args.job_memory,
        JOB_CPUS=args.job_cpus,
        JOB_GPUS=args.job_gpus,
        GPU_MEM=args.gpu_mem)