import torch
import numpy as np
from torch.utils.data import Dataset, Subset, ConcatDataset, DataLoader
import torchvision.transforms as transforms
from torchvision import datasets 
from feature_extractors import extract_clip_features_from_subset, extract_dino_features_from_subset
from generated_dataset import GeneratedDataset
from imagenet_generated_dataset import ImagenetGeneratedDataset
from itertools import chain
from torch.utils.data import random_split

def combined_loader(loader1, loader2):
    while True:
        yield from chain(loader1, loader2)
        
class TransformWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, expand=1):
        self.dataset = dataset
        self.transform = transform
        self.expand = expand

    def __getitem__(self, index):
        if self.expand > 1:
            index = index // self.expand
        img, label = self.dataset[index]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.dataset)*self.expand


# REAL TRAIN DATASET
def get_transformations(dataset_name,expand=False):
    if dataset_name == "cifar10":
        transform_train = transforms.Compose([
            # transforms.RandomHorizontalFlip(),
            transforms.ToPILImage(),
            transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.BICUBIC),
            # transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409),
                                 (0.2673, 0.2564, 0.2762)),
        ])
        
        transform_train_pil = transforms.Compose([
            # transforms.RandomHorizontalFlip(),
            transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.BICUBIC),
            # transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409),
                                 (0.2673, 0.2564, 0.2762)),
        ])

        if expand:
            transform_train = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomCrop(32, padding=4),                
                transforms.RandomHorizontalFlip(),                  
                transforms.RandomRotation(15),                       
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),      
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4865, 0.4409),
                                     (0.2673, 0.2564, 0.2762)),
            ])
            transform_train_pil = transforms.Compose([
                transforms.RandomCrop(32, padding=4),              
                transforms.RandomHorizontalFlip(),               
                transforms.RandomRotation(15),                      
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),        
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4865, 0.4409),
                                     (0.2673, 0.2564, 0.2762)),
            ])
            
        
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    elif dataset_name == "imagenet":
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.33)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        transform_train_pil = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.33)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    else:
        print("INVALID DATASET NAME")
        assert 0==1
    return transform_train, transform_train_pil, transform_test

def get_validation(train_set,batch_size):
    train_size = int(0.9 * len(train_set))
    val_size = len(train_set) - train_size
    train_subset, val_subset = random_split(train_set, [train_size, val_size])
    val_transform = get_transformations('cifar10')[2]
    val_set = TransformWrapper(val_subset,val_transform)
    val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
    return train_subset, val_dataloader


def get_real_subdataset(dataset_name, train_set_full, targets, subset_count =500, clip=False):
    if dataset_name == "cifar10":
        num_classes=10
    elif dataset_name == "imagenet":
        num_classes = 100
    indices = []
    for class_id in range(num_classes):
        class_indices = np.where(targets == class_id)[0]
        chosen = np.random.choice(class_indices, subset_count, replace=False)
        indices.extend(chosen)
    
    subset_train = Subset(train_set_full, indices)     
    use_generative = True
    if use_generative:
        if clip:
            cifar10_real_features = extract_clip_features_from_subset(subset_train)
        else:
            cifar10_real_features = extract_dino_features_from_subset(subset_train)
    return subset_train,cifar10_real_features

def get_two_real_subsets(dataset_name, train_set_full, targets,leak_count= 400,subset_count=500, clip=False):
    if dataset_name == "cifar10":
        num_classes=10
    elif dataset_name == "imagenet":
        num_classes = 100
    indices_1 = []
    indices_2 = []

    for class_id in range(num_classes):
        class_indices = np.where(targets == class_id)[0]
        np.random.shuffle(class_indices)

        n_total = subset_count + leak_count
        n_half = leak_count

        chosen_1 = class_indices[:n_half]
        chosen_2 = class_indices[n_half:n_total]

        indices_1.extend(chosen_1)
        indices_2.extend(chosen_2)

    subset_leak = Subset(train_set_full, indices_1)
    subset_train = Subset(train_set_full, indices_2)

    def extract_features(subset):
        if clip:
            return extract_clip_features_from_subset(subset)
        else:
            return extract_dino_features_from_subset(subset)

    features_leak = extract_features(subset_leak)
    features_train = extract_features(subset_train)
    # subset1 is leakage
    return subset_leak, subset_train, features_leak, features_train

def get_full_dataset(dataset_name="cifar10", model_names=["sd14"], subset_train=None,test_set=None, use_generative=True,cifar10_real_features=None, number_of_generated=1000, batch_size=128, generated_root="./cifar10_generated_images",\
                     method="random", zero_centered=True,clip=False, leak_dataset=None,leak_features=None,expand=1,read_amount=None):
    m = number_of_generated
    batch_size = batch_size
    use_generative=use_generative
    transform_train, transform_train_pil,_ = get_transformations(dataset_name,expand=(expand>1))
    generated_dataset = None
    if use_generative:
        generated_root = generated_root
        # generated_root = "./cifar10_generated_images"
        # generated_root = "./cifar10_sd14"
        if dataset_name == "cifar10":
            generated_dataset = GeneratedDataset(dataset_name, model_names,m, transform_train, cifar10_real_features,\
                                    method=method,zero_centered=zero_centered,clip=clip,leak_dataset=leak_dataset,leak_features=leak_features,expand=expand,read_amount=read_amount)
        else:
            generator = ImagenetGeneratedDataset(dataset_name, m, transform_train_pil, cifar10_real_features,\
                                                 method=method,zero_centered=zero_centered,clip=clip,leak_dataset=leak_dataset,leak_features=leak_features)
            generated_dataset = generator.get_dataset()
        
        subset_train2 = TransformWrapper(subset_train, transform_train_pil,expand)
        if dataset_name == "imagenet":
            combined_train = ConcatDataset([subset_train2, generated_dataset])
            train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
            print("combined loaders")
        else:
            combined_train = ConcatDataset([subset_train2, generated_dataset])
            train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
    else:
        combined_train= TransformWrapper(subset_train, transform_train_pil,expand)
        train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)

    # train_loader = DataLoader(combined_train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=8)
    return train_loader,test_loader,generated_dataset


class ExpandedAugmentedDataset(Dataset):
    def __init__(self, base_dataset, transform, repeats=5):
        self.dataset = base_dataset
        self.transform = transform
        self.repeats = repeats
        if hasattr(base_dataset, 'targets'):
            self.targets = base_dataset.targets * self.repeats  # list repeat
        elif hasattr(base_dataset, 'labels'):  # for some older torchvision datasets
            self.targets = base_dataset.labels * self.repeats
        else:
            self.targets = [label for _, label in base_dataset] * self.repeats

    def __len__(self):
        return len(self.dataset) * self.repeats

    def __getitem__(self, idx):
        base_idx = idx % len(self.dataset)
        img, label = self.dataset[base_idx]
        img = self.transform(img)
        return img, label
        
def extend_dataset(dataset):
    transform_augment = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
    ])
    return ExpandedAugmentedDataset(dataset,transform_augment,3)