import torch
import torchvision

from torchvision import transforms as T
import torchvision

import numpy as np
from PIL import Image
from pathlib import Path
import os
import warnings
import cv2
import gc
import json
from dataset_utils import enrich_dataset

def get_save_dir(args):
    if args.debug:
        idx = 0
        while True:
            curr_dir = os.path.join(str(args.save_dir), f"{idx:03}")
            if not os.path.exists(curr_dir):
                return curr_dir
            if args.skip_used:
                idx += 1
                continue
            # the directory exists
            all_files = os.listdir(curr_dir)

            if "unet" in all_files or "pytorch_lora_weights.safetensors" in all_files or "conv_in.pth" in all_files:
                idx += 1
                continue
            else:
                # this directory is useless
                return curr_dir
    else:
        return args.save_dir

class ImageDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            image_dir,
            width, height,
            orig_obj_prompt=None,
            ic_light_prompt=None,
            save_dir=None,
            generate_using_iclight=True,
            generate_cp_iclight=None,
            generate_n=32,
            generate_obj_similar=32,
            generate_cp_dir=None,
            prob=0.9,
            with_preservation=None,
            add_raw_images_ratio=0.0,
            fixed_place_prompt=None
        ):
        """
        I have to enrich the dataset and add more images
        """

        self.save_dir, self.save_iclight_dir, self.save_preservation_dir = enrich_dataset(
            image_dir=image_dir,
            save_dir=save_dir,
            
            generate_iclight_n=generate_n,
            iclight_prompt=ic_light_prompt,
            just_cp_iclight_dir=generate_cp_iclight,
            do_generate_iclight=(generate_using_iclight != 0),
            add_raw_images_ratio=add_raw_images_ratio,

            generate_preservation_n=generate_obj_similar,
            preservation_prompt=orig_obj_prompt,
            just_cp_preservation_dir=generate_cp_dir,

            prob=prob,
            with_preservation=with_preservation,
            fixed_place_prompt=fixed_place_prompt,
        )

        self.prob = prob
        self.image_dir = Path(image_dir)
        self.width = width
        self.height = height

        self.transform = T.Compose([
            T.ToTensor(),
            T.Resize((self.width, self.height)),
        ])

        self.images_iclight = self.read_all_images(self.save_iclight_dir)
        self.images_preservation = self.read_all_images(self.save_preservation_dir)

    def read_all_images(self, d, exts=("png", "jpg")):
        d = Path(d)
        paths = []
        for ext in exts:
            paths.extend(d.glob(f"*.{ext}"))
        paths.sort()
        images = []
        for p in paths:
            images.append(Image.open(p))
        return images

    def __len__(self):
        return len(self.images_iclight)

    def __getitem__(self, idx):
        if np.random.rand() < self.prob:
            image = self.images_iclight[idx]
            emb_idx = torch.tensor([1])
        else:
            random_idx = np.random.choice(len(self.images_preservation))
            image = self.images_preservation[random_idx]
            emb_idx = torch.tensor([0])
        image = self.transform(image)
        return image, emb_idx
