import os.path
from typing import Any, Callable, List, Optional, Tuple

from PIL import Image

from torchvision.datasets.vision import VisionDataset
import pickle
import torch
import torchvision
import re
# from torchvision.datasets import CocoDetection
# from utils.clip_filter import Clip_filter
from tqdm import tqdm

class CocoDetection(VisionDataset):


    def __init__(
            self,
            root: str ,
            annFile: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
            get_img=True,
            get_cap=True
    ) -> None:
        super().__init__(root, transforms, transform, target_transform)
        from pycocotools.coco import COCO

        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.column_names = ["image", "text"]
        self.get_img = get_img
        self.get_cap = get_cap

    def _load_image(self, id: int) -> Image.Image:
        path = self.coco.loadImgs(id)[0]["file_name"]
        with open(os.path.join(self.root, path), 'rb') as f:
            img = Image.open(f).convert("RGB")

        return img

    def _load_target(self, id: int) -> List[Any]:
        return self.coco.loadAnns(self.coco.getAnnIds(id))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.ids[index]
        ret={"id":id}
        if self.get_img:
            image = self._load_image(id)
            ret["image"] = image
        if self.get_cap:
            target = self._load_target(id)
            ret["caption"] = [target]

        if self.transforms is not None:
            ret = self.transforms(ret)

        return ret

    def subsample(self, n: int = 10000):
        if n is None or n == -1:
            return self
        ori_len = len(self)
        assert n <= ori_len
        # equal interval subsample
        ids = self.ids[::ori_len // n][:n]
        self.ids = ids
        print(f"COCO dataset subsampled from {ori_len} to {len(self)}")
        return self


    def with_transform(self, transform):
        self.transforms = transform
        return self

    def __len__(self) -> int:
        # return 100
        return len(self.ids)


class CocoCaptions(CocoDetection):

    def _load_target(self, id: int) -> List[str]:
        return [ann["caption"] for ann in super()._load_target(id)]


class CocoCaptions_clip_filtered(CocoCaptions):
    positive_prompt=["painting", "drawing", "graffiti",]
    def __init__(
            self,
            root: str ,
            annFile: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
            regenerate: bool = False,
            id_file: Optional[str] = ""
    ) -> None:
        super().__init__(root, annFile, transform, target_transform, transforms)
        os.makedirs(os.path.dirname(id_file), exist_ok=True)
        if os.path.exists(id_file) and not regenerate:
            with open(id_file, "rb") as f:
                self.ids = pickle.load(f)
        else:
            self.ids, naive_filtered_num = self.naive_filter()
            self.ids, clip_filtered_num = self.clip_filter(0.7)

            print(f"naive Filtered {naive_filtered_num} images")
            print(f"Clip Filtered {clip_filtered_num} images")

            with open(id_file, "wb") as f:
                pickle.dump(self.ids, f)
                print(f"Filtered ids saved to {id_file}")
        print(f"COCO filtered dataset size: {len(self)}")

    def naive_filter(self, filter_prompt="painting"):
        new_ids = []
        naive_filtered_num = 0
        for id in self.ids:
            target = self._load_target(id)
            filtered = False
            for prompt in target:
                if filter_prompt in prompt.lower():
                    filtered = True
                    naive_filtered_num += 1
                    break
                # if "artwork" in prompt.lower():
                #     pass
            if not filtered:
                new_ids.append(id)
        return new_ids, naive_filtered_num



def get_validation_set():
    coco_instance = CocoDetection(root="coco_2017/train2017/", annFile="coco_2017/annotations/instances_train2017.json")
    discard_cat_id = coco_instance.coco.getCatIds(supNms=["person", "animal"])
    discard_img_id = []
    for cat_id in discard_cat_id:
        discard_img_id += coco_instance.coco.catToImgs[cat_id]

    coco_clip_filtered = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
                                regenerate=False)
    coco_clip_filtered_ids = coco_clip_filtered.ids
    new_ids = set(coco_clip_filtered_ids) - set(discard_img_id)
    new_ids = list(new_ids)
    new_ids = random.sample(new_ids, 100)
    with open("coco/coco_clip_filtered_subset100.pickle", "wb") as f:
        pickle.dump(new_ids, f)

if __name__ == "__main__":
    from mypath import MyPath
    import random
    get_validation_set()

    dataset = CocoDetection(root=MyPath.db_root_dir("coco_train"), annFile="coco_2017/annotations/instances_train2017.json")
    ids = dataset.ids
    a=dataset.coco.getCatIds(9)
    chosen_id = random.sample(ids, 100)
    captions = []
    for id in chosen_id:
        captions.append(dataset._load_target(id)[0])

    with open("coco_caption_sample.txt", "w") as f:
        for caption in captions:
            f.write(caption)
            f.write("\n")
    print(len(dataset))
    print(dataset[0])