from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
from torchvision import transforms
import torch
import os 

class ForgetMeNotDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        tokenizer,
        size=512,
        center_crop=False,
        use_added_token= False,
        use_pooler=False,
        multi_concept=None
    ):  
        self.use_added_token = use_added_token
        self.use_pooler = use_pooler
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_images_path  = []
        self.instance_prompt  = []

        token_idx = 1
        for c, t, num_tok in multi_concept:
            p = Path("data", c)
            if not p.exists():
                raise ValueError(f"Instance {p} images root doesn't exists.")                   
            
            image_paths = list(p.iterdir())
            self.instance_images_path += image_paths

            target_snippet = f"{''.join([ f'<s{token_idx + i}>' for i in range(num_tok)])}" if use_added_token else c.replace("-", " ")
            if t == "object":
                self.instance_prompt += [(f"a photo of {target_snippet}", target_snippet)] * len(image_paths)
            elif t == "style":
                self.instance_prompt += [(f"a photo in the style of {target_snippet}", target_snippet)] * len(image_paths)
            else:
                raise ValueError("unknown concept type!")
            if use_added_token:
                token_idx += num_tok
        self.num_instance_images = len(self.instance_images_path)
        self._length = self.num_instance_images

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        instance_prompt, target_tokens = self.instance_prompt[index % self.num_instance_images]

        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_prompt"] = instance_prompt
        example["instance_images"] = self.image_transforms(instance_image)

        example["instance_prompt_ids"] = self.tokenizer(
            instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        prompt_ids = self.tokenizer(
            instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length
        ).input_ids

        concept_ids = self.tokenizer(
            target_tokens,
            add_special_tokens=False
        ).input_ids             

        pooler_token_id = self.tokenizer(
            "<|endoftext|>",
            add_special_tokens=False
        ).input_ids[0]

        concept_positions = [0] * self.tokenizer.model_max_length
        for i, tok_id in enumerate(prompt_ids):
            if tok_id == concept_ids[0] and prompt_ids[i:i + len(concept_ids)] == concept_ids:
                concept_positions[i:i + len(concept_ids)] = [1]*len(concept_ids)
            if self.use_pooler and tok_id == pooler_token_id:
                concept_positions[i] = 1
        example["concept_positions"] = torch.tensor(concept_positions)[None]               

        return example


def collate_fn(examples):
    input_ids = [example["instance_prompt_ids"] for example in examples]
    concept_positions = [example["concept_positions"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]
    instance_prompts =  [example["instance_prompt"] for example in examples]
    
    refine_input_ids = [example["instance_refine_prompt_ids"] for example in examples]
    refine_concept_positions = [example["refine_concept_positions"] for example in examples]
    instance_refine_prompts =  [example["instance_refine_prompt"] for example in examples]
    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)
    refine_input_ids = torch.cat(refine_input_ids, dim=0)
    
    concept_positions = torch.cat(concept_positions, dim=0).type(torch.BoolTensor)
    refine_concept_positions = torch.cat(refine_concept_positions, dim=0).type(torch.BoolTensor)
    
    batch = {
        "instance_refine_prompts": instance_refine_prompts,
        "instance_prompts": instance_prompts,
        "input_ids": input_ids,
        "refine_input_ids": refine_input_ids,
        "pixel_values": pixel_values,
        "refine_concept_positions": refine_concept_positions,
        "concept_positions": concept_positions
    }
    return batch

class MyDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        tokenizer,
        dataset_type,
        concept
    ):  
        if dataset_type == 'celebs':
            Path = os.path.join('./data/celebs',concept)
            refine_concept = "a person"
        elif dataset_type == 'artists':
            Path = os.path.join('./data/artists',concept)
            refine_concept = "normal"
            
        files = os.listdir(Path)
        self.instance_images_path  = []
        self.instance_prompt  = []
        self.instance_refine_prompt  = []
        for img in files:
            if img.endswith('.jpg'):
                self.instance_images_path.append(os.path.join(Path,img))
                if dataset_type == 'celebs':
                    self.instance_prompt.append(("an image of " + concept, concept))
                    self.instance_refine_prompt.append(("an image of " + refine_concept, refine_concept))
                elif dataset_type == 'artists':
                    self.instance_prompt.append(("An artwork in "+ concept +" style.", concept))
                    self.instance_refine_prompt.append(("An artwork in "+ refine_concept +" style.", refine_concept))      
                
        self.tokenizer = tokenizer
        self.num_instance_images = len(self.instance_images_path)
        self._length = self.num_instance_images
        
        self.use_pooler = False
        
        self.image_transforms = transforms.Compose([ 
            transforms.ToTensor(),
            transforms.RandomResizedCrop(
                (512, 512),antialias = True
            ),
            transforms.Normalize(
                [0.48145466, 0.4578275, 0.40821073],
                [0.26862954, 0.26130258, 0.27577711]),
        ])

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        instance_prompt, target_tokens = self.instance_prompt[0]
        instance_refine_prompt, refine_tokens = self.instance_refine_prompt[0]
        
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
            
        example["instance_prompt"] = instance_prompt
        example["instance_refine_prompt"] = instance_refine_prompt
        example["instance_images"] = self.image_transforms(instance_image)

        example["instance_prompt_ids"] = self.tokenizer(
            instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        
        example["instance_refine_prompt_ids"] = self.tokenizer(
            instance_refine_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        
        
        prompt_ids = self.tokenizer(
            instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length
        ).input_ids

        concept_ids = self.tokenizer(
            target_tokens,
            add_special_tokens=False
        ).input_ids             

        refine_prompt_ids = self.tokenizer(
            instance_refine_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length
        ).input_ids

        refine_concept_ids = self.tokenizer(
            refine_tokens,
            add_special_tokens=False
        ).input_ids 
        
        pooler_token_id = self.tokenizer(
            "<|endoftext|>",
            add_special_tokens=False
        ).input_ids[0]

        concept_positions = [0] * self.tokenizer.model_max_length
        for i, tok_id in enumerate(prompt_ids):
            if tok_id == concept_ids[0] and prompt_ids[i:i + len(concept_ids)] == concept_ids:
                concept_positions[i:i + len(concept_ids)] = [1]*len(concept_ids)
            if self.use_pooler and tok_id == pooler_token_id:
                concept_positions[i] = 1
        example["concept_positions"] = torch.tensor(concept_positions)[None] 
        
        concept_positions = [0] * self.tokenizer.model_max_length
        for i, tok_id in enumerate(refine_prompt_ids):
            if tok_id == refine_concept_ids[0] and refine_prompt_ids[i:i + len(refine_concept_ids)] == refine_concept_ids:
                concept_positions[i:i + len(refine_concept_ids)] = [1]*len(refine_concept_ids)
            if self.use_pooler and tok_id == pooler_token_id:
                concept_positions[i] = 1
        example["refine_concept_positions"] = torch.tensor(concept_positions)[None]        
        
        return example
