import os
import torch
import torchvision 
import numpy as np
from PIL import Image
from typing import Any, Tuple
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, ImageFolder, SVHN
from targeted_image_folder import find_classes, TargetedImageFolder


def build_sequential_retain_forget_sets(
    ds, forget_class,forgotten_classes, subclass= False
):
    forget, retain, forgotten = [], [], []
    for img, label, clabel in ds:
        if subclass:
            if label == forget_class:
                forget.append((img, label, clabel))
            elif label in forgotten_classes:
                forgotten.append((img, label, clabel))
            else:
                retain.append((img, label, clabel))
        else:
            if clabel == forget_class:
                forget.append((img, label, clabel))
            elif clabel in forgotten_classes:
                forgotten.append((img, label, clabel))
            else:
                retain.append((img, label, clabel)) 
    return (forget, retain, forgotten)            
   
    
def build_retain_forget_sets(
    ds, forget_class, subclass=False
):
    forget, retain = [], []
    for img, label, clabel in ds:
        if subclass:
            if label == forget_class:
                forget.append((img, label, clabel))
            else:
                retain.append((img, label, clabel))
        else:
            if clabel == forget_class:
                forget.append((img, label, clabel))
            else:
                retain.append((img, label, clabel))
    return forget, retain


CIFAR_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

# Cropping etc. to improve performance of the model (details see https://github.com/weiaicunzai/pytorch-cifar100)
transform_train_from_scratch = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

transform_unlearning = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

transform_test = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]


FACE_MEAN = [0.485, 0.456, 0.406]
FACE_STD = [0.229, 0.224, 0.225]

def get_facial_data_transforms(train, unlearning, img_size):
    FACE_IMAGE_SIZE = (img_size, img_size)
    facial_data_transforms = {
        'train': transforms.Compose([
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Resize(FACE_IMAGE_SIZE,antialias=False),
            transforms.Resize(FACE_IMAGE_SIZE,antialias=False),
            
            # transforms.Resize(FACE_IMAGE_SIZE,antialias=True),
            # transforms.RandomRotation(5, resample=False,expand=False, center=None),
            # transforms.RandomRotation(5, expand=False, center=None),
            
            transforms.Normalize(FACE_MEAN, FACE_STD)
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            # transforms.Resize(FACE_IMAGE_SIZE,antialias=True),
            transforms.Resize(FACE_IMAGE_SIZE,antialias=False),
            
            transforms.Normalize(FACE_MEAN, FACE_STD),
        ]),
        'unlearn': transforms.Compose([
            transforms.ToTensor(),
            # transforms.Resize(FACE_IMAGE_SIZE,antialias=True),
            transforms.Resize(FACE_IMAGE_SIZE,antialias=False),
            
            transforms.Normalize(FACE_MEAN, FACE_STD),
        ]),
    }
    if train:
        if unlearning:
            return facial_data_transforms['unlearn']
        else:
            return facial_data_transforms['train']
    else:
        return facial_data_transforms['val']

class PinsFaceRecognition(TargetedImageFolder):
    def __init__(self, root, train, unlearning, download=False, img_size=224, forget_class=None):
        if train:
            path = os.path.join(root, 'train')
        else:
            path = os.path.join(root, 'val')
        transform = get_facial_data_transforms(train, unlearning, img_size)
        self.classes, self.class_to_idx = find_classes(path)
        self.forget_class = forget_class
        if forget_class is not None:
            if unlearning:
                target_classes = [self.classes[forget_class]]
            else:
                target_classes = self.classes[:forget_class] + self.classes[forget_class+1:]
        else:
            target_classes = self.classes 
        # print(target_classes)
        super().__init__(path, target_classes, transform)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y

class Cifar100(CIFAR100):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y


class Cifar20(CIFAR100):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

        # This map is for the matching of subclases to the superclasses. E.g., rocket (69) to Vehicle2 (19:)
        # Taken from https://github.com/vikram2000b/bad-teaching-unlearning
        self.coarse_map = {
            0: [4, 30, 55, 72, 95],
            1: [1, 32, 67, 73, 91],
            2: [54, 62, 70, 82, 92],
            3: [9, 10, 16, 28, 61],
            4: [0, 51, 53, 57, 83],
            5: [22, 39, 40, 86, 87],
            6: [5, 20, 25, 84, 94],
            7: [6, 7, 14, 18, 24],
            8: [3, 42, 43, 88, 97],
            9: [12, 17, 37, 68, 76],
            10: [23, 33, 49, 60, 71],
            11: [15, 19, 21, 31, 38],
            12: [34, 63, 64, 66, 75],
            13: [26, 45, 77, 79, 99],
            14: [2, 11, 35, 46, 98],
            15: [27, 29, 44, 78, 93],
            16: [36, 50, 65, 74, 80],
            17: [47, 52, 56, 59, 96],
            18: [8, 13, 48, 58, 90],
            19: [41, 69, 81, 85, 89],
        }

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        coarse_y = None
        for i in range(20):
            for j in self.coarse_map[i]:
                if y == j:
                    coarse_y = i
                    break
            if coarse_y != None:
                break
        if coarse_y == None:
            print(y)
            assert coarse_y != None
        return x, y, coarse_y


