import pickle
import cv2
import numpy as np

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

class CIFAR100Dataset:
    def __init__(self, root_dir, train=True, num_tasks=10):
        metadata_path = f'{root_dir}/meta'
        metadata = unpickle(metadata_path)
        self.superclass_dict = dict(list(enumerate(metadata[b'fine_label_names'])))
        
        # File paths
        if train:
            data_path = root_dir + '/train'
        else:
            data_path = root_dir + '/test'
        # Read dictionary
        data_dict = unpickle(data_path)
        # Get data (change the coarse_labels if you want to use the 100 classes)
        self.data = data_dict[b'data']
        self.label = np.array(data_dict[b'fine_labels'])
        
        self.task_num = 0
        self.label_mapping = {}

        num_classes = len(self.superclass_dict)
        self.class_per_task = num_classes // num_tasks
        
        assert self.class_per_task > 0
                
        self.subset_data, self.subset_label, self.subset_classes = self.extract_subset()
        
    def extract_subset(self):
        subset_classes = []
        for idx, i in enumerate(range(self.task_num * self.class_per_task, (self.task_num + 1) * self.class_per_task)):
            self.label_mapping[self.superclass_dict[i]] = {"task_num": self.task_num, "label_num": idx}
            subset_classes.append(self.superclass_dict[i])
            
        data = []
        label = []
        for i in range(self.data.shape[0]):
            l = self.label[i]
            if self.superclass_dict[l] in subset_classes:
                data.append(self.data[i,:])
                label.append(self.label_mapping[self.superclass_dict[l]]["label_num"])
            if len(data) == 16:
                break
        
        data = np.stack(data)
        label = np.array(label)
        
        return data, label, subset_classes
        
    def set_task(self, task_num=0):
        self.task_num = task_num
        self.subset_data, self.subset_label, self.subset_classes = self.extract_subset()

    def __len__(self):
        return self.subset_data.shape[0]

    def arrayToNumpy(self, array):
        image = array.reshape(3, 32, 32)
        image = np.transpose(image, (1, 2, 0))    

        return image

    def __getitem__(self, idx):
        image = self.arrayToNumpy(self.subset_data[idx, :])
        label = self.subset_label[idx]
        
        return image, label
    
if __name__ == "__main__":
    dataset = CIFAR100Dataset("Dataset/cifar-100-python")
    for task_num in range(10):
        dataset.set_task(task_num)
        for image, label in dataset:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            cv2.imwrite(f"test_{task_num}.png", image)
            print(f"{dataset.superclass_dict[label]}")
            break