from typing import Optional, Callable
import re
import math
import json
from pathlib import Path

import torch

from ..utils import smart_resize


class TvcgDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        ann_path: str,
        image_dir: str,
        baseline_opts: str = "",
        max_image_size: int = 672,
        postprocess_fn: Optional[Callable] = None,
        debug: bool = False,
    ):
        self.name: str = "tvcg"
        self.postprocess_fn = postprocess_fn

        self.max_image_size = max_image_size

        baseline_opts_li = baseline_opts.split(",")

        self.no_copy = False
        if "no_copy" in baseline_opts_li:
            print("running baseline: no_copy")
            self.name: str = "tvcg_no_copy"
            self.no_copy = True

        # self.image_size = (480, 320)  # w, h
        self.image_dir = Path(image_dir)
        assert self.image_dir.is_dir()

        with open(ann_path) as f:
            data = json.load(f)
        print(f"TVCG dataset: loaded {len(data)} items")

        self.data = data

    def __len__(self):
        return len(self.data)

    def resize_image(self, image_size):
        width, height = image_size
        if width > self.max_image_size or height > self.max_image_size:
            scaling_factor = min(
                self.max_image_size / width, self.max_image_size / height
            )
            width = int(width * scaling_factor)
            height = int(height * scaling_factor)

            height, width = smart_resize(height, width)
        return width, height

    def resize_bbox(self, bbox, prev_size, new_size):
        x1, y1, x2, y2 = bbox
        return [
            math.floor(x1 * new_size[0] / prev_size[0]),
            math.floor(y1 * new_size[1] / prev_size[1]),
            math.ceil(x2 * new_size[0] / prev_size[0]),
            math.ceil(y2 * new_size[1] / prev_size[1]),
        ]

    def __getitem__(self, idx: int):
        row = self.data[idx]

        bbox_dt = row["regions"]
        conv = row["conversation"]

        image = conv[0]["content"][0]["image"]
        image = str(self.image_dir / Path(image).name)
        conv[0]["content"][0]["image"] = image

        image_size = self.resize_image(row["image_size"])
        if image_size != row["image_size"]:
            bbox_dt = {
                k: self.resize_bbox(bbox, row["image_size"], image_size)
                for k, bbox in bbox_dt.items()
            }

        if self.no_copy:
            response = conv[1]["content"][0]["text"]
            pattern = re.compile(r"<|obj\d+|>")

            # Do the replacement
            # Replacement function
            def replacer(match):
                return ""

            response_sub = pattern.sub(replacer, response)
            conv[1]["content"][0]["text"] = response_sub
            bbox_dt = {}
        if self.postprocess_fn is not None:
            return self.postprocess_fn(conv, bbox_dt, image_size)
        return conv, bbox_dt, image_size


def is_consecutive(lst):
    if not lst:
        return False  # or True if empty list is considered consecutive
    return max(lst) - min(lst) + 1 == len(lst) and len(set(lst)) == len(lst)
