import hydra
from omegaconf import DictConfig, OmegaConf
import subprocess
import os
import time
import sys


from itertools import product
from typing import List

# List your available GPU IDs here
GPUS = [0, 1, 2, 3]  # Change as needed


def build_train_cmd(cfg: DictConfig, output_dir: str) -> List[str]:
    project_root = os.path.dirname(os.path.abspath(__file__))
    strategy = cfg.train.strategy
    if strategy == "full":
        script_path = os.path.join(project_root, "nlp_training", "finetune_v2.py")
        cmd = [
            sys.executable, script_path,
            "--dataset_name", cfg.dataset.name,
            "--model_name", cfg.model.name,
            "--output_dir", output_dir,
            "--num_epochs", str(cfg.train.num_epochs),
            "--batch_size", str(cfg.model.batch_size),
            "--learning_rate", str(cfg.train.learning_rate),
            "--seed", str(cfg.seed),
            "--max_seq_length", str(cfg.train.max_seq_length),
            "--hf_token", str(cfg.hf_token),
        ]
        if cfg.debug:
            cmd.append("--debug")
    elif strategy == "head_only":
        script_path = os.path.join(project_root, "nlp_training", "finetune_v2.py")
        cmd = [
            sys.executable, script_path,
            "--dataset_name", cfg.dataset.name,
            "--model_name", cfg.model.name,
            "--output_dir", output_dir,
            "--num_epochs", str(cfg.train.num_epochs),
            "--batch_size", str(cfg.model.batch_size),
            "--learning_rate", str(cfg.train.learning_rate),
            "--seed", str(cfg.seed),
            "--max_seq_length", str(cfg.train.max_seq_length),
            "--hf_token", str(cfg.hf_token),
            "--freeze_base"
        ]
        if cfg.debug:
            cmd.append("--debug")
    elif strategy == "lora":
        script_path = os.path.join(project_root, "nlp_training", "lora_v2.py")
        cmd = [
            sys.executable, script_path,
            "--dataset_name", cfg.dataset.name,
            "--model_name", cfg.model.name,
            "--output_dir", output_dir,
            "--num_epochs", str(cfg.train.num_epochs),
            "--batch_size", str(cfg.model.batch_size),
            "--learning_rate", str(cfg.train.learning_rate),
            "--seed", str(cfg.seed),
            "--max_seq_length", str(cfg.train.max_seq_length),
            "--hf_token", cfg.hf_token,
        ]
        # Add LoRA-specific arguments if present
        if getattr(cfg.train, "lora_r", None) is not None:
            cmd.extend(["--lora_r", str(cfg.train.lora_r)])
        if getattr(cfg.train, "lora_alpha", None) is not None:
            cmd.extend(["--lora_alpha", str(cfg.train.lora_alpha)])
        if cfg.debug:
            cmd.append("--debug")
    else:
        raise ValueError(f"Unknown training strategy: {strategy}")
    return cmd


def run_experiment(cmd: List[str], gpu_id: int, device: str):
    env = os.environ.copy()
    if device == "cuda":
        env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    project_root = os.path.dirname(os.path.abspath(__file__))
    print(f"Launching on {device.upper()} (GPU {gpu_id if device == 'cuda' else ''}): {' '.join(cmd)}")
    subprocess.run(cmd, env=env, cwd=project_root)


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    # Compose output_dir based on config
    base_output_dir = cfg.output_dir
    if cfg.debug:
        base_output_dir = f"{base_output_dir}/debug/"
    # output_dir = f"{base_output_dir}/{cfg.train.strategy}/{cfg.model.name}_{cfg.dataset.name}_epochs_{cfg.train.num_epochs}_seed_{cfg.seed}"
    output_dir = f"{base_output_dir}/{cfg.dataset.name}/{cfg.model.name.replace('.', '_')}_epochs_{cfg.train.num_epochs}_seed_{cfg.seed}/{cfg.train.strategy}/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Build the command for the selected training mode
    cmd = build_train_cmd(cfg, output_dir)
    print(f"train cmd: {cmd} ")
    # Assign GPU (simple round-robin for now, can be improved for batch runs)
    device = getattr(cfg, "device", "cuda")
    gpu_id = 0 if not hasattr(cfg, 'gpu_id') else cfg.gpu_id
    run_experiment(cmd, gpu_id, device)

if __name__ == "__main__":
    main() 