import os, json
import pandas as pd
from tqdm import tqdm
from utils.utils import *
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN

def eval_model(model, args):
    parquet_files = [
        os.path.join(args.data_path, f)
        for f in os.listdir(args.data_path)
        if f.endswith(".parquet") and f.startswith(args.split)
    ]
    df = pd.concat([pd.read_parquet(f) for f in parquet_files], ignore_index=True)

    json_dir = os.path.join(args.save_base, args.dataset_name)
    os.makedirs(json_dir, exist_ok=True)
    json_path = os.path.join(json_dir, f"{model.name}.json")

    results = []
    results_dict = {}

    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as f:
            try:
                results = json.load(f)
                results_dict = {
                    int(r["idx"]): r for r in results
                    if isinstance(r, dict) and "idx" in r
                }
            except json.JSONDecodeError:
                print(f"⚠️ Warning: JSON decode failed for {json_path}, starting fresh.")
                results, results_dict = [], {}
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Evaluating"):
        idx = int(idx)

        if idx in results_dict and results_dict[idx].get("pred", "") != "":
            continue

        question  = row["question"]
        img_bytes = row["image"]["bytes"]

        # ---- prompt ----
        prompt_ctx = (
            f"{DEFAULT_IMAGE_TOKEN}\n{question} "
            "Please use a single-word or phrase to answer the question in english."
        )
        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], prompt_ctx)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        max_retries = 3
        pred, raw_out, inputs = None, None, None

        for attempt in range(max_retries):
            if attempt == 3:
                prompt = f"{prompt}Answer: "
            model.generate_prompt(prompt)
            response = model.get_answer(img_bytes)
            raw_out, inputs, _ = response
            if raw_out is None:
                continue

            pred, _ = model.decode_outputs(inputs, temperature=1)
            if pred is not None and pred.strip() != "":
                break 
        if pred is None:
            continue

        # raw_out_clean = extract_clean_answer(raw_out)
        # pred_clean    = extract_clean_answer(pred)

        # # ---- GT answers ----
        if args.dataset_name == "VQAv2":
            gt_answers = [a["answer"].strip().lower() for a in row["answers"]]
        else:
            gt_answers = [a.lower() for a in row["answers"]]

        prev = results_dict.get(idx, {})
        record = dict(
            idx=idx,
            question=question,
            pred=pred if pred else prev.get("pred", ""),
            raw_out=raw_out if raw_out else prev.get("raw_out", ""),
            gt_answers=gt_answers
        )
        print(record)
        results_dict[idx] = record

        # ---- 实时保存 JSON ----
        with open(json_path, "w", encoding="utf-8") as f:
            payload = [results_dict[k] for k in sorted(results_dict.keys())]
            json.dump(payload, f, ensure_ascii=False, indent=2)
        
