# main.py
# -----------------------------------------------------------------------------
# Entry for draft & verdict stages
# -----------------------------------------------------------------------------

from __future__ import annotations

import argparse
import json
import logging
import pathlib
import random
from dataclasses import asdict
from typing import Any, Dict, Iterable, List, Optional
from PIL import Image

import numpy as np
import torch
from qwen_vl_utils import process_vision_info

from draft import (
    QAEntry,
    ReasoningResult,
    PrefillResult,
    run_reasoning,
    run_prefill,
    prepare_models,
)
from verdict import qwen_verdict, gpt4o_verdict

from prompts import detect_dataset_from_path, get_legacy_prompts, validate_prompts
from evaluate import compute_anls
from utils.post_process import extract_final_boxed_content, clean_think_tags, clean_answer

DEFAULT_FLUSH_EVERY = 10

def setup_logging(verbosity: int = 1) -> None:
    """Setup concise root logger."""
    level = logging.WARNING if verbosity <= 0 else logging.INFO if verbosity == 1 else logging.DEBUG
    logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(message)s")

def set_all_seeds(seed: Optional[int]) -> None:
    """For reproducibility."""
    if seed is None:
        logging.info("No seed provided (non-deterministic run).")
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    logging.info(f"Seeds set to {seed}.")

def iter_entries_auto(path: pathlib.Path) -> Iterable[Dict[str, Any]]:
    """Stream entries from JSON array or JSONL without loading entire file into memory."""
    with path.open("r", encoding="utf-8") as f:
        pos = f.tell()
        first = None
        while True:
            ch = f.read(1)
            if not ch:
                break
            if not ch.isspace():
                first = ch
                break
        f.seek(pos)

        if first == "[":
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError("Top-level JSON must be an array.")
            for obj in data:
                if isinstance(obj, dict):
                    yield obj
        else:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                obj = json.loads(line)
                if isinstance(obj, dict):
                    yield obj

def safe_write_json(obj: Any, path: pathlib.Path) -> None:
    """Pretty-print JSON to disk."""
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as fo:
        json.dump(obj, fo, ensure_ascii=False, indent=2)
        fo.write("\n")

def to_QAEntry(d: Dict[str, Any], dataset: Optional[str] = None) -> QAEntry:
    """Normalize raw dict from input file into QAEntry."""
    image_path = d.get("image_path")
    question = d.get("question")
    answers = d.get("answers") or d.get("ground_truths")
    extra = {k: v for k, v in d.items() if k not in {"image_path", "question", "answers", "ground_truths"}}
    
    if not isinstance(image_path, str) or not isinstance(question, str):
        raise ValueError("Each entry must contain 'image_path' (str) and 'question' (str).")
    if answers is not None and not isinstance(answers, list):
        raise ValueError("'answers' must be a list if provided.")
        
    return QAEntry(image_path=image_path, question=question, answers=answers, extra=extra)

def run_reasoning_batch(
    in_path: pathlib.Path,
    out_path: pathlib.Path,
    *,
    model_paths: List[str],
    mode: str = "baseline",
    dataset: Optional[str] = None,
    start_idx: int = 0,
    max_entries: Optional[int] = None,
    flush_every: int = DEFAULT_FLUSH_EVERY,
) -> None:
    """Execute per-model generative reasoning over a dataset."""
    models = prepare_models(model_paths, max_tokens=2048)
    results: List[Dict[str, Any]] = []

    processed = 0
    for idx, raw in enumerate(iter_entries_auto(in_path)):
        if idx < start_idx:
            continue
        if max_entries is not None and processed >= max_entries:
            break

        entry = to_QAEntry(raw, dataset)
        out: ReasoningResult = run_reasoning(models, entry, mode=mode, q_idx=idx + 1, dataset=dataset)
        results.append(asdict(out))

        processed += 1
        if processed % flush_every == 0:
            logging.info(f"[Reasoning-{dataset}] Flushing @ {processed} → {out_path}")
            safe_write_json(results, out_path)
        torch.cuda.empty_cache()

    safe_write_json(results, out_path)
    logging.info(f"[Reasoning-{dataset}] Completed {processed} examples. Output: {out_path}")

def run_prefill_batch(
    in_path: pathlib.Path,
    out_path: pathlib.Path,
    *,
    model_paths: List[str],
    mode: str = "decode",
    source_key: str = "models_reasoning",
    running_model: Optional[str] = None,
    dataset: Optional[str] = None,
    start_idx: int = 0,
    max_entries: Optional[int] = None,
    flush_every: int = DEFAULT_FLUSH_EVERY,
) -> None:
    """Execute prefill scoring (self/cross PPL) over a dataset."""
    assert mode in {"decode", "cross"}
    models = prepare_models(model_paths, max_tokens=2048)
    results: List[Dict[str, Any]] = []

    processed = 0
    for idx, raw in enumerate(iter_entries_auto(in_path)):
        if idx < start_idx:
            continue
        if max_entries is not None and processed >= max_entries:
            break

        entry = to_QAEntry(raw, dataset)
        out: PrefillResult = run_prefill(
            models, entry, mode=mode, source=source_key, 
            running_model=running_model, dataset=dataset
        )
        results.append(asdict(out))

        processed += 1
        if processed % flush_every == 0:
            logging.info(f"[Prefill-{mode}-{dataset}] Flushing @ {processed} → {out_path}")
            safe_write_json(results, out_path)
        torch.cuda.empty_cache()

    safe_write_json(results, out_path)
    logging.info(f"[Prefill-{mode}-{dataset}] Completed {processed} examples. Output: {out_path}")

