import torch
from torch.utils.data import Dataset
import numpy as np
import os
import random
from selection_algorithms import match_greedy
from torchvision.transforms import ToTensor
from torchvision import transforms, datasets
import webdataset as wds
class ImagenetGeneratedDataset(Dataset):
    def __init__(self, dataset_name, count, transform, train_features, method="cosine-near",zero_centered=False, clip=False, leak_dataset=None, leak_features = None):
        self.samples = []
        self.indexes = []
        self.dataset_name = dataset_name
        self.transform = transform
        self.preload_imagefolder()
        self.is_clip_features = clip
        self.dataset_based_constants()
        # self.load_tensordata(dataset_name, leak_dataset)
        self.dinofeatures = None
        self.load_features(leak_features)
        self.transform = transform
        self.get_subset(train_features,count,method,zero_centered)
        self.fix_index_for_classes()
        print("before fixing")
        self.fix_image_selection()
        print("init done")
        # self.fix_index_for_sampling()


    def dataset_based_constants(self):
        if self.dataset_name == "cifar10":
            self.number_of_classes = 10
        elif self.dataset_name == "imagenet":
            self.number_of_classes = 100

    def preload_imagefolder(self):
        # pass
        if self.dataset_name == "imagenet":
            self.dataset = datasets.ImageFolder(root='../imagenet100sd14', transform=self.transform)
            # dataset = (
            #     wds.WebDataset('./imagenet100sd14.tar')
            #     .decode("pil")  # decode PNG images to PIL
            #     .to_tuple("png", "cls")  # extract image and label
            #     .map_tuple(self.transform, lambda x: int(x))  # apply transform, label to int
            # )

    def fix_index_for_classes(self):
        for i in range(len(self.indexes)):
            self.indexes[i] = torch.where(self.indexes[i])[0]

    def fix_image_selection(self):
        samples = []
        for class_idx in range(100):
            for img_idx in self.indexes[class_idx]:
                path = f'../imagenet100sd14/{class_idx}/{img_idx:04d}.png'
                if os.path.isfile(path):  # check existence
                    samples.append((path, class_idx))
                else:
                    print("Invalid pathh")
                # samples.append((f'../imagenet100sd14/{class_idx}/{img_idx:04d}.png',class_idx))

        self.dataset.samples = samples
        self.dataset.targets = [label for _, label in self.dataset.samples]
        # print("start fixing")
        # allowed_keys = {
        #     f"{class_idx}_{img_idx:04d}"
        #     for class_idx in range(100)
        #     for img_idx in self.indexes[class_idx]
        # }
        # print("make the dataset")       
        # self.dataset = (
        #         wds.WebDataset('./imagenet100sd14.tar')
        #         .select(lambda sample: sample["__key__"] in allowed_keys)
        #         .decode("pil")  # decode PNG images to PIL
        #         .to_tuple("png", "cls")  # extract image and label
        #         .map_tuple(self.transform, lambda x: int(x))  # apply transform, label to int
        #     )
        # print("dataset built")
        
        
    def fix_index_for_sampling(self):
        self.class_separated_indices = self.indexes.copy()
        last_idx = 0
        indexes_concatenated = []
        for index_tensor in self.indexes:
            indexlist = np.array(index_tensor)
            indexvalues = torch.where(torch.tensor(indexlist))[0]
            indexvalues += last_idx
            indexvalues = indexvalues.tolist()
            last_idx += len(indexlist)
            indexes_concatenated.extend(indexvalues)
        self.indexes = indexes_concatenated

    def get_dataset(self):
        print(len(self.dataset))
        return self.dataset
        
    def load_features(self, leak_features=None):
        clip = self.is_clip_features
        if self.dataset_name == "imagenet":
            if clip:
                self.dinofeatures = np.load("./NOT YET")
            else:
                self.dinofeatures = dict(np.load("../../dino-features/imagenetsd14.npz",allow_pickle=True)) ###FIX PATHS
        elif self.dataset_name == "cifar10":
            if clip:
                self.dinofeatures = dict(np.load("../../dino-features/cifar10sd14_clipfeatures.npz",allow_pickle=True))
            else:
                self.dinofeatures = dict(np.load("../../dino-features/cifar10sd14-4000-dino.npz",allow_pickle=True))
        if leak_features is not None:
            for i in range(self.number_of_classes):
                class_name = f"{i:02d}_" # FIIXX FOR IMAGENET
                class_dir = [k for k in self.dinofeatures.keys() if k.startswith(class_name)][0]
                self.dinofeatures[class_dir] = np.concatenate(
                [self.dinofeatures[class_dir], leak_features[f"{i}"]], axis=0
                )
        
    def __len__(self):
        return len(self.indexes)

    def __getitem__(self, idx):
        idx = self.indexes[idx]
        img = self.tensor_data[0][idx]
        label = self.tensor_data[1][idx].item()
        # img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, label


    def get_number_of_chosen_leaks(self):
        counter = 0
        for i in range(len(self.indexes)):
            idx = self.indexes[i]
            is_leak = self.tensor_data[2][idx].item()
            if is_leak:
                counter += 1
        return counter

    def get_class_generated_features(self,idx):
        genidx = self.class_separated_indices[idx]
        genidx = torch.where(torch.tensor(genidx))[0]
        # print(genidx)
        class_name = f"{idx:02d}_" # FIX FOR IMAGENET
        class_dir = [k for k in self.dinofeatures.keys() if k.startswith(class_name)][0]
        return self.dinofeatures[class_dir][genidx]

    def get_subset(self, train_features,count,method="cosine-near",zero_centered=False):
        self.indexes = []
        print(method)
        if method == "random":
            for class_id in range(self.number_of_classes):
                class_name = f"{class_id}"
                leng = self.dinofeatures[class_name].shape[0]
                selected_indices = random.sample(range(leng), count)
                selection_mask = [i in selected_indices for i in range(leng)]
                self.indexes.append(torch.tensor(selection_mask))

        else:
            for class_id in range(self.number_of_classes):
                class_name = f"{class_id}"
                train_i_features = train_features[f"{class_id}"]
                # leng = (self.tensor_data[1] == class_id).sum().item()
                gen_i_features = self.dinofeatures[class_dir]
                _,selection_indexes = match_greedy(train_i_features,gen_i_features,count,method,zero_centered)
                self.indexes.append(torch.tensor(selection_indexes))
