import pickle
import cv2
import os
import numpy as np
import glob
from sklearn.model_selection import train_test_split

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

class ImageNetRDataset:
    def __init__(self, root_dir, train=True, num_tasks=10, seed=42):
        folders = glob.glob(f"{root_dir}/*/")
        self.train = train
        
        class_list = dict([(os.path.basename(x[:-1]), i) for i, x in enumerate(folders)])
        self.superclass_dict = dict([(v, k) for k,v in class_list.items()])
        
        self.data = []
        self.label = []
        for folder in folders:
            image_paths = glob.glob(os.path.join(folder, "*"))
            labels = [class_list[os.path.basename(folder[:-1])] for _ in range(len(image_paths))]
            # images = [os.path.basename(image_path) for image_path in image_paths]
                        
            image_train, image_test, label_train, label_test = train_test_split(image_paths, labels, test_size=0.1, random_state=seed)
            
            if train:
                self.data.extend(image_train)
                self.label.extend(label_train)
            else:
                self.data.extend(image_test)
                self.label.extend(label_test)               

        self.task_num = 0
        self.label_mapping = {}

        num_classes = len(self.superclass_dict)
        self.class_per_task = num_classes // num_tasks
        
        assert self.class_per_task > 0
                
        self.subset_data, self.subset_label, self.subset_classes = self.extract_subset()
        
    def extract_subset(self):
        subset_classes = []
        for idx, i in enumerate(range(self.task_num * self.class_per_task, (self.task_num + 1) * self.class_per_task)):
            self.label_mapping[self.superclass_dict[i]] = {"task_num": self.task_num, "label_num": idx}
            subset_classes.append(self.superclass_dict[i])
            
        data = []
        label = []
        for i in range(len(self.data)):
            l = self.label[i]
            if self.superclass_dict[l] in subset_classes:
                data.append(self.data[i])
                label.append(self.label_mapping[self.superclass_dict[l]]["label_num"])
        
        data = np.stack(data)
        label = np.array(label)
        
        return data, label, subset_classes
        
    def set_task(self, task_num=0):
        self.task_num = task_num
        self.subset_data, self.subset_label, self.subset_classes = self.extract_subset()

    def __len__(self):
        return len(self.subset_data)

    def __getitem__(self, idx):
        image = self.subset_data[idx]
        label = self.subset_label[idx]
        
        image = cv2.imread(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        return image, label, self.train
    
if __name__ == "__main__":
    dataset = ImageNetRDataset("Dataset/imagenet-a", num_tasks=10)
    
    for t in range(10):
        dataset.set_task(t)
        print(dataset.subset_classes)
        for image, label in dataset:
            print(image.shape, label)
            cv2.imwrite("test.png", image)
            break