import torch
import numpy as np
from torch.utils.data import Dataset, Subset, ConcatDataset, DataLoader
import torchvision.transforms as transforms
import os
import pickle
from torchvision import datasets
from feature_extractors import extract_clip_features_from_subset, extract_dino_features_from_subset, extract_dino_features_from_subset_imagenet, extract_clip_features_from_subset_imagenet
from generated_dataset import GeneratedDataset
from imagenet_generated_dataset import ImagenetGeneratedDataset
from itertools import chain
from torch.utils.data import random_split

class TransformWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, expand=1):
        self.dataset = dataset
        self.transform = transform
        self.expand = expand

    def __getitem__(self, index):
        if self.expand > 1:
            index = index // self.expand
        img, label = self.dataset[index]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.dataset)*self.expand

def get_transformations(dataset_name,expand=False,is_vit=False):
    if dataset_name == "cifar10expand":
        return transforms.Compose([transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        ])

    if dataset_name == "cifar10" and not is_vit:
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

        transform_train_pil = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

        if expand:
            transform_train = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
            transform_train_pil = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    elif dataset_name == "cifar10" and is_vit:
        size = 384
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

        transform_train_pil = transforms.Compose([

            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BICUBIC),

            transforms.ToTensor(),
             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

        transform_test = transforms.Compose([
            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    elif dataset_name == "imagenet":
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.CenterCrop(224),

            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        transform_train_pil = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),

            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        if expand:
            transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.33)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            ])
            transform_train_pil = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.33)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
            ])
    elif dataset_name == "imagenetexpand":
        return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ])

    else:
        print("INVALID DATASET NAME")
        assert 0==1
    return transform_train, transform_train_pil, transform_test

def get_validation(train_set,batch_size,is_vit=False):
    train_size = int(0.9 * len(train_set))
    val_size = len(train_set) - train_size
    train_subset, val_subset = random_split(train_set, [train_size, val_size])
    val_transform = get_transformations('cifar10',is_vit=is_vit)[2]
    val_set = TransformWrapper(val_subset,val_transform)
    val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
    return train_subset, val_dataloader

def get_real_subdataset(dataset_name, train_set_full, targets, subset_count =500, clip=False, expand_reals=1, index_file="./configs/random"):
    if dataset_name == "cifar10":
        num_classes=10
    elif dataset_name == "imagenet":
        num_classes = 100
    indices = []
    if os.path.exists(index_file):
        with open(index_file, "rb") as f:
            indices = pickle.load(f)
    else:
        indices = []
        for class_id in range(num_classes):
            class_indices = np.where(targets == class_id)[0]
            chosen = np.random.choice(class_indices, subset_count, replace=False)
            indices.extend(chosen)
        with open(index_file, "wb") as f:
            pickle.dump(indices, f)

    subset_train = Subset(train_set_full, indices)
    if expand_reals > 1:
        if dataset_name == "cifar10":
            subset_train = TransformWrapper(subset_train, get_transformations("cifar10expand"), expand=expand_reals)
            subset_train = Subset(subset_train, list(range(len(subset_train))))
        elif dataset_name == "imagenet":
            subset_train = TransformWrapper(subset_train, get_transformations("imagenetexpand", expand_reals>1), expand=expand_reals)
            subset_train = Subset(subset_train, list(range(len(subset_train))))

    if len(indices) < subset_count * num_classes:
        print("Not enough indices found, expanding the dataset with more samples. Using more indices.")
        more_indices = []

        for class_id in range(num_classes):
            class_indices = np.where(targets == class_id)[0]
            class_indices = [i for i in class_indices if i not in indices]
            np.random.shuffle(class_indices)
            needed = subset_count - len([i for i in indices if targets[i] == class_id])
            if needed > 0:
                chosen = class_indices[:needed]
                more_indices.extend(chosen)

        more_subset = Subset(train_set_full, more_indices)

        subset_train = ConcatDataset([subset_train, more_subset])

    use_generative = True
    if use_generative:
        if clip:
            if dataset_name == "imagenet":
                cifar10_real_features = extract_clip_features_from_subset_imagenet(subset_train)
            else:
                cifar10_real_features = extract_clip_features_from_subset(subset_train)
        else:
            if dataset_name == "imagenet":
                cifar10_real_features = extract_dino_features_from_subset_imagenet(subset_train)
            else:
                cifar10_real_features = extract_dino_features_from_subset(subset_train)
    return subset_train,cifar10_real_features

