import json
from dataclasses import dataclass
from typing import Literal, Optional
from collections import defaultdict
from pprint import pprint
from pathlib import Path

import numpy as np
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoProcessor, AutoConfig
import torch
import tyro
from tqdm import tqdm


@dataclass
class Config:
    subset: Literal["RefCOCO", "RefCOCOPlus", "RefCOCOg"] = "RefCOCO"
    split: Literal["testA", "testB", "test", "val"] = "testA"
    num: Optional[int] = None
    batch_size: int = 8

    repo: str = "HF_REPO"
    # key: str = "ft-Phi-3.5-vision-instruct_sft_refcoco-checkpoint-7530"
    # key: str = "ft-Phi-3.5-vision-instruct_digit_base_refcoco-checkpoint-7530"
    key: str = "ft-Qwen2-VL-2B-Instruct_digit_refcoco-checkpoint-7530"


def xyhw_to_xyxy(loc):
    x, y, h, w = loc
    return [x, y, x + h, y + w]


def resize_image(image, max_width: int = 512, max_height: int = 512):
    width, height = image.size
    aspect_ratio = width / height
    if width > height:
        new_width = min(max_width, width)
        new_height = int(new_width / aspect_ratio)
    else:
        new_height = min(max_height, height)
        new_width = int(new_height * aspect_ratio)
    return image.resize((new_width, new_height))


def get_data(subset, split):
    invalid_combs = [
        ("RefCOCOplus", "test"),
        ("RefCOCOg", "testA"),
        ("RefCOCOg", "testB"),
    ]
    assert (
        subset,
        split,
    ) not in invalid_combs, f"Invalid combination of subset={subset} and split={split}"
    data = load_dataset(f"lmms-lab/{subset}", split=split)
    return data


def compute_iou(box1, box2):
    """
    Compute the Intersection over Union (IoU) of two bounding boxes.

    Parameters:
    - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max].
    - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max].

    Returns:
    - float: IoU of box1 and box2.
    """
    # Determine the coordinates of the intersection rectangle
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    # Compute the area of intersection
    intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)

    # Compute the area of both bounding boxes
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    # Compute the area of the union
    union_area = box1_area + box2_area - intersection_area

    # Compute the Intersection over Union
    iou = intersection_area / union_area

    return iou


