"""
Step-3: Judge using VLM

- Evaluate using the served vision-language (VL) model.
- Perform rescaling on both the reference image and the generated image:
  - Determine the shorter side of each image.
  - Compute the average of these two shorter sides.
  - Resize each image so that its shorter side matches the computed average size.
"""

import base64
from collections import defaultdict
import io
import json
from PIL import Image
import logging
import os
import re
from functools import cache
from pathlib import Path
from time import time

import fire

from models import GenModelVllmMultiModal
from summarize import summarize_entry
from utils import read_jsonl, setup_logger, write_jsonl

logger = logging.getLogger("text2svg")

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# PROMPT_TEMPLATE = "prompts/judge.v2_with_golden.txt"
# PROMPT_TEMPLATE = "prompts/judge.v3_with_golden.txt"
PROMPT_TEMPLATE = "prompts/judge.v4_with_golden.txt"


@cache
def load_prompt_template():
    with Path(PROMPT_TEMPLATE).open("r") as f:
        template = f.read()
    return template


def build_prompt(ex):
    template = load_prompt_template()
    prompt = ex["prompt"]
    prompt = template.replace("{THE_SVG_INSTRUCTION}", prompt)
    return prompt


@cache
def make_xml_re(xml_key):
    compiled_re = re.compile(f"<{xml_key}>(.+?)</{xml_key}>", flags=re.DOTALL)
    return compiled_re


def re_match_only_one(content, re_expr):
    match = re.findall(re_expr, content)
    if not match or len(match) != 1:
        return None
    matched = match[0].strip()
    return matched


def try_parse_int(s):
    try:
        return int(s)
    except:
        return None


def parse_result(gen):
    p1_summary = re_match_only_one(gen, make_xml_re("comparison_summary"))

    p2_detailed_scores = {}
    for section in ["object_text_accuracy", "positioning_stroke", "color_overall"]:
        matched_part = re_match_only_one(gen, make_xml_re(section))
        if matched_part is None:
            p2_detailed_scores[section] = None
            continue
        p2_detailed_scores[section] = {
            "review": re_match_only_one(matched_part, make_xml_re("review")),
            "score": try_parse_int(re_match_only_one(matched_part, make_xml_re("score"))),
        }

    p3_review = re_match_only_one(gen, make_xml_re("final_review"))
    p4_score = try_parse_int(re_match_only_one(gen, make_xml_re("final_score")))

    return {
        "summary": p1_summary,
        "detailed": p2_detailed_scores,
        "review": p3_review,
        "score": p4_score,
    }


def parse_failed():
    return {"summary": None, "detailed": None, "review": None, "score": 0.0}


def b64_to_img(b64):
    image_data = base64.b64decode(b64)
    image = Image.open(io.BytesIO(image_data))
    return image


def check_image_size(image_b64):
    max_aspect_ratio = 8
    min_size = 28

    try:
        image = b64_to_img(image_b64)
        width, height = image.size
        if width < min_size or height < min_size:
            return False

        if width >= height:
            aspect_ratio = width / height
        else:
            aspect_ratio = height / width

        if aspect_ratio > max_aspect_ratio:
            return False
        return True

    except Exception as e:
        return False


def resize_image_base64(img, shorter_side_size):
    ori_width, ori_height = img.size

    short_side = min(ori_width, ori_height)
    scale_factor = shorter_side_size / float(short_side)

    new_width = int(ori_width * scale_factor)
    new_height = int(ori_height * scale_factor)

    resized_img = img.resize((new_width, new_height), Image.LANCZOS)

    buffered = io.BytesIO()
    resized_img = resized_img.convert("RGB")
    resized_img.save(buffered, format="JPEG")
    output_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return output_base64


def scale_images_align_shorter_side(gen_img_b64, ref_img_b64):
    gen_img = b64_to_img(gen_img_b64)
    ref_img = b64_to_img(ref_img_b64)

    avg_shorter_size = (min(gen_img.size) + min(ref_img.size)) / 2.0

    scaled_gen_b64 = resize_image_base64(gen_img, avg_shorter_size)
    scaled_ref_b64 = resize_image_base64(ref_img, avg_shorter_size)

    # print(f"Gen: {gen_img.size} -> {b64_to_img(scaled_gen_b64).size}")
    # print(f"Ref: {ref_img.size} -> {b64_to_img(scaled_ref_b64).size}")
    return scaled_gen_b64, scaled_ref_b64