def get_two_real_subsets(dataset_name, train_set_full, targets,leak_count= 400,subset_count=500, clip=False, index_file="./configs/random", leak_index_file=None):
    if dataset_name == "cifar10":
        num_classes=10
    elif dataset_name == "imagenet":
        num_classes = 100

    indices_1 = []
    if os.path.exists(index_file):
        with open(index_file, "rb") as f:
            indices_1 = pickle.load(f)
    else:
        indices_1 = []
        for class_id in range(num_classes):
            class_indices = np.where(targets == class_id)[0]
            chosen = np.random.choice(class_indices, subset_count, replace=False)
            indices_1.extend(chosen)
        with open(index_file, "wb") as f:
            pickle.dump(indices_1, f)

    indices_2 = []
    if leak_index_file is not None and os.path.exists(leak_index_file):
        with open(leak_index_file, "rb") as f:
            indices_2 = pickle.load(f)
    else:
        indices_2 = []
        for class_id in range(num_classes):
            class_indices = np.where(targets == class_id)[0]

            class_indices = [i for i in class_indices if i not in indices_1]
            np.random.shuffle(class_indices)

            chosen_2 = class_indices[:leak_count]

            indices_2.extend(chosen_2)

        if leak_index_file is not None:
            with open(leak_index_file, "wb") as f:
                pickle.dump(indices_2, f)

    subset_train = Subset(train_set_full, indices_1)
    subset_leak = Subset(train_set_full, indices_2)

    def extract_features(subset):
        if clip:
            return extract_clip_features_from_subset(subset)
        else:
            return extract_dino_features_from_subset(subset)

    features_leak = extract_features(subset_leak)
    features_train = extract_features(subset_train)

    return subset_leak, subset_train, features_leak, features_train

def get_train_subset_features(subset,is_clip=False):
    if is_clip:
        return extract_clip_features_from_subset(subset)
    else:
        return extract_dino_features_from_subset(subset)

def get_full_dataset(dataset_name="cifar10", model_names=["sd14"], subset_train=None,test_set=None, use_generative=True,cifar10_real_features=None, number_of_generated=1000, batch_size=128, generated_root="./cifar10_generated_images",\
                     method="random", zero_centered=True,clip=False, leak_dataset=None,leak_features=None,expand=1,read_amount=None,Prune=False, SEED=42, add_synthetic=False,is_vit=False):
    m = number_of_generated
    batch_size = batch_size
    use_generative=use_generative
    g = torch.Generator()
    g.manual_seed(SEED)
    transform_train, transform_train_pil,_ = get_transformations(dataset_name,(expand>1),is_vit=is_vit)
    generated_dataset = None
    if use_generative:
        generated_root = generated_root

        if dataset_name == "cifar10":
            generated_dataset = GeneratedDataset(dataset_name, model_names,m, transform_train, cifar10_real_features,\
                                    method=method,zero_centered=zero_centered,clip=clip,leak_dataset=leak_dataset,leak_features=leak_features,expand=expand,read_amount=read_amount,prune=Prune, add_synthetic=add_synthetic)
        else:
            generator = ImagenetGeneratedDataset(dataset_name, m, transform_train_pil, cifar10_real_features,\
                                    method=method,zero_centered=zero_centered,clip=clip,leak_dataset=leak_dataset,leak_features=leak_features, prune=Prune)
            generated_dataset = generator.get_dataset()

        subset_train2 = TransformWrapper(subset_train, transform_train_pil,expand)
        if dataset_name == "imagenet":
            combined_train = ConcatDataset([subset_train2, generated_dataset])
            train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True, generator=g)
        else:
            combined_train = ConcatDataset([subset_train2, generated_dataset])
            train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, generator=g)
    else:
        combined_train= TransformWrapper(subset_train, transform_train_pil,expand)
        train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True, generator=g)

    test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=8, generator=g)
    return train_loader,test_loader,generated_dataset
