import os
import requests
from io import BytesIO
import PIL.Image as PilImage
from tqdm import tqdm
from nesim.utils.json_stuff import load_json_as_dict
from clip_retrieval.clip_client import ClipClient


class ImageInput:
    """Wrapper to handle image inputs both from local paths and urls
    Args:
        path_or_url (str): path or link to image.
    """

    def __init__(self, path_or_url):

        self.path_or_url = path_or_url
        if self.path_or_url.startswith("http://") or self.path_or_url.startswith(
            "https://"
        ):
            try:
                response = requests.get(path_or_url)
                self.pil_image = PilImage.open(BytesIO(response.content))
            except:
                raise Exception(
                    f"Could not retrieve image from url:\n{self.path_or_url}"
                )
        else:
            self.pil_image = PilImage.open(path_or_url)


folder_names_map = load_json_as_dict(filename="dataset_info.json")

dataset_folder = "./datasets/clip_retrieval_dataset"

os.system(f"mkdir -p {dataset_folder}")

client = ClipClient(
    url="XXXX",
    indice_name="laion5B-L-14",
    num_images=40,
    use_safety_model=False,
)

for i in range(len(folder_names_map)):

    os.system(f"mkdir -p {os.path.join(dataset_folder, folder_names_map[i]['folder'])}")

    pil_images = []
    for caption_data in folder_names_map[i]["captions"]:
        results = client.query(text=caption_data["caption"])

        for sample_idx in tqdm(
            range(min(len(results), caption_data["num_samples"])),
            desc=f'Downloading images for caption: "{caption_data["caption"]}"',
        ):
            try:
                pil_image = ImageInput(path_or_url=results[sample_idx]["url"]).pil_image
                pil_image = pil_image.convert("RGB")
                pil_images.append(pil_image)
            except:
                print(f'Skipped: {results[sample_idx]["url"]}')

    for sample_idx in range(len(pil_images)):
        filename = os.path.join(
            dataset_folder, folder_names_map[i]["folder"], f"{sample_idx}.jpg"
        )
        pil_image = pil_images[sample_idx]
        pil_image.save(filename)
        print(f"Saved: {filename}")
