import torch, json,os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login
login("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")


# --- paths ----------------------------------------------------------------
model_id     = "meta-llama/Llama-3.2-1B-Instruct"           # base model
adapter_ckpt = os.path.expanduser("~/messed_up/llama3_8b_run1/lora-adapter")




# --- 4-bit quantisation identical to training -----------------------------
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

# --- tokenizer ------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tokenizer.padding_side = "left"        # keep alignment with training

# --- base model on GPU 0 -------------------------------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,
    device_map={"": 0},                # put **every** sub-module on GPU 0
)

# --- load LoRA adapter & prep for inference -------------------------------
model = PeftModel.from_pretrained(base_model, adapter_ckpt)
model.gradient_checkpointing_disable()  # crucial for fast generation
model.config.use_cache = True           # re-enable KV cache
model.eval()
import  psutil
print("CUDA available       :", torch.cuda.is_available())
print("Model param lives on :", next(model.parameters()).device)
print("CUDA_VISIBLE_DEVICES :", os.getenv("CUDA_VISIBLE_DEVICES", "<not set>"))
# --- quick test prompt ----------------------------------------------------

prompt = (
    "### Instruction:\n"
    "Generate optimized traffic signal plan for a broken car near west left turn intersection. Scenario #10\n\n"
    "### Response:\n"
)



inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    output = model.generate(
        **inputs,
    max_new_tokens=600,
    do_sample=False,
    top_p=0.9,
    temperature=0.8,
 
    )

print(tokenizer.decode(output[0], skip_special_tokens=True))
