import json, os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils import data
import numpy as np
from PIL import Image
import PIL
import tqdm
import torch
import pickle
import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
import packaging
import numpy as np


def get_transform():
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=0.5, std=0.5)
    ])

class FolderDataset(Dataset):
    def __init__(self, root, train=True, transform=None, name=None):
        if train: 
            self.root = os.path.join(root, "train")
        else:
            self.root = os.path.join(root, "test")
        self.name = name
        self.transform = transform
        self.image_list = []
        path = self.root
        if not os.path.exists(path): raise Exception("Folder of class is not found.")
        file_list = self.get_imgs(path)
        file_list.sort()
        self.image_list = file_list
        self.image_cached = [np.array(Image.open(self.image_list[i])) for i in range(len(self.image_list))]

    def get_imgs(self, path_name):
        imagelist = []
        for parent, dirnames, filenames in os.walk(path_name):
            for filename in filenames:
                if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                    imagelist.append(os.path.join(parent, filename))
            return imagelist

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image = Image.fromarray(np.uint8(self.image_cached[index]))
        if self.transform: image = self.transform(image)
        return image
    
class ClipFolderDataset(Dataset):
    def __init__(self, root, train=True, name=None):
        if train: 
            self.root = os.path.join(root, "train")
        else:
            self.root = os.path.join(root, "test")
        self.name = name
        self.image_list = []
        path = self.root
        if not os.path.exists(path): raise Exception("Folder of class is not found.")
        file_list = self.get_imgs(path)
        file_list.sort()
        self.image_list = file_list
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clip.load("ViT-B/32", device=device)
        with torch.no_grad():
            self.image_cached = [
                model.encode_image(
                    preprocess(Image.open(self.image_list[i])).unsqueeze(0).to(device)
                ).cpu().numpy()[0, ...]
                for i in range(len(self.image_list))
            ]

    def get_imgs(self, path_name):
        imagelist = []
        for parent, dirnames, filenames in os.walk(path_name):
            for filename in filenames:
                if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                    imagelist.append(os.path.join(parent, filename))
            return imagelist

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image = self.image_cached[index]
        return image
    
class ClipDataset(Dataset):
    def __init__(self, root, prompt, nums=10, train=True, transform=None, name=None):
        if train: 
            self.root = os.path.join(root, "train")
        else:
            self.root = os.path.join(root, "test")
        self.prompts = None
        with open(prompt, "r") as f:
            self.prompts = json.load(f)
        if self.prompts is None: raise Exception("Prompt file load failed.")
        self.nums = nums
        self.name = name
        self.transform = transform
        self.image_list = []
        self.prompt_list = []
        path = self.root
        if not os.path.exists(path): raise Exception("Folder of class is not found.")
        for ii, prompt in enumerate(self.prompts):
            for num in range(self.nums):
                filename = os.path.join(path, f"#{ii}_{num}.png")
                self.image_list.append(filename)
                self.prompt_list.append(prompt)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clip.load("ViT-B/32", device=device)
        with torch.no_grad():
            self.image_cached = [
                model.encode_image(
                    preprocess(Image.open(self.image_list[i])).unsqueeze(0).to(device)
                ).cpu().numpy()[0, ...]
                for i in range(len(self.image_list))
            ]
            
            self.prompt_cached = []
            for i in range(len(self.prompt_list)):
                token = self.tokenize(self.prompt_list[i], truncate=True).to(device)
                embedding = model.encode_text(token).cpu().numpy()[0, ...]
                self.prompt_cached.append(embedding)
    
    def tokenize(self, texts, context_length: int = 77, truncate: bool = False):
        """
        Returns the tokenized representation of given input string(s)

        Parameters
        ----------
        texts : Union[str, List[str]]
            An input string or a list of input strings to tokenize

        context_length : int
            The context length to use; all CLIP models use 77 as the context length

        truncate: bool
            Whether to truncate the text in case its encoding is longer than the context length

        Returns
        -------
        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
        We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
        """
        if isinstance(texts, str):
            texts = [texts]
        _tokenizer = _Tokenizer()
        sot_token = _tokenizer.encoder["<|startoftext|>"]
        eot_token = _tokenizer.encoder["<|endoftext|>"]
        all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
        # print(texts)
        # print(all_tokens)
        # print(np.array(all_tokens).shape)
        if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
            result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
        else:
            result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) > context_length:
                if truncate:
                    tokens = tokens[:context_length]
                    tokens[-1] = eot_token
                else:
                    raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
            result[i, :len(tokens)] = torch.tensor(tokens)

        return result
    
    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image = self.image_cached[index]
        text  = self.prompt_cached[index]
        return image, text

class ClipPromptDataset(Dataset):
    def __init__(self, root, name=None):
        self.root = root
        self.data = os.path.join(self.root, "CLIP-prompts.pkl")
        try:
            with open(self.data, "rb") as fr:
                self.data = pickle.load(fr)
        except:
            raise Exception("Prompt file load failed.")
        self.image_cached = self.data["images"]
        self.prompt_cached = self.data["prompts"]
        self.name = name
    
    def __len__(self):
        return len(self.image_cached)

    def __getitem__(self, index):
        image = self.image_cached[index]
        text  = self.prompt_cached[index]
        return image, text
    
def get_dataloader(path, train=True, name=None):
    transform = get_transform()
    dataset = FolderDataset(path, train=train, transform=transform, name=name)
    dataloader  = data.DataLoader(dataset, batch_size=32 if train else 1, shuffle=False, num_workers=4)
    return dataloader

def get_clip_dataloader(path, train=True, name=None):
    dataset = ClipFolderDataset(path, train=train, name=name)
    dataloader  = data.DataLoader(dataset, batch_size=32 if train else 1, shuffle=False, num_workers=4)
    return dataloader

def get_clip_original_prompt_dataloader(path, prompt, train=True, name=None):
    dataset = ClipDataset(path, prompt, train=train, name=name)
    dataloader  = data.DataLoader(dataset, batch_size=32 if train else 1, shuffle=False, num_workers=4)
    return dataloader

def get_clip_prompt_dataloader(path, bs=1, name=None):
    dataset = ClipPromptDataset(path, name=name)
    dataloader  = data.DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=4)
    return dataloader