import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
from torch.utils.data import Dataset
import numpy as np
import pickle
root = './dataset'

class DATASET():
    def __init__(self,idx,minibatch_size=128,setsize=64,cls_idx=[0,1,2,3,4,5,6,7,8,9],task_size=50000,glance=50000):
        self.minibatch_size=minibatch_size
        self.cls_idx=cls_idx
        self.set_size=setsize
        self.task_size=task_size
        self.glance=glance
        if idx==0:

            train_file=os.path.join(root,"mnist_cl_train")
            test_file=os.path.join(root,"mnist_cl_test")


        if idx==1:
            if not os.path.exists(os.path.join(root,"svhn_cl_train")):
                self.svhn_file()
            train_file=os.path.join(root,"svhn_cl_train")
            test_file=os.path.join(root,"svhn_cl_test")
        if idx==2:
            if not os.path.exists(os.path.join(root,"cifar100_cl_train")):
                self.cifar100_file()
            train_file=os.path.join(root,"cifar100_cl_train")
            test_file=os.path.join(root,"cifar100_cl_test")

        with open(train_file, "rb") as fb:
                self.train_img = pickle.load(fb)
        with open(test_file, "rb") as fb:  # Pickling
                self.test_img = pickle.load(fb)


    def svhn_file(self):

        train_dataset = []
        test_dataset = []
        train_set=[[] for _ in range(10)]
        test_set = [[] for _ in range(10)]


        path = os.path.join(root, 'svhn')
        mean = [0.4377, 0.4438, 0.4728]
        std = [0.198, 0.201, 0.197]

        train = datasets.SVHN(path, split='train', download=True,
                              transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]))
        test = datasets.SVHN(path, split='test', download=True,
                             transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]))

        print(train.__len__())
        print(test.__len__())

        for image, target in train:
            train_set[target].append(image)
        for image, target in test:
            test_set[target].append(image)
        with open(os.path.join(root,"svhn_cl_train"), "wb") as fp:  # Pickling
            pickle.dump(train_set, fp)

        with open(os.path.join(root,"svhn_cl_test"), "wb") as fp:  # Pickling
            pickle.dump(test_set, fp)
        print("Done")
    def cifar100_file(self):

        train_dataset = []
        test_dataset = []
        train_set=[[] for _ in range(100)]
        test_set = [[] for _ in range(100)]


        path = os.path.join(root, 'cifar100')
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]

        train = datasets.CIFAR100(path, train=True,download=True,
                              transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]))
        test = datasets.CIFAR100(path,train=False, download=True,
                             transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]))

        print(train.__len__())
        print(test.__len__())

        for image, target in train:
            train_set[target].append(image)
        for image, target in test:
            test_set[target].append(image)
        with open(os.path.join(root,"cifar100_cl_train"), "wb") as fp:  # Pickling
            pickle.dump(train_set, fp)

        with open(os.path.join(root,"cifar100_cl_test"), "wb") as fp:  # Pickling
            pickle.dump(test_set, fp)
        print("Done")


            # torch img, list [class][]
    def mnist_file(self):

        train_dataset = []
        test_dataset = []
        train_set=[[] for _ in range(10)]
        test_set = [[] for _ in range(10)]



        # MNIST
        # mean=(0.1307,) # Mean and std without including the padding
        # std=(0.3081,)
        path = os.path.join(root, 'mnist')
        mean = (0.1,)  # Mean and std including the padding
        std = (0.2752,)

        train = datasets.MNIST(path, train=True, download=True, transform=transforms.Compose([
            transforms.Pad(padding=2, fill=0), transforms.ToTensor(), transforms.Normalize(mean, std)]))
        test = datasets.MNIST(path, train=False, download=True, transform=transforms.Compose([
            transforms.Pad(padding=2, fill=0), transforms.ToTensor(), transforms.Normalize(mean, std)]))
        print(train.__len__())
        print(test.__len__())

        for image, target in train:
            image = image.expand(3, image.size(1), image.size(2))
            train_set[target].append(image)
        for image, target in test:
            image = image.expand(3, image.size(1), image.size(2))
            test_set[target].append(image)

        with open(os.path.join(root,"mnist_cl_train"), "wb") as fp:  # Pickling
            pickle.dump(train_set, fp)

        with open(os.path.join(root,"mnist_cl_test"), "wb") as fp:  # Pickling
            pickle.dump(test_set, fp)
        print("Done")

        # torch img, list [class][]


    def get_loader(self):

        train_set=[]
        test_set = []
        for label,cls in enumerate(self.cls_idx):
            for img in self.train_img[cls]:
                train_set.append((img, label))
            for img in self.test_img[cls]:
                test_set.append((img, label))
        print("Done")
        glace_train=[]
        if self.glance !=50000:
            glace_train=train_set[:self.glance]
        else:
            glace_train=train_set
        train_loader0 = torch.utils.data.DataLoader(glace_train, batch_size=self.minibatch_size, shuffle=True,
                                               num_workers=2, pin_memory=True)
        if self.task_size != 50000:
            train_set=train_set[:self.task_size]
            print(len(train_set))
            glace_train=glace_train[:self.task_size]
        train_loader0 = torch.utils.data.DataLoader(glace_train, batch_size=self.minibatch_size, shuffle=True,
                                               num_workers=2, pin_memory=True)

        train_loader = torch.utils.data.DataLoader(train_set, batch_size=self.minibatch_size, shuffle=True,
                                               num_workers=2, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=self.minibatch_size,
                                              num_workers=2, pin_memory=True)

        return train_loader0,train_loader,test_loader

    def sample(self,train=False):
        set_input=[]
        if train:

            for cls in self.cls_idx:
                idx=torch.randperm(len(self.train_img[cls]))[:self.set_size]
                for i in idx:
                    set_input.append(self.train_img[cls][i])



        else:
            for cls in self.cls_idx:
                idx=torch.randperm(len(self.test_img[cls]))[:self.set_size]

                for i in idx:
                    set_input.append(self.test_img[cls][i])

        return torch.stack(set_input).cuda().view(1,-1,3072)
