import subprocess
import sys
from pathlib import Path

from shared_utils import *

MODEL_NAMES = [
    "Gemma2_9B",
    "Gemma2_27B",
    "Gemma3_1B",
    "Gemma3_4B",
    "Gemma3_12B",
    "Gemma3_27B",
    "Llama3.2_3B",
    "Llama3.1_8B",
    "Qwen3_8B",
    "Qwen3_14B",
    "Qwen3_32B",
    "Gemma2_9B_Base",
    "Gemma2_27B_Base",
    "Gemma3_1B_Base",
    "Gemma3_4B_Base",
    "Gemma3_12B_Base",
    "Gemma3_27B_Base",
    "Llama3.2_3B_Base",
    "Llama3.1_8B_Base",
    "Qwen3_8B_Base",
    "Qwen3_14B_Base",
    "Llama3.3_70B",
    "Llama3.3_70B_Base",
]

# Get the directory where this script is located
script_dir = Path(__file__).parent
delete_script = script_dir / "delete_hf_model.py"


# Use subprocess instead of os.system for better control
def delete_model_cache(model_name):
    """Delete model cache using the delete_hf_model.py script."""
    try:
        result = subprocess.run(
            [
                sys.executable,  # Use the same Python interpreter
                str(delete_script),
                model_name,
                "--force",
            ],
            capture_output=True,
            text=True,
            check=True,
        )
        print(f"Successfully deleted cache for {model_name}")
        print(result.stdout)
    except subprocess.CalledProcessError as e:
        print(f"Error deleting cache for {model_name}: {e}")
        print(f"stdout: {e.stdout}")
        print(f"stderr: {e.stderr}")
    except FileNotFoundError:
        print(f"delete_hf_model.py script not found at {delete_script}")


for model_name in MODEL_NAMES:
    model, tokenizer = get_model(model_name)
    # use the model to generate something random
    prompt = "A rhyming couplet:\nHe saw a carrot and had to grab it\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(
            inputs.input_ids,
            max_new_tokens=10,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
        outputs_probs = model(**inputs, output_hidden_states=False)
        probs = torch.nn.functional.softmax(outputs_probs.logits[:, :, :], dim=-1)
        print(probs.shape)
    for token_idx in inputs.input_ids[0]:
        print(tokenizer.decode(token_idx))
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"Generated: {generated_text}")
    print(f"Loaded {model_name}")
    del model
    del tokenizer
    cleanup_gpu_memory()
    # delete_model_cache(model_name)
