# draft.py
# -----------------------------------------------------------------------------
# Draft stage: (1) generative reasoning per model; (2) consensus scoring.
# Two core functions: run_reasoning() and run_prefill()
# -----------------------------------------------------------------------------

from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
from PIL import Image
import re
import json
import torch
import time

from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForCausalLM, Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText, Glm4vForConditionalGeneration

from evaluate import compute_anls
from utils.post_process import (
    extract_final_boxed_content,
    clean_think_tags,
    clean_answer,
    glm_extract_boxed,
)
from prompts import get_prompt, detect_dataset_from_path, SCORING_PROMPT

@dataclass
class QAEntry:
    """Input entry for pipeline processing."""
    image_path: str
    question: str
    answers: Optional[List[str]] = None
    extra: Optional[Dict[str, Any]] = None

@dataclass
class PerModelRecord:
    """Per-model output for generative reasoning."""
    reasoning: str
    answer: str
    anls: float

@dataclass
class ReasoningResult:
    """Output of run_reasoning: per-model reasoning and answers."""
    image_path: str
    question: str
    models_reasoning: Dict[str, PerModelRecord]
    final_reasoning: str = "NO_FUSION"
    final_answer: str = "NO_FUSION"
    anls: float = 0.0
    ground_truths: Optional[List[str]] = None

@dataclass
class PrefillResult:
    """Output of run_prefill: self and cross prefill scores."""
    image_path: str
    question: str
    models: List[str]
    answers: Dict[str, Dict[str, Any]]
    self_ppl: Dict[str, float]
    cross_ppl: Optional[Dict[str, Dict[str, float]]] = None

# Model Loading
def load_vlm(model_path: str, device="cuda", dtype="bfloat16"):
    """Load VLM and return (model, processor, tokenizer, tag)"""
    if "qwen2.5-vl" in model_path.lower():
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=getattr(torch, dtype), device_map="auto",
            low_cpu_mem_usage=True, attn_implementation="flash_attention_2"
        ).eval()
        MIN_PIXELS = 1280 * 28 * 28
        MAX_PIXELS = 16384 * 28 * 28
        proc = AutoProcessor.from_pretrained(model_path, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS)
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        tokenizer.padding_side = "left" 
        proc.tokenizer = tokenizer
        return model, proc, tokenizer, "qwen"

    elif "glm" in model_path.lower():
        processor = AutoProcessor.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="auto"
        )
        return model, processor, processor.tokenizer, "glm"

    elif "ovis" in model_path.lower():
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
        ).cuda()
        tokenizer = model.text_tokenizer
        model.eval()
        return model, None, tokenizer, "ovis"

    elif "mimo" in model_path.lower():
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=getattr(torch, dtype), device_map="auto",
            low_cpu_mem_usage=True, attn_implementation="flash_attention_2"
        ).eval()
        proc = AutoProcessor.from_pretrained(model_path, min_pixels=0, max_pixels=4096*28*28)
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        tokenizer.padding_side = "left"
        proc.tokenizer = tokenizer
        return model, proc, tokenizer, "mimo"

    elif "intern" in model_path.lower():
        processor = AutoProcessor.from_pretrained(model_path)
        model = AutoModelForImageTextToText.from_pretrained(
            model_path, device_map=device, torch_dtype=torch.bfloat16
        )
        return model, processor, processor.tokenizer, "internvl"

    else:
        raise ValueError(f"Unsupported model path: {model_path}")

def prepare_models(model_paths, max_tokens):
    """Prepare prefill models"""
    from model import PrefillingModel
    m, p, tok, tag = load_vlm(path)
    return [PrefillingModel(m, p, tok, max_tokens, tag) for path in model_paths]


def _post_process_raw_response(raw_text: str, tag: str, direct_qa: bool) -> str:
    """Post-process model response to extract final answer"""
    if direct_qa:
        if "mimo" in tag or "glm" in tag or "ovis" in tag:
            resp = extract_final_boxed_content(raw_text)
            resp = clean_think_tags(resp)
            if "glm" in tag:
                resp = glm_extract_boxed(resp)
        else:
            resp = extract_final_boxed_content(raw_text)
            resp = clean_think_tags(resp)
    else:
        if "glm" in tag:
            resp = glm_extract_boxed(raw_text)
        else: 
            resp = extract_final_boxed_content(raw_text)
            resp = clean_think_tags(resp)
    
    return clean_answer(resp)