def perform_vl_judge(
    model,
    generation_file,
    save_file,
    num_runs=5,
    limit=None,
):
    samples = read_jsonl(generation_file)
    if limit is not None:
        samples = samples[:limit]
    logger.info(f"Loaded samples = {len(samples)}")

    generations_file = Path(save_file)
    generations_file = generations_file.resolve()

    idx2res = {}
    for i, sample in enumerate(samples):
        idx2res[i] = {
            **sample,
            "render_status": None,
            "failed_reason": "",
            "judge_output": [],
            "judge_parsed": [],
            "scores": [],
        }

    render_status_counter = defaultdict(lambda: 0)
    prompts = []
    valid_idx = []

    logger.info("Processing Image (generated & reference)...")
    for idx, sample in enumerate(samples):
        generated_rendered = sample["rendered_model_gen"]
        reference_rendered = sample["rendered_image_b64"]

        failed_reason = "THIS-SHOULD-NOT-OCCUR"
        if generated_rendered is None:
            failed_reason = "Not Generated"
            render_status = False
        elif not check_image_size(generated_rendered):
            failed_reason = "Bad Ratio or Size"
            render_status = False
        else:
            render_status = True

        if render_status:
            render_status_counter["Success"] += 1
        else:
            render_status_counter[failed_reason] += 1

        if not render_status:
            failed_res = parse_failed()
            idx2res[idx]["render_status"] = False
            idx2res[idx]["failed_reason"] = failed_reason
            idx2res[idx]["judge_output"].append(None)
            idx2res[idx]["judge_parsed"].append(failed_res)
            idx2res[idx]["scores"].append(failed_res["score"])
        else:
            idx2res[idx]["render_status"] = True
            assert reference_rendered is not None, "Reference image must exist if generation exists."

            scaled_gen_img, scaled_ref_img = scale_images_align_shorter_side(generated_rendered, reference_rendered)
            llm_input = model.preprocess_prompt(
                build_prompt(sample),
                list_of_b64_images=[
                    # generated_rendered,  # first: the generated image
                    # reference_rendered,  # second: reference (ground-truth) image
                    scaled_gen_img,
                    scaled_ref_img,
                ],
            )

            for _ in range(num_runs):
                prompts.append(llm_input)
                valid_idx.append(idx)

    logger.warning(f"Status: {dict(render_status_counter)}")

    if prompts:
        logger.info(f"{len(prompts) = }")
        logger.info("Example prompt:")
        logger.info("-" * 20)
        logger.info(prompts[0]["prompt"])
        logger.info("-" * 20)
        logger.info(f"{len(prompts[0]['multi_modal_data']) = }")

        logger.info("Generating...")
        tic = time()
        judge_outputs = model.generate(prompts)
        judge_parsed = [parse_result(output) for output in judge_outputs]

        toc = time()
        logger.info(f"Generation DONE: in {toc - tic:.2f} s")

        for idx, output, parsed in zip(valid_idx, judge_outputs, judge_parsed):
            idx2res[idx]["judge_output"].append(output)
            idx2res[idx]["judge_parsed"].append(parsed)

            parsed_score = parsed["score"]
            if isinstance(parsed_score, int):
                idx2res[idx]["scores"].append(parsed_score)

    merged_generations = []
    for i in range(len(samples)):
        all_scores = idx2res[i]["scores"]

        if not all_scores:
            idx2res[i]["score"] = 0.0
        else:
            sorted_scores = sorted(all_scores)
            if len(sorted_scores) > 2:  # remove highest & lowest
                trimmed = sorted_scores[1:-1]
            else:
                trimmed = sorted_scores

            idx2res[i]["score"] = sum(trimmed) / len(trimmed)

        merged_generations.append(idx2res[i])

    write_jsonl(generations_file, merged_generations)
    logger.info(f"Generations => {generations_file}, #={len(merged_generations)}")


def judge_entry(
    judge_model: str = "Qwen/Qwen2.5-VL-72B-Instruct",
    input_file: str = None,
    output_file: str = None,
    metric_file: str = None,
    tp: int = 1,
    num_runs: int = 5,
    limit: int = None,
):
    input_file = Path(input_file)
    save_file = Path(output_file)
    metric_file = Path(metric_file)

    if not logging.getLogger("text2svg").hasHandlers():
        setup_logger(save_file.parent, console_output=True)

    logger.info("===============[Start:judge_resize]===============")

    save_folder = save_file.parent
    save_folder.mkdir(parents=True, exist_ok=True)
    logger.info(f"{save_folder = }")

    configs_dumped = {
        "judge_model": judge_model,
        "input_file": str(input_file),
        "output_file": str(output_file),
        "metric_file": str(metric_file),
        "tp": tp,
        "num_runs": num_runs,
        "limit": limit,
    }

    with save_folder.joinpath("config_judge.json").open("w") as f:
        json.dump(configs_dumped, f, indent=2, ensure_ascii=False)

    kwargs = dict(
        model=judge_model,
        temperature=None,
        max_new_tokens=4096,
    )

    logger.info(f"{kwargs = }")
    model = GenModelVllmMultiModal(tp=tp, **kwargs)

    perform_vl_judge(model, input_file, save_file, num_runs=num_runs, limit=limit)
    model.close()

    summarize_entry(output_file, metric_file)

    logger.info("===============[Success:judge_resize]===============")


if __name__ == "__main__":
    fire.Fire(judge_entry)
