import argparse, os, random
import torch
import numpy as np

from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from tqdm import tqdm
from utils import QUESTION_TMPL, extract_xy, in_box, load_screenspot, load_screenspot_pro, dump


# ----------------------------- CLI ------------------------------------
p = argparse.ArgumentParser()
p.add_argument('--model_dir', required=True, help='Path to vLLM-compatible model directory')
p.add_argument('--dataset', type=str, required=True)
p.add_argument('--batch_size', type=int, default=32)

args = p.parse_args()

model_name = os.path.basename(os.path.normpath(args.model_dir))
file = f'{model_name}_{args.dataset}_baseline.jsonl'

# ---------------------- Reproducibility -------------------------------
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

# --------------------------- vLLM -------------------------------------
vllm = LLM(
    model=args.model_dir,
    tensor_parallel_size=4,
    gpu_memory_utilization=0.8,
    dtype=torch.bfloat16,
)

sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=1024,
    seed=0
)

# -------------------------- Dataset ---------------------------------
print(f'Loading {args.dataset} ...')

if args.dataset == 'screenspot':
    ds = load_screenspot()
elif args.dataset == 'screenspot-pro':
    ds = load_screenspot_pro()
else:
    raise ValueError

N  = len(ds)
processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-3B-Instruct')

# -------------------- Inference & Evaluation --------------------------
print(f"Running inference on {N} images ...")

records = []

for start in tqdm(range(0, N, args.batch_size)):
    batch_idx = list(range(start, min(start + args.batch_size, N)))

    # build prompts for vLLM
    llm_inputs = []
    for i in batch_idx:
        img = ds[i]['image']
        w, h = img.size
        prompt = processor.apply_chat_template(
            [{"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": QUESTION_TMPL.format(
                    question=ds[i]['instruction'], w=w, h=h
                )}
            ]}], add_generation_prompt=True)

        llm_inputs.append({'prompt': prompt,
                           'multi_modal_data': {'image': img}})

    outputs = vllm.generate(llm_inputs, sampling_params)

    # evaluate
    for out, i in zip(outputs, batch_idx):
        text = out.outputs[0].text.strip()
        img_size = ds[i]['image'].size
        pred = extract_xy(text)

        if args.dataset == 'screenspot':
            correct = int(in_box(pred, ds[i]['bbox'], img_size)) # relative bbox
            records.append({
                "data_type"  : ds[i]["data_type"],
                "data_source": ds[i]["data_source"],
                "instruction": ds[i]["instruction"],
                "response"   : text,
                "prediction" : pred,
                "correct"    : correct,
                "pipeline"   : "baseline"
            })
        elif args.dataset == 'screenspot-pro':
            correct = int(in_box(pred, ds[i]['bbox'])) # absolute bbox
            records.append({
                "ui_type"    : ds[i]["ui_type"],
                "id"         : ds[i]["id"],
                "application": ds[i]["application"],
                "platform"   : ds[i]["platform"],
                "instruction": ds[i]["instruction"],
                "response"   : text,
                "prediction" : pred,
                "correct"    : correct,
                "pipeline"   : "baseline"
            })

# ---------------------- Write result files ---------------------------
print('Writing results ...')

dump(file, records)
print('Saved.')