import os, json
from PIL import Image
from torch.utils.data import Dataset

import numpy as np
import torch

import torchvision
import random

class ImageNetDataset(Dataset):
    def __init__(self, data_dir, dataset_seed, transform=None, specific_classes='None'):
        self.transform = transform

        self.data = []
        self.labels = []

        self.class_names = np.sort([d for d in os.listdir(data_dir + 'train/')]).tolist()

        class_gen = random.Random(dataset_seed)
        class_gen.shuffle(self.class_names)

        if specific_classes == 'old':
            initial_classes = [
                'n01443537', 'n07714990', 
                'n06359193', 'n01614925', 
                'n13133613', 'n02109961', 
                'n02130308', 'n02165456', 
                'n04086273', 'n02504458', 
                'n02509815', 'n02129604', 
                'n01644373', 'n02219486', 
                'n03977966', 'n02690373', 
                'n02701002', 'n01860187', 
                'n03777568', 'n02276258', 
                'n07747607', 'n03345487', 
                'n02951358', 'n07745940', 
                'n03794056', 'n07749582', 
                'n03874599', 'n04074963', 
                'n02391049', 'n04243546', 
                'n01843383', 'n04399382', 
                'n01514668', 'n11939491', 
                'n03841143', 'n04146614', 
                'n02007558', 'n01484850', 
                'n03590841', 'n01773797'
                ]
            
            remaining_classes = [item for item in self.class_names if item not in initial_classes]

            self.class_names = initial_classes + remaining_classes
        elif specific_classes == 'categories':
            mapping = json.load(open(data_dir + 'imagenet_map.json', 'r'))

            dogs = ['Afghan_hound', 'African_hunting_dog', 'Airedale', 'American_Staffordshire_terrier', 'Boston_bull', 
                    'Appenzeller', 'Australian_terrier', 'Bedlington_terrier', 'Bernese_mountain_dog', 'Blenheim_spaniel',
                    'Border_collie', 'Border_terrier', 'Boston_bull', 'Bouvier_des_Flandres', 'Brabancon_griffon',
                    'Brittany_spaniel', 'Cardigan', 'Chesapeake_Bay_retriever', 'Chihuahua', 'Dandie_Dinmont']
            dogs = [mapping[dog] for dog in dogs]


            clothing = ['Loafer', 'abaya', 'academic_gown', 'apron', 'bib',
                        'bikini', 'brassiere', 'breastplate', 'bulletproof_vest', 'cardigan',
                        'chain_mail', 'cloak', 'clog', 'cowboy_boot', 'cowboy_hat',
                        'cuirass', 'diaper', 'fur_coat', 'gown', 'hoopskirt']
            clothing = [mapping[cloth] for cloth in clothing]

            birds = ['African_grey', 'American_coot', 'American_egret', 'European_gallinule', 'albatross', 
                     'bald_eagle', 'bee_eater', 'bittern', 'black_grouse', 'black_stork', 
                     'black_swan', 'brambling', 'bulbul', 'bustard', 'chickadee', 
                     'cock', 'coucal', 'flamingo', 'dowitcher', 'drake']
            birds = [mapping[bird] for bird in birds]

            instruments = ['French_horn', 'accordion', 'acoustic_guitar', 'banjo', 'bassoon', 
                           'cello', 'chime', 'cornet', 'drum', 'drumstick', 
                           'electric_guitar', 'flute', 'gong', 'grand_piano', 'harmonica', 
                           'harp', 'maraca', 'marimba', 'oboe', 'ocarina']
            instruments = [mapping[instrument] for instrument in instruments]

            vehicles = ['Model_T', 'ambulance', 'amphibian', 'beach_wagon', 'bicycle-built-for-two', 
                        'bobsled', 'cab', 'convertible', 'crane', 'dogsled',
                        'fire_engine', 'forklift', 'garbage_truck', 'go-kart', 'golfcart', 
                        'grille', 'half_track', 'harvester', 'horse_cart', 'jeep']
            vehicles = [mapping[vehicle] for vehicle in vehicles]


            initial_tasks = [vehicles, birds, instruments, dogs, clothing]
            
            class_gen.shuffle(initial_tasks)

            initial_classes = []
            for task in initial_tasks:
                initial_classes += task

            self.class_names = initial_classes 


        for y, clas in enumerate(self.class_names):
            for img in os.listdir(data_dir + 'train/' + clas):
                self.data.append(data_dir + 'train/' + clas + '/' + img)
                self.labels.append(y)

            for img in os.listdir(data_dir + 'val/' + clas):
                self.data.append(data_dir + 'val/' + clas + '/' + img)
                self.labels.append(y)

        self.labels = np.array(self.labels)
        self.task_ids = -1 * np.ones(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        f_name = self.data[index]
        label = self.labels[index]

        img = Image.open(f_name).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label, index
    

class TinyImageNetDataset(Dataset):
    def __init__(self, data_dir, dataset_seed, transform=None):
        self.transform = transform

        self.data = []
        self.labels = []
        
        self.class_names = np.sort([d for d in os.listdir(data_dir + 'train/')]).tolist()
        class_gen = random.Random(dataset_seed)
        class_gen.shuffle(self.class_names)

        for y, clas in enumerate(self.class_names):
            for img in os.listdir(data_dir + 'train/' + clas + '/images/'):
                self.data.append(data_dir + 'train/' + clas + '/images/' + img)
                self.labels.append(y)


        with open(data_dir + 'val/val_annotations.txt', 'r') as f:
            for line in f.readlines():
                items = line.split('\t')
                self.data.append(data_dir + 'val/images/' + items[0])
                y = np.argwhere(class_names == items[1])[0][0]
                self.labels.append(y)


        self.labels = np.array(self.labels)
        self.task_ids = -1 * np.ones(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        f_name = self.data[index]
        label = self.labels[index]
        task = self.task_ids[index]

        img = Image.open(f_name).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label, task
    
class Places365Dataset(Dataset):
    def __init__(self, data_dir, dataset_seed, transform=None):
        self.transform = transform

        self.data = []
        self.labels = []

        self.class_names = np.sort([d for d in os.listdir(data_dir + 'train/')]).tolist()
        class_gen = random.Random(dataset_seed)
        class_gen.shuffle(self.class_names)

        for y, clas in enumerate(self.class_names):
                for img in os.listdir(data_dir + 'train/' + clas):
                    self.data.append(data_dir + 'train/' + clas + '/' + img)
                    self.labels.append(y)

                for img in os.listdir(data_dir + 'val/' + clas):
                    self.data.append(data_dir + 'val/' + clas + '/' + img)
                    self.labels.append(y)

        self.labels = np.array(self.labels)
        self.task_ids = -1 * np.ones(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        f_name = self.data[index]
        label = self.labels[index]

        img = Image.open(f_name).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label, index
    
class Stream51Dataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.transform = transform

        self.data = []
        self.labels = []
        self.object_ids = []
        self.frame_pos = []
        class_names = np.sort([d for d in os.listdir(data_dir + 'train/')])
        
        id_map = {}
        counter = 0
        for y, clas in enumerate(class_names):
            cls_ids = []
            for img in os.listdir(data_dir + 'train/' + clas):
                self.data.append(data_dir + 'train/' + clas + '/' + img)
                obj_id = str(y) + '-' + img.split('_')[0]
                if obj_id in id_map:
                    obj_id = id_map[obj_id]
                else:
                    id_map[obj_id] = counter
                    obj_id = id_map[obj_id]
                    counter += 1

                self.object_ids.append(obj_id)
                self.frame_pos.append(int(img.split('_')[2].split('.')[0]))
                self.labels.append(y)


        self.labels = np.array(self.labels)
        self.object_ids = np.array(self.object_ids)
        self.frame_pos = np.array(self.frame_pos)

        self.task_ids = -1 * np.ones(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        f_name = self.data[index]
        label = self.labels[index]
        task = self.task_ids[index]
        obj_lab = self.object_ids[index]
        
        img = Image.open(f_name).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label, task, obj_lab

class SVHNDataset():
    def __init__(self, data_dir, transform=None):
        self.transform = transform

        self.train = torchvision.datasets.SVHN('datasets/', split='train', download=True)
        self.test = torchvision.datasets.SVHN('datasets/', split='test', download=True)

        self.data = np.concatenate((self.train.data, self.test.data))
        self.data = np.transpose(self.data, (0, 2, 3, 1))
        self.labels = np.concatenate((self.train.labels, self.test.labels))
        self.task_ids = -1 * np.ones(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = Image.fromarray(self.data[index])
        label = self.labels[index]
        task = self.task_ids[index]

        return img, label, task

class CIFAR100Dataset():
    def __init__(self, data_dir, transform=None):
        self.transform = transform

        self.train = torchvision.datasets.CIFAR100('datasets/', train=True, download=True)
        self.test = torchvision.datasets.CIFAR100('datasets/', train=False, download=True)

        self.data = np.concatenate((self.train.data, self.test.data))
        print(self.data.dtype)
        self.labels = np.concatenate((self.train.targets, self.test.targets))
        self.task_ids = -1 * np.ones(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = Image.fromarray(self.data[index])
        label = self.labels[index]
        task = self.task_ids[index]

        return img, label, task
        
class NumpyDataset(Dataset):

    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def add_data(self, x, y):
        self.data = torch.cat((self.data, x))
        self.labels = np.concatenate((self.labels, y))

    def __getitem__(self, index):
        img = self.data[index]
        label = self.labels[index]

        if self.transform:
            img = self.transform(img)

        return img, label, 0


def load_dataset(dataset_conf):

    if dataset_conf.name == 'imagenet':
        return ImageNetDataset(dataset_conf.dir, dataset_conf.seed, specific_classes=dataset_conf.specific_classes)
    elif dataset_conf.name == 'tinyimagenet':
        return TinyImageNetDataset(dataset_conf.dir, dataset_conf.seed)
    elif dataset_conf.name == 'places365':
        return Places365Dataset(dataset_conf.dir, dataset_conf.seed)
    elif dataset_conf.name == 'stream51':
        return Stream51Dataset(dataset_conf.dir)
    elif dataset_conf.name == 'cifar100':
        return CIFAR100Dataset('datasets/')
    elif dataset_conf.name == 'svhn':
        return SVHNDataset('datasets/')
