import os.path
import sys
from typing import Any, Callable, List, Optional, Tuple

import tqdm
from PIL import Image

from torch.utils.data import Dataset
import pickle
from torchvision import transforms
# import torch
# import torchvision
# import re


class SamDataset(Dataset):
    def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None,
                 resolution=None,
                 get_img=True,
                 get_cap=True,):
        if id_dict_file is not None:
            with open(id_dict_file, 'rb') as f:
                print(f"Loading id_dict from {id_dict_file}", flush=True)
                self.id_dict = pickle.load(f)
                print(f"Loaded id_dict from {id_dict_file}", flush=True)
        else:
            self.id_dict = None
        if isinstance(id_file, list):
            self.ids = id_file
        elif isinstance(id_file, str):
            with open(id_file, 'rb') as f:
                print(f"Loading ids from {id_file}", flush=True)
                self.ids = pickle.load(f)
                print(f"Loaded ids from {id_file}", flush=True)
        self.resolution = resolution
        self.ori_image_folder_path = image_folder_path
        if self.resolution is not None:
            self.image_folder_path = f"{image_folder_path}_{resolution}"
            os.makedirs(self.image_folder_path, exist_ok=True)
        else:
            self.image_folder_path = image_folder_path
        self.caption_folder_path = caption_folder_path
        self.transforms = transforms
        self.column_names = ["image", "text"]
        self.get_img = get_img
        self.get_cap = get_cap

    def __len__(self):
        # return 100
        return len(self.ids)

    def __getitem__(self, index: int):
        id = self.ids[index]
        ret={"id":id}
        try:
            # if index == 1:
            #     raise Exception("test")
            if self.get_img:
                image = self._load_image(id)
                ret["image"]=image
            if self.get_cap:
                target = self._load_caption(id)
                ret["text"] = [target]
            if self.transforms is not None:
                ret = self.transforms(ret)
            return ret
        except Exception as e:
            raise e
            print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True)
            ret = self[0]
            return ret

    def define_resolution(self, resolution: int):
        self.resolution = resolution
        self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
        print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}")
    def _load_image(self, id: int) -> Image.Image:
        if self.id_dict is not None:
            subfolder = self.id_dict[id]
            image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg"
        else:
            image_path = f"{self.image_folder_path}/sa_{id}.jpg"

        try:
            with open(image_path, 'rb') as f:
                img = Image.open(f).convert("RGB")
            # return img
        except:
            # load original image
            if self.id_dict is not None:
                subfolder = self.id_dict[id]
                ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg"
            else:
                ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg"
            assert os.path.exists(ori_image_path)
            with open(ori_image_path, 'rb') as f:
                img = Image.open(f).convert("RGB")
            # resize image keep aspect ratio
            if self.resolution is not None:
                img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img)
            # write image
            os.makedirs(os.path.dirname(image_path), exist_ok=True)
            img.save(image_path)

        return img

    
    def _load_caption(self, id: int):
        caption_path = f"{self.caption_folder_path}/sa_{id}.txt"
        if not os.path.exists(caption_path):
            return None
        try:
            with open(caption_path, 'r', encoding="utf-8") as f:
                content = f.read()
        except Exception as e:
            raise e
            print(f"Error reading caption file {caption_path}, error: {e}")
            return None
        sentences = content.split('.')
        # remove empty sentences and sentences with "black and white"(too many false prediction)
        sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence]
        # join sentence
        sentences = ". ".join(sentences)
        if len(sentences) > 0 and sentences[-1] != '.':
            sentences += '.'

        return sentences
    
    def with_transform(self, transform):
        self.transforms = transform
        return self

    def subsample(self, n: int = 10000):
        if n is None or n == -1:
            return self
        ori_len = len(self)
        assert n <= ori_len
        # equal interval subsample
        ids = self.ids[::ori_len // n][:n]
        self.ids = ids
        print(f"SAM dataset subsampled from {ori_len} to {len(self)}")
        return self


if __name__ == "__main__":
    # sam_filt(caption_filt=False, clip_filt=False, clip_logit=True)
    from custom_datasets.sam_caption.mypath import MyPath
    dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"))
    dataset.get_img = False
    for i in tqdm.tqdm(dataset):
        a=i['text']
