import torch
import torch.nn as nn
import PIL.Image as Image
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import os
import numpy as np

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

class SplitCIFAR100(datasets.CIFAR100):
    
    def __init__(self, root="datasets", train=True, task=[], sub_size=None):
        super().__init__(root, train, download=False)
        self.transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
        self.task = task
        self.train = train
        if sub_size==None:
            self.idx = [i for i in range(len(self.targets)) if self.targets[i] in task]
        else:
            self.subset(sub_size)

    def subset(self, sub_size):
        count = dict()
        for i in self.task:
            count[i]=0
        self.idx = []
        for i in range(len(self.targets)):
            if self.targets[i] in self.task:
                if count[self.targets[i]]<sub_size:
                    self.idx.append(i)
                    count[self.targets[i]]+=1
        
    def __len__(self):
        return len(self.idx)

    def __getitem__(self, index):
        img, target = self.data[self.idx[index]], self.targets[self.idx[index]]
        img = Image.fromarray(img)
        img = self.transform(img) 
        img = normalize(img)      
        return img, target
    
class SplitCIFAR10(datasets.CIFAR10):
    
    def __init__(self, root="datasets", train=True, task=[], sub_size=None):
        super().__init__(root, train, download=False)
        self.transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
        #self.transform = transforms.Compose([transforms.ToTensor()])
        self.task = task
        self.train = train
        if sub_size==None:
            self.idx = [i for i in range(len(self.targets)) if self.targets[i] in task]
        else:
            self.subset(sub_size)

    def subset(self, sub_size):
        count = dict()
        for i in self.task:
            count[i]=0
        self.idx = []
        for i in range(len(self.targets)):
            if self.targets[i] in self.task:
                if count[self.targets[i]]<sub_size:
                    self.idx.append(i)
                    count[self.targets[i]]+=1
        
    def __len__(self):
        return len(self.idx)

    def __getitem__(self, index):
        img, target = self.data[self.idx[index]], self.targets[self.idx[index]]
        img = Image.fromarray(img)
        img = self.transform(img) 
        img = normalize(img)      
        return img, target

class SplitMnistFashion(datasets.FashionMNIST):
    
    def __init__(self, root="datasets", train=True, task=[]):
        super().__init__(root, train, download=True)
        self.transform = transforms.ToTensor()
        self.task = task
        self.train = train
        self.idx = [i for i in range(len(self.targets)) if self.targets[i] in task]
        
    def __len__(self):
        return len(self.idx)

    def __getitem__(self, index):
        img, target = self.data[self.idx[index]], self.targets[self.idx[index]]
        img = Image.fromarray(img.numpy(), mode='L')
        img = self.transform(img)  
        #return img.view(-1), target     
        return img, target
        
class SplitMNIST(datasets.MNIST):
    
    def __init__(self, root="datasets", train=True, task=[]):
        super().__init__(root, train, download=True)
        self.transform = transforms.ToTensor()
        self.task = task
        self.train = train
        self.idx = [i for i in range(len(self.targets)) if self.targets[i] in task]
        
    def __len__(self):
        return len(self.idx)

    def __getitem__(self, index):
        img, target = self.data[self.idx[index]], self.targets[self.idx[index]]
        img = Image.fromarray(img.numpy(), mode='L')
        img = self.transform(img)       
        #return img.view(-1), target
        return img, target