import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import math
import json
from itertools import chain
from typing import Optional, Union
from pathlib import Path
from dataclasses import dataclass

import ray
from ray.util.actor_pool import ActorPool
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm, trange
import tyro


@dataclass
class Config:
    out_dir: str = "../../data/data/refcoco_mix_base"
    normalize_size: Optional[int] = 1024
    mix_open_images: int = 0
    # mix_visual_genome: int = 50000
    mix_visual_genome: int = 0
    use_coco: bool = False
    add_captioning: bool = False
    debug: bool = False
    num_workers: Optional[int] = 32


def resize_image(image, max_width: int = 512, max_height: int = 512):
    return image.resize((max_width, max_height))


def get_conversation(desc, bbox, key: str = "name", add_captioning: bool = False):
    loc = json.dumps(bbox)
    grounded_message = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": f"Detect the object in the image with the following {key}: {desc}",
                },
                {"type": "image"},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": loc,
                },
            ],
        },
    ]
    msgs = [grounded_message]
    if add_captioning:
        reference_message = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": f"Describe {key} of the object in the image with the following location: {loc}",
                    },
                    {"type": "image"},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": desc,
                    },
                ],
            },
        ]
        msgs.append(reference_message)
    return msgs


def process_row(
    image_id,
    image,
    desc,
    bbox,
    out_image_dir,
    normalize_size,
    key: str = "name",
    add_captioning: bool = False,
):
    qid = int(hash(desc))
    image = image.convert("RGB")
    bbox = normalize_bbox(bbox, image, size=normalize_size)
    if normalize_size is not None:
        image = resize_image(image, normalize_size, normalize_size)
    image_filename = f"{image_id}.jpg"
    out_path = out_image_dir / image_filename
    if not out_path.is_file():
        image.save(out_path)
    assert is_bbox_within(bbox, 0, 1025), bbox
    conversations = get_conversation(desc, bbox, key=key, add_captioning=add_captioning)
    return [
        {
            "conversations": conversation,
            "id": f"{image_id}_{qid}_{i}",
            "image": image_filename,
        }
        for i, conversation in enumerate(conversations)
    ]


KEYS = {
    "refcoco": "jxu124/refcoco",
    "refcocog": "jxu124/refcocog",
    "refcocop": "jxu124/refcocoplus",
}


def convert_bbox_base(bbox):
    x, y, x2, y2 = bbox
    return {
        "left_x": x,
        "top_y": y,
        "right_x": x2,
        "bottom_y": y2,
    }


def convert_bbox_coco(bbox):
    x, y, w, h = bbox
    return {
        "left_x": x,
        "top_y": y,
        "right_x": x + w,
        "bottom_y": y + h,
    }


def is_bbox_within(dt, minv, maxv):
    for k, v in dt.items():
        if not (minv <= v <= maxv):
            return False
    return True


def normalize_bbox(bbox, image, size: Optional[int] = 1024):
    if size is not None:
        orig_size = image.size
        w, h = orig_size
        x_ratio = size / w
        y_ratio = size / h
    else:
        x_ratio = 1
        y_ratio = 1
    bbox = {
        "left_x": math.floor(bbox["left_x"] * x_ratio),
        "top_y": math.floor(bbox["top_y"] * y_ratio),
        "right_x": math.ceil(bbox["right_x"] * x_ratio),
        "bottom_y": math.ceil(bbox["bottom_y"] * y_ratio),
    }
    return bbox


def process_dataset(
    image_dataset,
    image_indices_map,
    out_image_dir,
    normalize_size,
    key: str = "refcoco",
    add_captioning: bool = False,
    debug: bool = False,
):
    dataset = load_dataset(KEYS[key])["train"]
    data = []
    image_ids = []
    for row in tqdm(dataset, desc=f"Processing {key}"):
        image_id = row["image_id"]
        if image_id not in image_indices_map:
            continue
        image_ids.append(image_id)
        image_index = image_indices_map[image_id]
        image_row = image_dataset[image_index]
        image = image_row["image"]
        bbox = convert_bbox_base(row["bbox"])
        for caption in row["captions"]:
            data.extend(
                process_row(
                    image_id,
                    image,
                    caption,
                    bbox,
                    out_image_dir,
                    normalize_size,
                    key="reference",
                    add_captioning=add_captioning,
                )
            )
        if debug:
            break
    return data, image_ids


ray.init()


@ray.remote
class ImageProcessor:
    def __init__(
        self,
        image_dataset,
        dataset,
        image_indices_map,
        out_image_dir,
        normalize_size,
        add_captioning=False,
    ):
        self.image_dataset = image_dataset
        self.dataset = dataset
        self.image_indices_map = image_indices_map
        self.out_image_dir = out_image_dir
        self.normalize_size = normalize_size
        self.add_captioning = add_captioning

    def process_image(self, idx):
        row = self.dataset[idx]
        """Process a single row."""
        image_id = row["image_id"]
        if image_id not in self.image_indices_map:
            return None, None

        image_index = self.image_indices_map[image_id]
        image_row = self.image_dataset[image_index]
        image = image_row["image"]
        bbox = convert_bbox_base(row["bbox"])

        data = []
        for caption in row["captions"]:
            data.extend(
                process_row(
                    image_id,
                    image,
                    caption,
                    bbox,
                    self.out_image_dir,
                    self.normalize_size,
                    key="reference",
                    add_captioning=self.add_captioning,
                )
            )
        print(image_id)
        return data, image_id


