import subprocess
import sys
import time

sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)

# COMMAND_TO_RUN = 'python -m scripts.train_selection_heads --model="meta-llama/Llama-3.3-70B-Instruct" --train_limit=2048 --validation_limit=1024 --n_epochs=10 --category="objects" --option_config="distinct" --task="select_one" --prompt_temp_idx=3 -v 2>&1 | tee select_obj_llama.log'
COMMAND_TO_RUN = 'python -m scripts.train_selection_heads --model="meta-llama/Llama-3.3-70B-Instruct" --train_limit=2048 --validation_limit=1024 --n_epochs=10 --category="objects" --option_config="distinct" --task="select_one" --prompt_temp_idx=3 --mcqify -v 2>&1 | tee select_obj_llama_mcq.log'
# COMMAND_TO_RUN = 'python -m scripts.train_selection_heads --model="meta-llama/Llama-3.3-70B-Instruct" --train_limit=2048 --validation_limit=1024 --n_epochs=10 --category="objects" --option_config="distinct" --task="counting" --prompt_temp_idx=1 -v 2>&1 | tee count_obj_llama.log'
# COMMAND_TO_RUN = 'python -m scripts.train_selection_heads --model="meta-llama/Llama-3.3-70B-Instruct" --train_limit=2048 --validation_limit=1024 --n_epochs=10 --category="objects" --option_config="distinct" --task="yes_no" --prompt_temp_idx=3 -v 2>&1 | tee yes_no_obj_llama.log'
# COMMAND_TO_RUN = 'python -m scripts.train_selection_heads --model="meta-llama/Llama-3.3-70B-Instruct" --train_limit=2048 --validation_limit=1024 --n_epochs=10 --category="objects" --option_config="distinct" --task="select_first" --prompt_temp_idx=3 -v 2>&1 | tee select_first_obj_llama.log'


# COMMAND_TO_RUN = 'python -m scripts.train_selection_heads --model="meta-llama/Llama-3.3-70B-Instruct" --train_limit=2048 --validation_limit=1024 --n_epochs=10 --category="objects" --option_config="distinct" --task="select_one" --save_dir="selection/ques_mixed" --prompt_temp_idx=-1 -v 2>&1 | tee ques_mixed.log'

# Memory threshold in GB
MEM_THRESHOLD = 42
CUDA_INDEX = 0

# Check interval in seconds (10 minutes)
CHECK_INTERVAL = 10 * 60


def get_gpu_free_memory():
    """Get free memory on cuda:0 in GB."""
    try:
        # Run nvidia-smi and capture output
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"],
            capture_output=True,
            text=True,
            check=True,
        )

        gpu_list = result.stdout.splitlines()

        # Parse the output to get the free memory in MB
        free_memory_mb = float(gpu_list[CUDA_INDEX].strip())

        # Convert to GB
        free_memory_gb = free_memory_mb / 1024.0

        return free_memory_gb
    except Exception as e:
        print(f"Error getting GPU memory: {e}")
        return 0


def main():
    print(
        f"Starting GPU monitor. Waiting for cuda:0 to have more than {MEM_THRESHOLD}GB free memory."
    )
    print(f"Will check every {CHECK_INTERVAL/60} minutes.")

    while True:
        free_memory = get_gpu_free_memory()
        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")

        print(f"[{timestamp}] Free GPU memory: {free_memory:.2f}GB")

        if free_memory > MEM_THRESHOLD:
            print(
                f"GPU has {free_memory:.2f}GB of free memory, which exceeds threshold of {MEM_THRESHOLD}GB."
            )
            print("=" * 80)
            print(f"Running command: {COMMAND_TO_RUN}")
            print("=" * 80)

            try:
                # Run the command when memory threshold is met
                subprocess.run(COMMAND_TO_RUN, shell=True, check=True)
                print("Command completed successfully. Exiting.")
                break
            except subprocess.CalledProcessError as e:
                print(f"Error running command: {e}")
                sys.exit(1)
        else:
            print(
                f"Not enough GPU memory available. Waiting {CHECK_INTERVAL/60} minutes before checking again..."
            )
            time.sleep(CHECK_INTERVAL)


if __name__ == "__main__":
    print(">>> Running GPU monitor <<<")
    print(f"{COMMAND_TO_RUN}")
    main()