class Cifar10(CIFAR10):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y


class UnLearningData(Dataset):
    def __init__(self, forget_data, retain_data):
        super().__init__()
        self.forget_data = forget_data
        self.retain_data = retain_data
        self.forget_len = len(forget_data)
        self.retain_len = len(retain_data)

    def __len__(self):
        return self.retain_len + self.forget_len

    def __getitem__(self, index):
        if index < self.forget_len:
            x = self.forget_data[index][0]
            y = 1
            return x, y
        else:
            x = self.retain_data[index - self.forget_len][0]
            y = 0
            return x, y


def get_surrogate(args, opt,num_classes, original_model=None):
    surrogate_dataset = opt['surrogate_dataset']
    mode = opt['mode']
    surrogate_quantity = int(opt['surrogate_quantity'])
    mean = {
            'subset_tiny': (0.485, 0.456, 0.406),
            'subset_Imagenet': (0.4914, 0.4822, 0.4465),
            'subset_rnd_img': (0.5969, 0.5444, 0.4877),
            'subset_COCO': (0.4717,0.4486,0.4089),
            'subset_gaussian_noise': (0,0,0)
            }

    std = {
            'subset_tiny': (0.229, 0.224, 0.225),
            'subset_Imagenet': (0.229, 0.224, 0.225),
            'subset_rnd_img': (0.3366, 0.3260, 0.3411),
            'subset_COCO': (0.2754, 0.2708, 0.2852),
            'subset_gaussian_noise': (1,1,1)
            }

    
    transform_list = [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean[surrogate_dataset],std[surrogate_dataset]),
        ]

    transform_list_test = [
            transforms.ToTensor(),
            transforms.Normalize(mean[surrogate_dataset],std[surrogate_dataset]),
        ]

    if args.net =='ViT':
        transform_list.insert(0, transforms.RandomCrop(224, padding=28))
        transform_list.insert(0,transforms.Resize(224, antialias=True))
        transform_list_test.insert(0,transforms.Resize(224, antialias=True))
    else:
        transform_list.insert(0,transforms.RandomCrop(64, padding=8) if args.dataset == 'TinyImagenet' else transforms.RandomCrop(32, padding=4))
        transform_list_test.insert(0,transforms.RandomCrop(64, padding=8) if args.dataset == 'TinyImagenet' else transforms.RandomCrop(32, padding=4))
        
        # transform_list.insert(0,transforms.Resize(64, antialias=True) if args.dataset == 'TinyImagenet' else transforms.RandomCrop(32, padding=4))
        


        
    transform_dset = transforms.Compose(transform_list)
    transform_test= transforms.Compose(transform_list_test)

    if surrogate_dataset!="subset_gaussian_noise":
        if mode =='CR':
            set = torchvision.datasets.ImageFolder(root=os.path.join('data','surrogate_data/',surrogate_dataset+'_split'), transform=transform_dset)
        else:    
            set = torchvision.datasets.ImageFolder(root=os.path.join('data','surrogate_data/',surrogate_dataset+'_split'), transform=transform_test)
        if surrogate_quantity == -1:
            subset = set
        else:
            class_list = [i for i in range(min(surrogate_quantity,len(set.classes)))]
            idx = [i for i in range(len(set)) if set.imgs[i][1] in class_list]
            #build the appropriate subset
            subset = torch.utils.data.Subset(set, idx)
    else:
        #dataset from pt tensor
        subset = []
        if surrogate_quantity == -1:
            surrogate_quantity =10

        for i in range(surrogate_quantity):
            fname = f"data/surrogate_data/{surrogate_dataset}_split/{i}/gaussian_noise_{i}.pt"
            print(fname)
            imgs = torch.load(fname)
            labels = torch.zeros(imgs.shape[0])
            subset.append(torch.utils.data.TensorDataset(imgs,labels))
        subset = torch.utils.data.ConcatDataset(subset)


    loader_surrogate = DataLoader(subset, batch_size=args.b, shuffle=True, num_workers=0)
    if mode =='HR':
        bbone = torch.nn.Sequential(*(list(original_model.children())[:-1] + [torch.nn.Flatten()]))
        fc = original_model.fc
        bbone.eval()
        #forward pass into the original model 
        logits = []
        dset = []
        labels = []
        features_sur = []
        for img,lb in loader_surrogate:
            with torch.no_grad():
                output = original_model(img.cuda())
                logits.append(output.detach().cpu())
                lb = torch.argmax(output,dim=1).detach().cpu()
                dset.append(img)
                labels.append(lb)
                features_sur.append(bbone(img.cuda()).detach().cpu())

        logits = torch.cat(logits)
        dset = torch.cat(dset)
        labels = torch.cat(labels)
        features_sur=torch.cat(features_sur)

        clean_logits = []
        clean_labels = []
        clean_dset = []

        dataset_wlogits = custom_Dset_surrogate(dset,labels,logits)
        print('LEN surrogate',dataset_wlogits.__len__())
        
        class_sample_count = torch.zeros_like(labels)
        for i in range(num_classes):
            class_sample_count[labels==i] = len(torch.where(labels==i)[0])
        #correct for undersampled output
        class_sample_count[class_sample_count<3]=5

        weights = 1 / torch.Tensor(class_sample_count)

        sampler = torch.utils.data.sampler.WeightedRandomSampler(weights,num_samples=dataset_wlogits.__len__(), replacement=True)
        loader_surrogate = DataLoader(dataset_wlogits, batch_size=args.b, num_workers=4,sampler=sampler)#
    return loader_surrogate