def process_dataset_ray(
    image_dataset,
    image_indices_map,
    out_image_dir,
    normalize_size,
    key: str = "refcoco",
    add_captioning: bool = False,
    debug: bool = False,
    num_workers: int = None,
):
    dataset = load_dataset(KEYS[key])["train"]

    if num_workers is None:
        num_workers = min(
            len(dataset), ray.available_resources()["CPU"]
        )  # Use available CPUs

    # Create actors (workers) with the dataset assigned to each
    workers = [
        ImageProcessor.remote(
            image_dataset,
            dataset,
            image_indices_map,
            out_image_dir,
            normalize_size,
            add_captioning,
        )
        for _ in range(num_workers)
    ]
    pool = ActorPool(workers)

    # Split dataset for each worker
    all_ids = list(range(len(dataset)))
    results = [
        (results, image_id)
        for results, image_id in pool.map(
            lambda a, v: a.process_image.remote(v), all_ids
        )
        if results
    ]
    results, image_ids = zip(*results)
    results = list(chain(*results))

    return results, image_ids


def main(args):
    out_dir = Path(args.out_dir)

    out_image_dir = out_dir / "images"
    out_image_dir.mkdir(exist_ok=True, parents=True)

    with open(Path(__file__).parent / "coco_category.json") as f:
        coco_cat = json.load(f)

    image_dataset = load_dataset("detection-datasets/coco")["train"]
    image_indices_map = {v: i for i, v in enumerate(image_dataset["image_id"])}
    data = []
    image_ids = []
    for key in KEYS:
        _data, _image_ids = process_dataset_ray(
            image_dataset,
            image_indices_map,
            out_image_dir,
            args.normalize_size,
            key=key,
            add_captioning=args.add_captioning,
            debug=args.debug,
            num_workers=args.num_workers,
        )
        data = [*data, *_data]
        image_ids = [*image_ids, *_image_ids]

    # data = []
    # image_ids = []
    # for key in KEYS:
    #     _data, _image_ids = process_dataset(
    #         image_dataset,
    #         image_indices_map,
    #         out_image_dir,
    #         args.normalize_size,
    #         key=key,
    #         add_captioning=args.add_captioning,
    #         debug=args.debug,
    #     )
    #     data = [*data, *_data]
    #     image_ids = [*image_ids, *_image_ids]

    if args.use_coco:
        for image_id in tqdm(image_ids, desc="Processing coco"):
            image_index = image_indices_map[image_id]
            image_row = image_dataset[image_index]
            image = image_row["image"]
            for category, bbox in zip(
                image_row["objects"]["category"], image_row["objects"]["bbox"]
            ):
                bbox = convert_bbox_coco(bbox)
                name = coco_cat[str(category)]
                data.extend(
                    process_row(
                        image_id,
                        image,
                        name,
                        bbox,
                        out_image_dir,
                        args.normalize_size,
                        key="name",
                        add_captioning=args.add_captioning,
                    )
                )
            if args.debug:
                break

    if args.mix_open_images > 0:
        open_images = load_dataset("vikhyatk/openimages-bbox")["train"]
        pbar = trange(args.mix_open_images, desc="Processing open images")
        for i in pbar:
            row = open_images[i]
            w, h = row["image"].size
            image = row["image"]
            for obj in row["objects"]:
                bbox = {
                    "left_x": obj["xmin"] * w,
                    "top_y": obj["ymin"] * h,
                    "right_x": obj["xmax"] * w,
                    "bottom_y": obj["ymax"] * h,
                }
                name = obj["label"].lower()
                data.extend(
                    process_row(
                        f"open_image_{i}",
                        image,
                        name,
                        bbox,
                        out_image_dir,
                        args.normalize_size,
                        key="name",
                        add_captioning=args.add_captioning,
                    )
                )
            if args.debug:
                break

    if args.mix_visual_genome > 0:
        visual_genome = load_dataset(
            "ranjaykrishna/visual_genome",
            "region_descriptions_v1.2.0",
            trust_remote_code=True,
        )["train"]

        pbar = trange(args.mix_visual_genome, desc="Processing visual genome regions")
        for i in pbar:
            row = visual_genome[i]
            image = row["image"]
            w, h = image.size
            for obj in row["regions"]:
                bbox = {
                    "left_x": obj["x"],
                    "top_y": obj["y"],
                    "right_x": obj["x"] + obj["width"],
                    "bottom_y": obj["y"] + obj["height"],
                }
                reference = obj["phrase"]
                data.extend(
                    process_row(
                        f"visual_genome_{i}",
                        image,
                        reference,
                        bbox,
                        out_image_dir,
                        args.normalize_size,
                        key="reference",
                        add_captioning=args.add_captioning,
                    )
                )
            if args.debug:
                break
        visual_genome = load_dataset(
            "ranjaykrishna/visual_genome",
            "objects_v1.2.0",
            trust_remote_code=True,
        )["train"]

        pbar = trange(args.mix_visual_genome, desc="Processing visual genome objects")
        for i in pbar:
            row = visual_genome[i]
            image = row["image"]
            w, h = image.size
            for obj in row["objects"]:
                bbox = {
                    "left_x": obj["x"],
                    "top_y": obj["y"],
                    "right_x": obj["x"] + obj["w"],
                    "bottom_y": obj["y"] + obj["h"],
                }
                name = obj["names"][0].lower()
                data.extend(
                    process_row(
                        f"visual_genome_{i}",
                        image,
                        name,
                        bbox,
                        out_image_dir,
                        args.normalize_size,
                        key="name",
                        add_captioning=args.add_captioning,
                    )
                )
            if args.debug:
                break

    print(f"total_size: {len(data)}")
    # save
    with open(out_dir / "annotations.json", "w") as f:
        json.dump(data, f, indent=4)

    print("done")


if __name__ == "__main__":
    main(tyro.cli(Config))
