from PIL import Image
import numpy as np
import torch
import torchvision
import os
import pickle
import random
from torch.utils.data import Dataset, DataLoader
from torchvision import models, utils, datasets, transforms
import sys

def load_data(data, bs):

    transform_ = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])

    if data == 'cifar10':
        train_dataset = datasets.CIFAR10('./dataset/cifar10/', train=True, download=True, transform=transform_)
        test_dataset = datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform_)
    
    elif data == 'cifar10-2k':
        train_dataset = datasets.CIFAR10('./dataset/cifar10/', train=True, download=True, transform=transform_)
        test_dataset = datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform_)
        train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [20000, len(train_dataset)-20000])
    
    elif data == 'cifar100':
        train_dataset = datasets.CIFAR100('./dataset/cifar100/', train=True, download=True, transform=transform_)
        test_dataset = datasets.CIFAR100('./dataset/cifar100/', train=False, download=True, transform=transform_)
        
    elif data == 'stl10':
        train_dataset = datasets.STL10('./dataset/stl10', split="train", download=True, transform=transform_)
        test_dataset = datasets.STL10('./dataset/stl10', split="test", download=True, transform=transform_)
    
    elif data == 'stl10-50k':
        dataset = datasets.STL10('./dataset/stl10', split="unlabeled", download=True, transform=transform_)
        indices = list(range(len(dataset)))
        random.seed(310)  
        random.shuffle(indices)
        train_dataset = torch.utils.data.Subset(dataset, indices[:50000])
        test_dataset = torch.utils.data.Subset(dataset, indices[50000:])
    
    elif data == 'gtsrb':
        train_dataset = datasets.ImageFolder('./dataset/GTSRB/Images/', transform = transform_)
        test_dataset = datasets.ImageFolder('./dataset/GTSRB/Images/', transform = transform_)
        train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [39000, len(train_dataset)-39000])

    elif data == 'imagenet':
        train_dataset = datasets.ImageFolder('./dataset/imagenet/train/', transform = transform_)
        test_dataset = datasets.ImageFolder('./dataset/imagenet/train/', transform = transform_)
        train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [20000, len(train_dataset)-20000])
    
    print('dataset: ', len(train_dataset))

    train_loader = DataLoader(train_dataset, batch_size=bs, drop_last=False, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=bs, drop_last=False, shuffle = True)

   
    return train_loader, test_loader



class TinyImageNet(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")

        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()

        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

        self.set_nids = set()

        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))

        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]

    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3, 5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(self.train_dir, d))]
        classes = sorted(classes)
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images;

        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3, 5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(self.train_dir, d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file, 'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root, _, files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if (fname.endswith(".JPEG")):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)

    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, tgt






class DatasetSplit(Dataset):
    def __init__(self, dataset, num_data):
        self.dataset = dataset
        idxs = np.arange(len(dataset))
        self.idxs = np.random.choice(idxs,num_data,replace=False)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label



def allocate_imagenet(args):
    
    transform=transforms.Compose([transforms.RandomHorizontalFlip(),
                                transforms.RandomRotation(15),
                                transforms.ToTensor()
                                ])
    data_dir = './Data/tiny-imagenet-200'
    dataset_train = TinyImageNet(data_dir, train=True, transform=transform)

    X = []
    y = []
    for data in dataset_train:
        X.append(data[0].unsqueeze(0))
        y.append(data[1])
    X = torch.cat(X, axis=0)
    y = torch.tensor(y)
    
    # shuffle the dataset
    idx = torch.randperm(len(dataset_train))
    X = X[idx]
    y = y[idx]

    
    # allocate data
    data_log = {}
    data_log["X_train"] = X[0:args.num_train]
    data_log["y_train"] = y[0:args.num_train]
    data_log["X_fin"] = X[args.num_train:args.num_train+args.fin_num]
    data_log["y_fin"] = y[args.num_train:args.num_train+args.fin_num]
    data_log["X_remain"] = X[args.num_train+args.fin_num:]
    data_log["y_remain"] = y[args.num_train+args.fin_num:]
    
    # save data
    data_dir = args.data_path + '/' + args.dataset + '/allocated_data'
    os.makedirs(data_dir, exist_ok=True)
    torch.save(data_log, data_dir + '/data_log.pth')



class ElementWiseTransform():
    def __init__(self, trans=None):
        self.trans = trans

    def __call__(self, x):
        if self.trans is None: return x
        return torch.cat( [self.trans( xx.view(1, *xx.shape) ) for xx in x] )


class IndexedTensorDataset():
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]
        ''' transform HWC pic to CWH pic '''
        x = torch.tensor(x, dtype=torch.float32).permute(2,0,1)
        return x, y, idx

    def __len__(self):
        return len(self.x)


class Dataset():
    def __init__(self, x, y, transform=None, fitr=None):
        self.x = x
        self.y = y
        self.transform = transform
        self.fitr = fitr

    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]

        ''' low pass filtering '''
        if self.fitr is not None:
            x = self.fitr(x)

        ''' data augmentation '''
        if self.transform is not None:
            x = self.transform( Image.fromarray(x) )

        return x, y

    def __len__(self):
        return len(self.x)


class IndexedDataset():
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.ii = np.array( range(len(x)), dtype=np.int64 )
        self.transform = transform

    def __getitem__(self, idx):
        x, y, ii = Image.fromarray(self.x[idx]), self.y[idx], self.ii[idx]
        if self.transform is not None:
            x = self.transform(x)
        return x, y, ii

    def __len__(self):
        return len(self.x)


def datasetCIFAR10(root='./path', train=True, transform=None):
    return torchvision.datasets.CIFAR10(root=root, train=train,
                        transform=transform, download=True)

def datasetCIFAR100(root='./path', train=True, transform=None):
    return torchvision.datasets.CIFAR100(root=root, train=train,
                        transform=transform, download=True)

def datasetTinyImageNet(root='./path', train=True, transform=None):
    if train: root = os.path.join(root, 'tiny-imagenet_train.pkl')
    else: root = os.path.join(root, 'tiny-imagenet_val.pkl')
    with open(root, 'rb') as f:
        dat = pickle.load(f)
    return Dataset(dat['data'], dat['targets'], transform)



class Loader():
    def __init__(self, dataset, batch_size, shuffle=False, drop_last=False, num_workers=4):
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
        self.iterator = None

    def __iter__(self):
        return iter(self.loader)

    def __len__(self):
        return len(self.loader)

    def __next__(self):
        if self.iterator is None:
            self.iterator = iter(self.loader)

        try:
            samples = next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.loader)
            samples = next(self.iterator)

        return samples
