from torchvision.datasets import CIFAR100
import numpy as np
from PIL import Image



class iCIFAR100(CIFAR100):
    def __init__(self,root,
                 train=True,
                 transform=None,
                 target_transform=None,
                 test_transform=None,
                 target_test_transform=None,
                 download=False):
        super(iCIFAR100,self).__init__(root,
                                       train=train,
                                       transform=transform,
                                       target_transform=target_transform,
                                       download=download)

        self.target_test_transform=target_test_transform
        self.test_transform=test_transform
        self.TrainData = []
        self.TrainLabels = []
        self.TestData = []
        self.TestLabels = []

    def concatenate(self,datas,labels):
        con_data=datas[0]
        con_label=labels[0]
        for i in range(1,len(datas)):
            con_data=np.concatenate((con_data,datas[i]),axis=0)
            con_label=np.concatenate((con_label,labels[i]),axis=0)
        return con_data,con_label

    def getTestData(self, classes):
        datas,labels=[],[]
        for label in range(classes[0], classes[1]):
            data = self.data[np.array(self.targets) == label]
            datas.append(data)
            labels.append(np.full((data.shape[0]), label))
        self.TestData, self.TestLabels=self.concatenate(datas,labels)

    def getTrainData(self, data_ids):
        self.TrainData = self.data[data_ids]
        self.TrainLabels = np.array(self.targets)[data_ids]

    def getTrainItem(self,index):
        img, target = Image.fromarray(self.TrainData[index]), self.TrainLabels[index]

        if self.transform:
            img1=self.transform(img)
            img2=self.transform(img)

        if self.target_transform:
            target=self.target_transform(target)

        return index, (img1,img2), target

    def getTestItem(self,index):
        img, target = Image.fromarray(self.TestData[index]), self.TestLabels[index]

        if self.test_transform:
            img=self.test_transform(img)

        if self.target_test_transform:
            target=self.target_test_transform(target)

        return index, img, target

    def __getitem__(self, index):
        if self.TrainData!=[]:
            return self.getTrainItem(index)
        elif self.TestData!=[]:
            return self.getTestItem(index)


    def __len__(self):
        if self.TrainData!=[]:
            return len(self.TrainData)
        elif self.TestData!=[]:
            return len(self.TestData)

    def get_image_class(self,label):
        return self.data[np.array(self.targets)==label]
    
    def get_data_label(self, classes):
        datas,labels, ids = [],[], []
        for label in classes:
            data=self.data[np.array(self.targets)==label]
            datas.append(data)
            ids.append(np.arange(len(self.targets))[np.array(self.targets)==label])
            labels.append(np.full((data.shape[0]),label))
        TrainData, TrainLabels = self.concatenate(datas,labels)
        TrainIds=ids[0]
        for i in range(1,len(ids)):
            TrainIds=np.concatenate((TrainIds,ids[i]),axis=0)
        return TrainData, TrainLabels, TrainIds


