import argparse, os, random
import torch
import numpy as np

from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from tqdm import tqdm
from utils import QUESTION_TMPL, extract_xy, in_box, to_list, compute_KDE, load_screenspot, load_screenspot_pro, dump


# ----------------------------- CLI ------------------------------------
p = argparse.ArgumentParser()
p.add_argument("--model_dir", help='Path to vLLM-compatible model directory')
p.add_argument('--dataset', type=str, required=True)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--top_k", type=int, default=-1)
p.add_argument("--n_samples", type=int, default=16)
p.add_argument("--crop_size", type=int, default=1400, help="ReGround LxL")
args = p.parse_args()

model_name = os.path.basename(os.path.normpath(args.model_dir))
file_kde   = f'{model_name}_{args.dataset}_scaling_T{args.temperature}_K{args.top_k}_L{args.crop_size}_KDE{args.n_samples*2}_0_01.jsonl'

# ---------------------- Reproducibility ------------------------------
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

# --------------------------- vLLM -------------------------------------
vllm = LLM(
    model=args.model_dir,
    tensor_parallel_size=4,
    gpu_memory_utilization=0.8,
    dtype=torch.bfloat16
)

sampling_params = SamplingParams(
    n=args.n_samples,
    temperature=args.temperature,
    top_k=args.top_k,
    max_tokens=1024,
    seed=0
)

# -------------------------- Dataset ---------------------------------
print(f'Loading {args.dataset} ...')

if args.dataset == 'screenspot':
    ds = load_screenspot()
elif args.dataset == 'screenspot-pro':
    ds = load_screenspot_pro()
else:
    raise ValueError

N  = len(ds)
processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-3B-Instruct')

# -------------------- Inference & Evaluation --------------------------
print(f"Running inference on {N} images ...")

records_kde = []

total_generated   = 0   # N * N_SAMPLES
missing_extracted  = 0  # answers where regex failed

for start in tqdm(range(0, N, args.batch_size)):
    batch_idx = list(range(start, min(start + args.batch_size, N)))

    # stage 1: full-image sampling
    llm_inputs = []
    for i in batch_idx:
        img = ds[i]["image"]
        w, h = img.size
        prompt = processor.apply_chat_template(
            [{"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": QUESTION_TMPL.format(
                    question=ds[i]['instruction'], w=w, h=h
                )}
            ]}], add_generation_prompt=True)
        
        llm_inputs.append({'prompt': prompt,
                           'multi_modal_data': {'image': img}})
        
    outputs = vllm.generate(llm_inputs, sampling_params)

    # stage 2: crop + sampling
    llm_inputs = []
    metas      = []
    for out, i in zip(outputs, batch_idx):
        img_size = ds[i]['image'].size
        W, H = img_size

        texts1  = [g.text.strip() for g in out.outputs]
        preds1 = [extract_xy(t) for t in texts1]
        pts1    = [c for c in preds1 if c[0] is not None]

        if pts1:
            kde_pred = compute_KDE(pts1, img_size)
            xi, yi = kde_pred
            
            # crop: ReGround with L bounded with min(W, H)
            L = min(args.crop_size, min(W, H))
            half = L // 2
            Lx, Ty = max(0, xi - half), max(0, yi - half)
            Rx, By = min(W, Lx + L), min(H, Ty + L)
            Lx, Ty = Rx - L, By - L
        else:
            Lx, Ty, Rx, By = 0, 0, W, H

        crop = ds[i]['image'].crop((Lx, Ty, Rx, By))
        prompt = processor.apply_chat_template(
            [{"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": QUESTION_TMPL.format(
                    question=ds[i]['instruction'], w=crop.width, h=crop.height
                )}
            ]}], add_generation_prompt=True)

        llm_inputs.append({'prompt': prompt,
                           'multi_modal_data': {'image': crop}})
        metas.append({
            "idx": i, "dx": Lx, "dy": Ty,
            "box": [Lx, Ty, Rx, By],
            "crop": crop,
            "pts1": pts1,
            "preds1": preds1,
            "texts1": texts1
        })

    outputs = vllm.generate(llm_inputs, sampling_params)

    # evaluate
    for m, out in zip(metas, outputs):
        i = m['idx']
        crop_size = m['crop'].size
        texts2  = [g.text.strip() for g in out.outputs]
        coords2 = [extract_xy(t) for t in texts2]
        total_generated  += len(m['preds1'])
        total_generated  += len(coords2)
        missing_extracted += sum(1 for c in m['preds1'] if c[0] is None)
        missing_extracted += sum(1 for c in coords2 if c[0] is None)

        pts2    = [c for c in coords2 if c[0] is not None]
        pts2g   = [(m["dx"] + x, m["dy"] + y) for x, y in pts2]

        bbox = ds[i]['bbox']
        img_size = ds[i]['image'].size
        W, H = img_size

        # metric: KDE
        kde_pred = compute_KDE(m["pts1"]+pts2g, img_size)

        if args.dataset == 'screenspot':
            correct_kde = int(in_box(kde_pred, bbox, img_size)) # relative bbox
            records_kde.append({
                "data_type"  : ds[i]["data_type"],
                "data_source": ds[i]["data_source"],
                "instruction": ds[i]["instruction"],
                "init_response"  : m["texts1"][0],
                "response"  : texts2[0],
                "init_predictions": [to_list(c) for c in m["preds1"]],
                "predictions": [to_list(c) for c in coords2],
                "crop_box"     : m["box"],
                "crop_offset"  : [m["dx"], m["dy"]],
                "prediction": to_list(kde_pred),
                "correct"   : correct_kde,
                "pipeline"  : f"sampling_crop_T{args.temperature}_K{args.top_k}_L{args.crop_size}_KDE{args.n_samples*2}_0_01"
            })
        
        elif args.dataset == 'screenspot-pro':
            correct_kde = int(in_box(kde_pred, bbox)) # absolute bbox
            records_kde.append({
                "ui_type"      : ds[i]["ui_type"],
                "id"           : ds[i]["id"],
                "application"  : ds[i]["application"],
                "platform"     : ds[i]["platform"],
                "instruction": ds[i]["instruction"],
                "init_response"  : m["texts1"][0],
                "response"  : texts2[0],
                "init_predictions": [to_list(c) for c in m["preds1"]],
                "predictions": [to_list(c) for c in coords2],
                "crop_box"     : m["box"],
                "crop_offset"  : [m["dx"], m["dy"]],
                "prediction": to_list(kde_pred),
                "correct"   : correct_kde,
                "pipeline"  : f"scaling_T{args.temperature}_K{args.top_k}_L{args.crop_size}_KDE{args.n_samples*2}_0_01"
            })

# ---------------------- Write result files ---------------------------
print('Writing results ...')

dump(file_kde, records_kde)
print('Saved.')

# --------------------- Extraction summary ----------------------------
miss_pct = 100 * missing_extracted / total_generated
print(f"Extraction failures: {missing_extracted} / {total_generated}")
print(f"{miss_pct:.2f}% failed to format answer")