
import os
import pickle
import random
import shutil
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class LhqDataset(Dataset):
    def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None,
                 get_img=True,
                 get_cap=True,):

        if isinstance(id_file, list):
            self.ids = id_file
        elif isinstance(id_file, str):
            with open(id_file, 'rb') as f:
                print(f"Loading ids from {id_file}", flush=True)
                self.ids = pickle.load(f)
                print(f"Loaded ids from {id_file}", flush=True)
        self.image_folder_path = image_folder_path
        self.caption_folder_path = caption_folder_path
        self.transforms = transforms
        self.column_names = ["image", "text"]
        self.get_img = get_img
        self.get_cap = get_cap

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index: int):
        id = self.ids[index]
        ret={"id":id}
        if self.get_img:
            image = self._load_image(id)
            ret["image"]=image
        if self.get_cap:
            target = self._load_caption(id)
            ret["caption"]=[target]
        if self.transforms is not None:
            ret = self.transforms(ret)
        return ret

    def _load_image(self, id: int):
        image_path = f"{self.image_folder_path}/{id}.jpg"
        with open(image_path, 'rb') as f:
            img = Image.open(f).convert("RGB")
        return img

    def _load_caption(self, id: int):
        caption_path = f"{self.caption_folder_path}/{id}.txt"
        with open(caption_path, 'r') as f:
            caption_file = f.read()
        caption = []
        for line in caption_file.split("\n"):
            line = line.strip()
            if len(line) > 0:
                caption.append(line)
        return caption

    def subsample(self, n: int = 10000):
        if n is None or n == -1:
            return self
        ori_len = len(self)
        assert n <= ori_len
        # equal interval subsample
        ids = self.ids[::ori_len // n][:n]
        self.ids = ids
        print(f"LHQ dataset subsampled from {ori_len} to {len(self)}")
        return self

    def with_transform(self, transform):
        self.transforms = transform
        return self


def generate_idx(data_folder = "", save_path = ""):
    all_ids = os.listdir(data_folder)
    all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")]
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    pickle.dump(all_ids, open(f"{save_path}", "wb"))
    print("all_ids generated")
    return all_ids

def random_sample(all_ids, sample_num = 110, save_root = ""):
    chosen_id = random.sample(all_ids, sample_num)
    save_dir = f"{save_root}/{sample_num}"
    os.makedirs(save_dir, exist_ok=True)
    for id in chosen_id:
        img_path = f"lhq_1024_jpg/{id}.jpg"
        shutil.copy(img_path, save_dir)

    return chosen_id




    