def run_verdict_batch(
    in_path: pathlib.Path,
    out_path: pathlib.Path,
    *,
    verdict_model_path: str,
    annotated_folder: str = "",
    dataset: Optional[str] = None,
    start_idx: int = 0,
    max_entries: Optional[int] = None,
    flush_every: int = DEFAULT_FLUSH_EVERY,
) -> None:
    """Execute verdict over a dataset with existing model reasoning."""
    from draft import load_vlm
    
    model, processor, tokenizer, tag = load_vlm(verdict_model_path)
    results: List[Dict[str, Any]] = []

    processed = 0
    for idx, raw in enumerate(iter_entries_auto(in_path)):
        if idx < start_idx:
            continue
        if max_entries is not None and processed >= max_entries:
            break

        entry_dict = dict(raw)
        
        # Ensure required fields exist
        if "models_reasoning" not in entry_dict:
            raise ValueError(f"Entry {idx} missing 'models_reasoning' field required for verdict")

        full_response, final_answer = qwen_verdict(
            model=model,
            processor=processor,
            question=entry_dict["question"],
            answers_dict=entry_dict["models_reasoning"],
            orig_img_path=entry_dict["image_path"],
            annotated_folder=annotated_folder,
            dataset=dataset, 
            device="cuda"
        )

        entry_dict["final_reasoning"] = full_response
        entry_dict["final_answer"] = final_answer
        entry_dict["anls"] = compute_anls(
            final_answer, 
            entry_dict.get("ground_truths", []) or entry_dict.get("answers", [])
        )
        entry_dict["dataset"] = dataset

        results.append(entry_dict)

        processed += 1
        if processed % flush_every == 0:
            logging.info(f"[verdict-{dataset}] Flushing @ {processed} → {out_path}")
            safe_write_json(results, out_path)
        torch.cuda.empty_cache()

    safe_write_json(results, out_path)
    logging.info(f"[verdict-{dataset}] Completed {processed} examples. Output: {out_path}")


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Run the 'Draft' stage with dataset-aware prompts: (1) reasoning, (2) prefill scoring, (3) verdict."
    )
    p.add_argument("--in_json", required=True, help="Input file (JSON array or JSONL).")
    p.add_argument("--out_json", required=True, help="Output JSON path.")
    p.add_argument("--mode",
                   choices=["reasoning", "prefill_decode", "prefill_cross", "verdict"],
                   default="reasoning",
                   help="Draft substage to run.")
    p.add_argument("--models", nargs="+", required=True,
                   help="Model paths (space-separated). For verdict, first one is the verdict model.")
    
    # Dataset specification
    p.add_argument("--dataset", choices=["infovqa", "hrbench", "museum", "pro"], default=None,
                   help="Dataset type (auto-detected if not specified).")
    
    # Mode-specific parameters
    p.add_argument("--reasoning_mode", choices=["models", "baseline", "ocr"], default="baseline",
                   help="[reasoning] Sub-mode for reasoning stage.")
    p.add_argument("--source_key", default="models_reasoning",
                   help="[prefill_cross] key where per-model answers live.")
    p.add_argument("--running_model", default=None,
                   help="[prefill_cross] restrict cross-eval target.")
    p.add_argument("--annotated_folder", default="",
                   help="[verdict] folder containing layout-annotated images.")
    
    # Processing parameters
    p.add_argument("--start_idx", type=int, default=0,
                   help="Start from this 0-based index in the dataset.")
    p.add_argument("--max_entries", type=int, default=None,
                   help="Process at most N entries.")
    p.add_argument("--flush_every", type=int, default=DEFAULT_FLUSH_EVERY,
                   help="Flush partial results every N entries.")
    p.add_argument("--seed", type=int, default=None,
                   help="Random seed (sets torch/cuda/numpy/python).")
    p.add_argument("-v", "--verbose", action="count", default=1,
                   help="Increase logging verbosity (-v: INFO, -vv: DEBUG).")
                   
    return p.parse_args()

def main() -> None:
    args = parse_args()
    setup_logging(args.verbose)
    set_all_seeds(args.seed)

    in_path = pathlib.Path(args.in_json)
    out_path = pathlib.Path(args.out_json)
    model_paths = args.models

    if args.mode == "reasoning":
        run_reasoning_batch(
            in_path=in_path,
            out_path=out_path,
            model_paths=model_paths,
            mode=args.reasoning_mode,
            dataset=args.dataset,
            start_idx=args.start_idx,
            max_entries=args.max_entries,
            flush_every=args.flush_every,
        )
    elif args.mode == "prefill_decode":
        run_prefill_batch(
            in_path=in_path,
            out_path=out_path,
            model_paths=model_paths,
            mode="decode",
            dataset=args.dataset,
            start_idx=args.start_idx,
            max_entries=args.max_entries,
            flush_every=args.flush_every,
        )
    elif args.mode == "prefill_cross":
        run_prefill_batch(
            in_path=in_path,
            out_path=out_path,
            model_paths=model_paths,
            mode="cross",
            source_key=args.source_key,
            running_model=args.running_model,
            dataset=args.dataset,
            start_idx=args.start_idx,
            max_entries=args.max_entries,
            flush_every=args.flush_every,
        )
    elif args.mode == "verdict":
        if not model_paths:
            raise ValueError("verdict mode requires at least one model path (the verdict model)")
        run_verdict_batch(
            in_path=in_path,
            out_path=out_path,
            verdict_model_path=model_paths[0],  # Use first model as verdict model
            annotated_folder=args.annotated_folder,
            dataset=args.dataset,
            start_idx=args.start_idx,
            max_entries=args.max_entries,
            flush_every=args.flush_every,
        )
    else:
        raise ValueError(f"Unsupported mode: {args.mode}")

if __name__ == "__main__":
    main()