from re import S
import torch
from torch.utils.data import Dataset
import numpy as np
import random
from selection_algorithms import match_greedy
from torchvision.transforms import ToTensor
import time
from feature_extractors import extract_clip_text_features

class GeneratedDataset(Dataset):
    def __init__(self, dataset_name, model_names, count, transform, train_features, method="cosine-near",zero_centered=False,\
                  clip=False, leak_dataset=None, leak_features = None, expand=1, read_amount=None,prune=False,add_synthetic=False):
        self.samples = []
        self.indexes = []
        self.model_names = model_names
        self.prune = prune
        self.expand = expand
        self.read_amount = read_amount
        self.dataset_name = dataset_name
        self.is_clip_features = clip
        self.dataset_based_constants()
        self.load_tensordata(leak_dataset)
        self.dinofeatures = None
        self.load_features(leak_features)
        if add_synthetic:
            self.init_synthetic_experiment(train_features)
        self.transform = transform
        self.get_subset(train_features,count,method,zero_centered)

    def dataset_based_constants(self):
        if self.dataset_name == "cifar10":
            self.number_of_classes = 10
        elif self.dataset_name == "imagenet":
            self.number_of_classes = 100

    def fix_index_for_sampling(self):

        needs_fix = True
        if isinstance(self.indexes, list):
            for i in range(len(self.indexes)):
                if sum(self.indexes[i] > 1) > 0:
                    needs_fix = False
                    break
        if not needs_fix:
            return
        for i in range(len(self.indexes)):
            self.indexes[i] = torch.where(self.indexes[i])[0]

    def load_tensordata(self,leak_dataset):
        tensorimgs = torch.tensor([])
        tensorlabels = torch.tensor([])
        for i,model_name in enumerate(self.model_names):

            model_tensor = torch.load(f"<path>/{model_name}.pt")
            model_tensorimgs = model_tensor[0][:, :self.read_amount[i]]
            model_tensorlabels = model_tensor[1][:, :self.read_amount[i]]

            tensorimgs = torch.cat((tensorimgs, model_tensorimgs), dim=1)
            tensorlabels = torch.cat((tensorlabels, model_tensorlabels), dim=1)
        self.tensor_data = (tensorimgs, tensorlabels.to(torch.long))

        if leak_dataset is not None:
            gen_images, gen_labels = self.tensor_data
            transform = ToTensor()

            per_class_images = []
            per_class_labels = []
            per_class_isleak = []

            for class_id in range(self.number_of_classes):

                class_imgs = gen_images[class_id]
                class_lbls = gen_labels[class_id]
                class_isleak = torch.zeros(len(class_imgs), dtype=torch.bool)

                leak_imgs = []
                for img, label in leak_dataset:
                    if label == class_id:
                        leak_imgs.append(transform(img))

                if leak_imgs:
                    leak_imgs_tensor = torch.stack(leak_imgs)
                    leak_lbls_tensor = torch.full((len(leak_imgs),), class_id, dtype=torch.long)
                    leak_flags_tensor = torch.ones(len(leak_imgs), dtype=torch.bool)

                    class_imgs = torch.cat([class_imgs, leak_imgs_tensor], dim=0)
                    class_lbls = torch.cat([class_lbls, leak_lbls_tensor], dim=0)
                    class_isleak = torch.cat([class_isleak, leak_flags_tensor], dim=0)

                per_class_images.append(class_imgs)
                per_class_labels.append(class_lbls)
                per_class_isleak.append(class_isleak)

            self.tensor_data = (
                torch.stack(per_class_images),
                torch.stack(per_class_labels),
                torch.stack(per_class_isleak),
            )

    

    def load_features(self, leak_features=None):
        clip = self.is_clip_features
        self.dinofeatures = {}
        for i,model_name in enumerate(self.model_names):

            model_features = np.load(f"<path>/{model_name}.npz", allow_pickle=True)
            for key in model_features.keys():
                class_name = key.split('_')[0]
                if class_name not in self.dinofeatures:
                    self.dinofeatures[class_name] = []
                self.dinofeatures[class_name].append(model_features[key][:self.read_amount[i]])
        for class_name in self.dinofeatures.keys():
            self.dinofeatures[class_name] = np.concatenate(self.dinofeatures[class_name], axis=0)
        self.dinofeatures = dict(self.dinofeatures)
        self.dino_features_fixed = self.dinofeatures.copy()
        self.load_clip_features()
        if clip:
            self.dinofeatures = self.clipfeatures.copy()
        if leak_features is not None:
            for i in range(self.number_of_classes):
                class_name = f"{i:02d}"
                self.dinofeatures[class_name] = np.concatenate(
                [self.dinofeatures[class_name], leak_features[f"{i}"]], axis=0
                )

    def __len__(self):
        len_all_indexes = self.indexes.shape[1] * self.indexes.shape[0]
        return len_all_indexes * self.expand

    def calculate_class_of_index(self, idx):
        class_length = self.indexes.shape[1]
        class_id = idx // class_length
        idx = idx % class_length
        return class_id, idx

    def __getitem__(self, idx):
        if self.expand > 1:
            idx = idx // self.expand
        class_id, idx = self.calculate_class_of_index(idx)
        idx = self.indexes[class_id][idx]
        img = self.tensor_data[0][class_id][idx]
        label = self.tensor_data[1][class_id][idx].item()

        img = self.transform(img)
        return img, label

    def load_clip_features(self):
        self.clipfeatures = {}
        for i,model_name in enumerate(self.model_names):

            model_features = np.load(f"<path>/{model_name}.npz", allow_pickle=True)
            for key in model_features.keys():
                class_name = key.split('_')[0]
                if class_name not in self.clipfeatures:
                    self.clipfeatures[class_name] = []
                self.clipfeatures[class_name].append(model_features[key][:self.read_amount[i]])
        for class_name in self.clipfeatures.keys():
            self.clipfeatures[class_name] = np.concatenate(self.clipfeatures[class_name], axis=0)
        self.clipfeatures = dict(self.clipfeatures)
        cifar_classes = ["airplane", "automobile", "bird", "cat", "deer",
                        "dog", "frog", "horse", "ship", "truck"]
        sentences = [f"a photo of a {cifar_class}." for cifar_class in cifar_classes]
        self.textfeatures = extract_clip_text_features(sentences, device="cuda:0")

    def get_number_of_chosen_leaks(self):
        counter = 0
        for class_id in range(self.number_of_classes):
            for i in range(len(self.indexes[class_id])):
                idx = self.indexes[class_id][i]
                is_leak = self.tensor_data[2][class_id][idx].item()
                if is_leak:
                    counter += 1
        return counter

    def get_class_generated_features(self,idx, is_dino=False):

        genidx = self.indexes[idx]

        class_name = f"{idx:02d}"
        if is_dino:
            return self.dino_features_fixed[class_name][genidx]
        return self.dinofeatures[class_name][genidx]

    def get_class_generated_tensors(self,idx):
        genidx = self.indexes[idx]
        class_tensor = self.tensor_data[0][idx][genidx]
        return class_tensor

    def get_all_features_separately(self,class_id, train_features):
        cls1, other = [], []
        for i in range(10):
            chunk = [np.asarray(self.dinofeatures[f"{i:02d}"]),
                    np.asarray(train_features[f"{i}"])]
            (cls1 if i == class_id else other).extend(chunk)

        dim = self.dinofeatures["00"].shape[1]
        stack = lambda L: torch.from_numpy(np.vstack(L)) if L else torch.empty((0, dim))

        all_class_features  = stack(cls1)
        all_other_features  = stack(other)

        return all_class_features, all_other_features

    def get_subset(self, train_features,count,method="Covariance_matching",zero_centered=False,prune=True):
        prune = self.prune
        self.indexes = []
        using_clip_features = False
        if method == "Text_matching" or self.is_clip_features:
            self.dinofeatures = self.clipfeatures
            using_clip_features = True

        for class_id in range(10):
            class_name = f"{class_id:02d}"
            train_i_features = train_features[f"{class_id}"]
            gen_i_features = self.dinofeatures[class_name]
            gen_i_clip_features = self.clipfeatures[class_name]
            gen_i_all_features_tuple = self.get_all_features_separately(class_id, train_features)
            text_features_class = self.textfeatures[class_id]
            _,selection_indexes = match_greedy(train_i_features,gen_i_features,count,method,zero_centered,text_feature=text_features_class,gen_clip_features=gen_i_clip_features,prune=prune, using_clip_features=using_clip_features,\
                all_features_separately=gen_i_all_features_tuple)

            self.indexes.append(torch.tensor(selection_indexes))

        self.fix_index_for_sampling()
        self.indexes = np.array(self.indexes, dtype=np.int64)

    def init_synthetic_experiment(self,train_features):
        prune=True
        zero_centered=False
        using_clip_features = False

        if self.is_clip_features:
            self.dinofeatures = self.clipfeatures.copy()
            using_clip_features = True
        center_best_indices = []
        for class_id in range(10):
            class_name = f"{class_id:02d}"
            train_i_features = train_features[f"{class_id}"]
            gen_i_features = self.dinofeatures[class_name]
            gen_i_clip_features = self.clipfeatures[class_name]
            text_features_class = self.textfeatures[class_id]
            _, selection_indexes = match_greedy(train_i_features, gen_i_features, 1, "nearest-to-center", zero_centered, text_feature=text_features_class, gen_clip_features=gen_i_clip_features, prune=prune, using_clip_features=using_clip_features, all_features_separately=None)
            selection_indexes = torch.where(torch.tensor(selection_indexes))[0]
            center_best_indices.append(selection_indexes)

        using_clip_features = True
        text_best_indices = []
        for class_id in range(10):
            class_name = f"{class_id:02d}"
            train_i_features = train_features[f"{class_id}"]
            gen_i_features = self.clipfeatures[class_name]
            gen_i_clip_features = self.clipfeatures[class_name]
            text_features_class = self.textfeatures[class_id]
            _, selection_indexes = match_greedy(train_i_features, gen_i_features, 1, "closest-to-text", zero_centered, text_feature=text_features_class, gen_clip_features=gen_i_clip_features, prune=prune, using_clip_features=using_clip_features, all_features_separately=None)
            selection_indexes = torch.where(torch.tensor(selection_indexes))[0]
            text_best_indices.append(selection_indexes)

        SYNTH_COUNT = 4000
        tensorimgs, tensorlabels = self.tensor_data
        tensor_imgs_initial = torch.zeros((tensorimgs.shape[0],SYNTH_COUNT*2,tensorimgs.shape[2],tensorimgs.shape[3],tensorimgs.shape[4]), dtype=tensorimgs.dtype)
        tensor_labels_initial = torch.zeros((tensorlabels.shape[0],SYNTH_COUNT*2), dtype=tensorlabels.dtype)
        for class_id in range(10):
            class_name = f"{class_id:02d}"
            appending_feature = self.dinofeatures[class_name][text_best_indices[class_id]].reshape(1,-1)
            appending_feature_2 = self.dinofeatures[class_name][center_best_indices[class_id]].reshape(1,-1)
            self.dinofeatures[class_name] = np.concatenate([self.dinofeatures[class_name], np.tile(appending_feature, (SYNTH_COUNT, 1))], axis=0)
            self.dinofeatures[class_name] = np.concatenate([self.dinofeatures[class_name], np.tile(appending_feature_2, (SYNTH_COUNT, 1))], axis=0)
            appending_clip_feature = self.clipfeatures[class_name][text_best_indices[class_id]].reshape(1,-1)
            appending_clip_feature_2 = self.clipfeatures[class_name][center_best_indices[class_id]].reshape(1,-1)
            self.clipfeatures[class_name] = np.concatenate([self.clipfeatures[class_name], np.tile(appending_clip_feature, (SYNTH_COUNT, 1))], axis=0)
            self.clipfeatures[class_name] = np.concatenate([self.clipfeatures[class_name], np.tile(appending_clip_feature_2, (SYNTH_COUNT, 1))], axis=0)

            tensor_imgs_initial[class_id][:SYNTH_COUNT] = torch.tensor(np.tile(tensorimgs[class_id][text_best_indices[class_id]], (SYNTH_COUNT, 1, 1, 1)))
            tensor_labels_initial[class_id][:SYNTH_COUNT] = torch.tensor(np.tile(tensorlabels[class_id][text_best_indices[class_id]], (SYNTH_COUNT)))
            tensor_imgs_initial[class_id][SYNTH_COUNT:] = torch.tensor(np.tile(tensorimgs[class_id][center_best_indices[class_id]], (SYNTH_COUNT, 1, 1, 1)))
            tensor_labels_initial[class_id][SYNTH_COUNT:] = torch.tensor(np.tile(tensorlabels[class_id][center_best_indices[class_id]], (SYNTH_COUNT)))

        tensorimgs = torch.cat([tensorimgs, tensor_imgs_initial], dim=1)
        tensorlabels = torch.cat([tensorlabels, tensor_labels_initial], dim=1)
        self.tensor_data = (tensorimgs, tensorlabels)
        self.dino_features_fixed = self.dinofeatures
        return 0