def main():
    args = tyro.cli(Config)

    special_opts = {
        "is_qwen2_vl": "Qwen2-VL" in args.key,
        "is_phi3v": "Phi-3.5-vision" in args.key,
        "is_llava_phi": "llava-phi" in args.key,
        "is_llava_ov": "llava-onevision" in args.key,
        "is_llava_gemma": "llava-gemma" in args.key,
    }

    root = Path(f"./stats/{args.subset}/{args.split}")
    root.mkdir(exist_ok=True, parents=True)

    model_cls = AutoModelForCausalLM
    if special_opts["is_qwen2_vl"]:
        from transformers import Qwen2VLForConditionalGeneration

        model_cls = Qwen2VLForConditionalGeneration
    model = model_cls.from_pretrained(
        args.repo,
        subfolder=args.key,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        _attn_implementation="flash_attention_2",
    )
    model = model.to("cuda")
    model.eval()

    processor_name = "microsoft/Phi-3.5-vision-instruct"
    if special_opts["is_qwen2_vl"]:
        processor_name = "Qwen/Qwen2-VL-2B-Instruct"

    data = get_data(args.subset, args.split)

    processor = AutoProcessor.from_pretrained(processor_name, trust_remote_code=True)
    kwargs = {}
    kwargs["torch_dtype"] = torch.bfloat16

    # user_prompt = "<|user|>\n"
    # assistant_prompt = "<|assistant|>\n"
    # prompt_suffix = "<|end|>\n"
    # prompt = f"{user_prompt}<|image_1|>\nDetect the object in the image with the following reference: <|DESC|>{prompt_suffix}{assistant_prompt}"
    conversation = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Detect the object in the image with the following reference: <|DESC|>",
                },
                {"type": "image"},
            ],
        }
    ]

    def run_old_processor(conversation, image_key: str = "<|image|>"):
        conversation[0]["content"][0]["text"] = (
            f"{image_key}\n" + conversation[0]["content"][0]["text"]
        )
        conversation = [
            {**turn, "content": turn["content"][0]["text"]} for turn in conversation
        ]
        prompt = processor.tokenizer.apply_chat_template(
            conversation, tokenize=False, add_generation_prompt=True
        )
        return prompt

    if special_opts["is_phi3v"]:
        prompt = run_old_processor(conversation, "<|image_1|>")
    else:
        prompt = processor.apply_chat_template(
            conversation, tokenize=False, add_generation_prompt=True
        )

    def process(prompts, images):
        pad_token_id = processor.tokenizer.pad_token_id
        input_ids = []
        attention_mask = []
        out_images = []
        others = []
        for prompt, image in zip(prompts, images):
            inputs = processor(
                text=prompt,
                images=image,
                return_tensors="pt",
            ).to("cuda")
            input_ids.append(inputs.input_ids[0].flip(0))
            attention_mask.append(inputs.attention_mask[0].flip(0))
            _image = inputs.pixel_values
            if not special_opts["is_qwen2_vl"]:
                _image = _image[0]
            out_images.append(_image)
            others.append(inputs)

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=pad_token_id
        ).flip(-1)
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            attention_mask, batch_first=True, padding_value=0
        ).flip(-1)
        out_images = torch.stack(out_images, dim=0)
        inputs = {
            k: torch.stack([o[k][0] for o in others], dim=0)
            for k in others[0].keys()
            if k not in ["input_ids", "attention_mask", "pixel_values"]
        }
        inputs = {
            **inputs,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": out_images,
        }
        return inputs

    def run_batch(images, descs):
        prompts = [prompt.replace("<|DESC|>", desc) for desc in descs]
        orig_sizes = [image.size for image in images]
        images = [image.resize((1024, 1024)) for image in images]
        # images = [resize_image(image.convert("RGB")) for image in images]
        processor.tokenizer.padding_side = "left"

        inputs = process(prompts, images)
        generate_ids = model.generate(
            **inputs,
            max_new_tokens=60,
            eos_token_id=processor.tokenizer.eos_token_id,
        )
        generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
        responses = processor.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        locs = []
        for response, orig_size in zip(responses, orig_sizes):
            text = response.strip().split("\n")[0].strip()
            try:
                obj = json.loads(text)
                loc = [
                    obj["left_x"] / 1024 * orig_size[0],
                    obj["top_y"] / 1024 * orig_size[1],
                    obj["right_x"] / 1024 * orig_size[0],
                    obj["bottom_y"] / 1024 * orig_size[1],
                ]
                x1, y1, x2, y2 = loc
                loc = [x1, y1, x2 - x1, y2 - y1]
            except Exception as e:
                print(e)
                loc = [0, 1, 0, 1]
            locs.append(loc)
        return locs

    batch = []
    all_outs = []
    for i, row in tqdm(enumerate(data), total=len(data)):
        if args.num is not None and len(all_outs) >= args.num:
            break
        for ans in row["answer"]:
            batch.append([row["image"], ans, row["bbox"]])

        while len(batch) >= args.batch_size:
            rest = []
            if len(batch) > args.batch_size:
                rest = batch[args.batch_size :]
                batch = batch[: args.batch_size]
            outs = run_batch([b[0] for b in batch], [b[1] for b in batch])
            all_outs = [*all_outs, *zip(outs, [b[2] for b in batch])]
            batch = rest

    if len(batch) > 0:
        outs = run_batch([b[0] for b in batch], [b[1] for b in batch])
        all_outs = [*all_outs, *zip(outs, [b[2] for b in batch])]

    print("eval")
    scores = defaultdict(lambda: [])
    for i, (pred, target) in enumerate(all_outs):
        pred = xyhw_to_xyxy(pred)
        target = xyhw_to_xyxy(target)
        iou = compute_iou(pred, target)
        scores["iou"].append(iou)
        for th in [0.1, 0.3, 0.5, 0.7, 0.9]:
            scores[f"acc@{th}"].append(int(iou > th))

    stats = {k: float(np.mean(v)) for k, v in scores.items()}
    pprint(stats)

    with open(root / f"{args.key}.json", "w") as f:
        json.dump(stats, f, indent=4)


if __name__ == "__main__":
    main()