def run_reasoning(
    models: List[Any],
    entry: QAEntry,
    *,
    mode: str = "models",
    q_idx: Optional[int] = None,
    dataset: Optional[str] = None,
) -> ReasoningResult:
    """
    Run generative reasoning for each model on a single (image, question) pair.

    Args:
        models: List of model wrappers with .tag and .answer() methods
        entry: QAEntry with image_path, question, and optional answers
        mode: "models", "baseline" (full image)
        q_idx: Optional question index
        dataset: Evaluated dataset
        
    Returns:
        ReasoningResult with per-model reasoning and normalized answers
    """
    img_path = entry.image_path
    question = entry.question
    gt_list = entry.answers or []

    print(f"Question: {question}")
    
    per_model: Dict[str, PerModelRecord] = {}

    for model in models:
        tag = model.tag
        
        prompt_template = get_prompt(task="reason", dataset=dataset, model_tag=tag)

        if mode == "baseline":
            raw_text = model.answer(
                img_path, question, prompt_tpl=prompt_template, return_prompt=True
            )
            answer = _post_process_raw_response(raw_text, tag, direct_qa=True)

        per_model[tag] = PerModelRecord(
            reasoning=raw_text,
            answer=answer,
            anls=compute_anls(answer, gt_list),
        )

        torch.cuda.empty_cache()

    return ReasoningResult(
        image_path=img_path,
        question=question,
        models_reasoning=per_model,
        ground_truths=gt_list,
    )

def run_prefill(
    models: List[Any],
    entry: QAEntry,
    *,
    mode: str = "decode",
    source: str = "models_reasoning",
    running_model: Optional[str] = None,
    dataset: Optional[str] = None,
) -> PrefillResult:
    """
    Compute self/cross prefill PPL scores.

    Args:
        models: List of PrefillingModel instances with .tag and .prefill_nll()
        entry: QAEntry with answers data in .extra[source] 
        mode: "decode" (generate then score) or "cross" (score existing answers)
        source: Key to read existing answers from entry.extra
        running_model: Specific model for targeted cross-evaluation
        dataset: Evaluated dataset

    Returns:
        PrefillResult with self_ppl and optional cross_ppl scores
    """
    img_path = entry.image_path
    question = entry.question
    print(f"Prefill for: {question}")

    if mode == "decode":
        answers: Dict[str, Dict[str, Any]] = {}
        gt_list = entry.answers or []

        for model in models:
            # Get dataset-specific QA prompt from prompts.py
            prompt_template = get_prompt(task="qa", dataset=dataset, model_tag=model.tag)
            
            raw_text = model.answer(
                img_path, question, prompt_tpl=prompt_template, return_prompt=True
            )
            answer = _post_process_raw_response(raw_text, model.tag, direct_qa=True)
            
            answers[model.tag] = {
                "reasoning": raw_text,
                "answer": answer,
                "anls": compute_anls(answer, gt_list),
            }

        # Self PPL only for decode mode
        self_ppl = {}
        for model in models:
            if model.tag in answers:
                ppl = model.prefill_nll(img_path, question, answers[model.tag]["answer"])
                self_ppl[model.tag] = ppl
                print(f"[Self-PPL] {model.tag}: {ppl:.4f}")

        return PrefillResult(
            image_path=img_path,
            question=question,
            models=[m.tag for m in models],
            answers=answers,
            self_ppl=self_ppl,
        )

    else:  # mode == "cross"
        # Use existing answers for cross-evaluation
        extra = getattr(entry, "extra", {})
        answers = extra.get(source) or extra.get("models_reasoning")
        
        if not answers:
            raise ValueError(f"No answers found in entry.extra['{source}'] for cross prefill")

        # Self PPL
        self_ppl = {}
        for model in models:
            if model.tag in answers:
                ppl = model.prefill_nll(img_path, question, answers[model.tag]["answer"])
                self_ppl[model.tag] = ppl
                print(f"[Self-PPL] {model.tag}: {ppl:.4f}")

        # Cross PPL
        cross_ppl = {}
        
        if running_model:
            # Single model cross-evaluation
            eval_models = [name for name in answers.keys() if name != running_model]
            for answer_source in eval_models:
                for model in models:
                    if model.tag == running_model:
                        ppl = model.prefill_nll(img_path, question, answers[answer_source]["answer"])
                        cross_ppl.setdefault(answer_source, {})[running_model] = ppl
                        print(f"[Cross-PPL] {running_model} on {answer_source}: {ppl:.4f}")
        else:
            # Full cross-evaluation matrix
            for answer_source in answers.keys():
                for model in models:
                    if model.tag != answer_source:  # Skip self-evaluation
                        ppl = model.prefill_nll(img_path, question, answers[answer_source]["answer"])
                        cross_ppl.setdefault(answer_source, {})[model.tag] = ppl
                        print(f"[Cross-PPL] {model.tag} on {answer_source}: {ppl:.4f}")

        return PrefillResult(
            image_path=img_path,
            question=question,
            models=[m.tag for m in models],
            answers=answers,
            self_ppl=self_ppl,
            cross_ppl=cross_ppl,
        )

def _load_json_data(path: Path) -> List[Dict[str, Any]]:
    """Load JSON array or JSONL file"""
    text = path.read_text(encoding="utf-8").strip()
    
    if text.startswith("["):
        # JSON array
        return json.loads(text)
    else:
        # JSONL format
        return [json.loads(line) for line in text.splitlines() if line.strip()]

def _save_json_data(data: Any, path: Path) -> None:
    """Save data as JSON with proper formatting"""
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)