import torch
from torch.utils.data import Dataset
import numpy as np
import os
from selection_algorithms import match_greedy
from torchvision.transforms import ToTensor
from torchvision import transforms, datasets
from feature_extractors import get_imagenet100_features
import webdataset as wds
class ImagenetGeneratedDataset(Dataset):
    def __init__(self, dataset_name, count, transform, train_features, method="Covariance_matching",zero_centered=False, clip=False, leak_dataset=None, leak_features = None, prune=True):
        self.samples = []
        self.indexes = []
        self.dataset_name = dataset_name
        self.dataset_path = "<dataset_path>"
        self.transform = transform
        self.prune = prune
        self.preload_imagefolder()
        self.is_clip_features = clip
        self.set_dataset_based_constants()

        self.dinofeatures = None
        self.load_features(leak_features)
        self.transform = transform
        self.get_subset(train_features,count,method,zero_centered)
        self.fix_index_for_classes()
        self.fix_image_selection()

    def set_dataset_based_constants(self):
        if self.dataset_name == "cifar10":
            self.number_of_classes = 10
        elif self.dataset_name == "imagenet":
            self.number_of_classes = 100

    def preload_imagefolder(self):

        if self.dataset_name == "imagenet":
            self.dataset = datasets.ImageFolder(root='<generated path>', transform=self.transform)

    def fix_index_for_classes(self):
        for i in range(len(self.indexes)):
            self.indexes[i] = torch.where(self.indexes[i])[0]

    def fix_image_selection(self):
        samples = []
        for class_idx in range(100):
            for img_idx in self.indexes[class_idx]:
                path = '<generated path>'
                if os.path.isfile(path):
                    samples.append((path, class_idx))
                else:
                    print("Invalid path")

        self.dataset.samples = samples
        self.dataset.targets = [label for _, label in self.dataset.samples]

    def fix_index_for_sampling(self):
        self.class_separated_indices = self.indexes.copy()
        last_idx = 0
        indexes_concatenated = []
        for index_tensor in self.indexes:
            indexlist = np.array(index_tensor)
            indexvalues = torch.where(torch.tensor(indexlist))[0]
            indexvalues += last_idx
            indexvalues = indexvalues.tolist()
            last_idx += len(indexlist)
            indexes_concatenated.extend(indexvalues)
        self.indexes = indexes_concatenated

    def get_dataset(self):
        return self.dataset

    def load_features(self, leak_features=None):
        clip = self.is_clip_features

        self.clipfeatures = dict(np.load('<clip features path>',allow_pickle=True))
        if clip:
            self.dinofeatures = self.clipfeatures
        self.textfeatures = get_imagenet100_features()

    def __len__(self): # not needed as we output the dataset 
        pass

    def __getitem__(self, idx): # not needed as we output the dataset 
        pass

    def get_number_of_chosen_leaks(self):
        counter = 0
        for i in range(len(self.indexes)):
            idx = self.indexes[i]
            is_leak = self.tensor_data[2][idx].item()
            if is_leak:
                counter += 1
        return counter

    def get_class_generated_features(self):
        features = {}
        for class_id in range(self.number_of_classes):
            idxes = self.indexes[class_id]
            features[class_id] = self.dinofeatures[f"{class_id:02d}"][idxes]
        return features

    def get_subset(self, train_features,count,method="Covariance_matching",zero_centered=False,prune=True):
        prune = self.prune
        self.indexes = []
        using_clip_features = False
        if method == "Text_matching" or self.is_clip_features:
            self.dinofeatures = self.clipfeatures
            using_clip_features = True

        for class_id in range(self.number_of_classes):
            class_name = f"{class_id:02d}"
            train_i_features = train_features[f"{class_id}"]
            gen_i_features = self.dinofeatures[class_name]
            gen_i_clip_features = self.clipfeatures[class_name]
            text_features_class = self.textfeatures[class_id]
            _,selection_indexes = match_greedy(train_i_features,gen_i_features,count,method,zero_centered,text_feature=text_features_class,gen_clip_features=gen_i_clip_features,prune=prune, using_clip_features=using_clip_features)

            self.indexes.append(torch.tensor(selection_indexes))