class custom_Dset_surrogate(Dataset):
    def __init__(self, dset,labels, logits,transf=None):
        self.dset = dset
        self.labels = labels
        self.logits = logits
        self.transf = transf


    def __len__(self):
        return self.dset.shape[0]

    def __getitem__(self, index):
        x = self.dset[index]
        y = self.labels[index]
        logit_x = self.logits[index]
        if self.transf:
            x=self.transf(x)
        return x, y,logit_x

TINYIMAGENET_MEAN = (0.4802, 0.4480, 0.3975)
TINYIMAGENET_STD = (0.2770, 0.2691, 0.2821)
transform_train_tiny = [transforms.RandomCrop(64, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(TINYIMAGENET_MEAN, TINYIMAGENET_STD)]

transform_test_tiny = [
    transforms.ToTensor(),
    transforms.Normalize(TINYIMAGENET_MEAN, TINYIMAGENET_STD),
]

class TinyImagenet(Dataset):
    """
    Defines Tiny Imagenet as for the others pytorch datasets.
    """

    def __init__(self, root: str, train: bool = True, unlearning: bool = True, download: bool = False, img_size: int = 64) -> None:
        self.root = root
        self.train = train
        self.unlearning = unlearning 
        if train:
            if unlearning:
                transform = transform_test_tiny
            else:
                transform = transform_train_tiny
        else:
            transform = transform_test_tiny
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)
        self.transform = transform
        
        self.download = download

        if download:
            if os.path.isdir(root) and len(os.listdir(root)) > 0:
                print('Download not needed, files already on disk.')
            else:
                from onedrivedownloader import download

                print('Downloading dataset')
                ln = "https://unimore365-my.sharepoint.com/:u:/g/personal/263133_unimore_it/EVKugslStrtNpyLGbgrhjaABqRHcE3PB_r2OEaV7Jy94oQ?e=9K29aD"
                download(ln, filename=os.path.join(root, 'tiny-imagenet-processed.zip'), unzip=True, unzip_path=root, clean=True)

        self.data = []
        for num in range(20):
            self.data.append(np.load(os.path.join(
                root, 'TINYIMG/processed/x_%s_%02d.npy' %
                      ('train' if self.train else 'val', num + 1))))
        self.data = np.concatenate(np.array(self.data))

        self.targets = []
        for num in range(20):
            self.targets.append(np.load(os.path.join(
                root, 'TINYIMG/processed/y_%s_%02d.npy' %
                      ('train' if self.train else 'val', num + 1))))
        self.targets = np.concatenate(np.array(self.targets))

    def __len__(self):
        return len(self.data)

    def getit(self, index):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.uint8(255 * img))

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __getitem__(self, index):
        x, y = self.getit(index)
        return x, torch.Tensor([]), y
    
    
SVHN_MEAN = (0.4376821, 0.4437697, 0.47280442)
SVHN_STD = (0.19803012, 0.20101562, 0.19703614)

transform_train_svhn = [
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(SVHN_MEAN, SVHN_STD)]

transform_test_svhn = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(SVHN_MEAN, SVHN_STD),
]



class Svhn(SVHN):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform.append(transforms.Resize(img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, split='train' if train else 'test', download=download, transform=transform)

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y