import torch
from torch.utils.data import Dataset
import numpy as np
import random
from selection_algorithms import match_greedy
from torchvision.transforms import ToTensor

class GeneratedDataset(Dataset):
    def __init__(self, dataset_name, model_names, count, transform, train_features, method="cosine-near",zero_centered=False, clip=False, leak_dataset=None, leak_features = None, expand=1, read_amount=None):
        self.samples = []
        self.indexes = []
        self.model_names = model_names
        self.expand = expand
        self.read_amount = read_amount
        self.dataset_name = dataset_name
        self.is_clip_features = clip
        self.dataset_based_constants()
        self.load_tensordata(leak_dataset)
        self.dinofeatures = None
        self.load_features(leak_features)
        self.transform = transform
        self.get_subset(train_features,count,method,zero_centered)
        self.fix_index_for_sampling()
        self.indexes = np.array(self.indexes)


    def dataset_based_constants(self):
        if self.dataset_name == "cifar10":
            self.number_of_classes = 10
        elif self.dataset_name == "imagenet":
            self.number_of_classes = 100

    def fix_index_for_sampling(self):
        for i in range(len(self.indexes)):
            self.indexes[i] = torch.where(self.indexes[i])[0]

    def load_tensordata(self,leak_dataset):
        tensorimgs = torch.tensor([])
        tensorlabels = torch.tensor([])
        for i,model_name in enumerate(self.model_names):
            model_tensor = torch.load(f"../../tensordataset/cifar10{model_name}_bc.pt")
            model_tensorimgs = model_tensor[0][:, :self.read_amount[i]]
            model_tensorlabels = model_tensor[1][:, :self.read_amount[i]]
            # self.tensor_data = torch.cat((self.tensor_data, model_tensor), dim=1)
            tensorimgs = torch.cat((tensorimgs, model_tensorimgs), dim=1)
            tensorlabels = torch.cat((tensorlabels, model_tensorlabels), dim=1)
        self.tensor_data = (tensorimgs, tensorlabels.to(torch.long))

        if leak_dataset is not None:
            gen_images, gen_labels = self.tensor_data  # [10, N, C, H, W], [10, N]
            transform = ToTensor()

            per_class_images = []
            per_class_labels = []
            per_class_isleak = []

            for class_id in range(self.number_of_classes):
                # Original samples
                class_imgs = gen_images[class_id]       # [N, C, H, W]
                class_lbls = gen_labels[class_id]       # [N]
                class_isleak = torch.zeros(len(class_imgs), dtype=torch.bool)

                # Leak samples
                leak_imgs = []
                for img, label in leak_dataset:
                    if label == class_id:
                        leak_imgs.append(transform(img))

                if leak_imgs:
                    leak_imgs_tensor = torch.stack(leak_imgs)  # [L, C, H, W]
                    leak_lbls_tensor = torch.full((len(leak_imgs),), class_id, dtype=torch.long)
                    leak_flags_tensor = torch.ones(len(leak_imgs), dtype=torch.bool)

                    # Concatenate generated and leaked samples
                    class_imgs = torch.cat([class_imgs, leak_imgs_tensor], dim=0)
                    class_lbls = torch.cat([class_lbls, leak_lbls_tensor], dim=0)
                    class_isleak = torch.cat([class_isleak, leak_flags_tensor], dim=0)

                # Save per class
                per_class_images.append(class_imgs)
                per_class_labels.append(class_lbls)
                per_class_isleak.append(class_isleak)

            # Stack final tensors: [10, N', ...]
            self.tensor_data = (
                torch.stack(per_class_images),         # [10, N', C, H, W]
                torch.stack(per_class_labels),         # [10, N']
                torch.stack(per_class_isleak),         # [10, N'] (bool)
            )
            print(self.tensor_data[0].shape, self.tensor_data[1].shape, self.tensor_data[2].shape)
                

    def load_features(self, leak_features=None):
        clip = self.is_clip_features # CURRENTLY ONLY DINO
        self.dinofeatures = {}
        for i,model_name in enumerate(self.model_names):
            model_features = np.load(f"../../dino-features/{self.dataset_name}{model_name}.npz", allow_pickle=True)
            for key in model_features.keys():
                class_name = key.split('_')[0]  # Assuming keys are like '00_...'
                if class_name not in self.dinofeatures:
                    self.dinofeatures[class_name] = []
                self.dinofeatures[class_name].append(model_features[key][:self.read_amount[i]])
        for class_name in self.dinofeatures.keys():
            self.dinofeatures[class_name] = np.concatenate(self.dinofeatures[class_name], axis=0)
        self.dinofeatures = dict(self.dinofeatures)  # Convert to dict if not already
        if leak_features is not None: #Probably fixed
            for i in range(self.number_of_classes):
                class_name = f"{i:02d}" 
                # class_dir = [k for k in self.dinofeatures.keys() if k.startswith(class_name)][0]
                self.dinofeatures[class_name] = np.concatenate(
                [self.dinofeatures[class_name], leak_features[f"{i}"]], axis=0
                )
            print("shape:", self.dinofeatures["03"].shape)
        
    def __len__(self):
        len_all_indexes = self.indexes.shape[1] * self.indexes.shape[0]
        return len_all_indexes * self.expand
    
    def calculate_class_of_index(self, idx):
        class_length = self.indexes.shape[1]
        class_id = idx // class_length
        idx = idx % class_length
        return class_id, idx

    def __getitem__(self, idx):
        if self.expand > 1:
            idx = idx // self.expand
        class_id, idx = self.calculate_class_of_index(idx)
        idx = self.indexes[class_id][idx]
        img = self.tensor_data[0][class_id][idx]
        label = self.tensor_data[1][class_id][idx].item()
        # img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, label


    def get_number_of_chosen_leaks(self):
        counter = 0
        for i in range(len(self.indexes)):
            idx = self.indexes[i]
            is_leak = self.tensor_data[2][idx].item()
            if is_leak:
                counter += 1
        return counter

    def get_class_generated_features(self,idx):
        # genidx = self.class_separated_indices[idx]
        # genidx = torch.where(torch.tensor(genidx))[0]
        genidx = self.indexes[idx]
        # print(genidx)
        class_name = f"{idx:02d}" 
        return self.dinofeatures[class_name][genidx]

    def get_subset(self, train_features,count,method="cosine-near",zero_centered=False):
        self.indexes = []
        print(method)
        if method == "random":
            for class_id in range(10):
                class_name = f"{class_id:02d}"
                # class_dir = [d for d in os.listdir(self.root_dir) if d.startswith(class_name)][0]
                # full_path = os.path.join(self.root_dir, class_dir)
                # image_files = [f for f in os.listdir(full_path) if f.endswith('.png')]
                leng = (self.tensor_data[1] == class_id).sum().item()
                # print(leng)
                # print(leng)
                selected_indices = random.sample(range(leng), count)
                # selected = [image_files[i] for i in selected_indices]
                selection_mask = [i in selected_indices for i in range(leng)]
                self.indexes.append(torch.tensor(selection_mask))
                # selected = random.sample(image_files, count)
                # for f in selected:
                #     self.samples.append((os.path.join(full_path, f), class_id))
        else:
            for class_id in range(10):
                class_name = f"{class_id:02d}"
                print(class_name)
                train_i_features = train_features[f"{class_id}"]
                # print(leng)
                gen_i_features = self.dinofeatures[class_name]
                _,selection_indexes = match_greedy(train_i_features,gen_i_features,count,method,zero_centered)
                # full_path = os.path.join(self.root_dir, class_dir)
                # image_files = [f for f in os.listdir(full_path) if f.endswith('.png')]
                # selected = [image_files[i] for i in range(len(selection_indexes)) if selection_indexes[i]==True]
                self.indexes.append(torch.tensor(selection_indexes))
                # for f in selected:
                #     self.samples.append((os.path.join(full_path, f), class_id))
