# +
import os
import numpy as np
from PIL import Image

import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from datasets.transforms.rotation import FixedRotation
from inclearn.lib.data.base import DomainIncrementalDataset

class iAVMNIST(DomainIncrementalDataset):
    NAME = 'rot-avmnist'
    N_TASKS = 20
    N_CLASSES = 10
    INDIM = (1, 28, 28)
    MAX_N_SAMPLES_PER_TASK = 60000
    
    def __init__(self, args, run_id):
        super().__init__(args)

    def new_task(self):



# +
class AVMNIST(torch.utils.data.Dataset):
    """
    Overrides the MNIST dataset to change the getitem function.
    """

    def __init__(self, root, domain_id, train=True, download=False):
        
        if train:
            self.visual = np.load(os.path.join(root, 'avmnist/image/train_data.npy'))
            self.audio = np.load(os.path.join(root, 'avmnist/audio/train_data.npy'))
            self.targets = np.load(os.path.join(root, 'avmnist/train_labels.npy'))
        else:
            self.visual = np.load(os.path.join(root, 'avmnist/image/test_data.npy'))
            self.audio = np.load(os.path.join(root, 'avmnist/audio/test_data.npy'))
            self.targets = np.load(os.path.join(root, 'avmnist/test_labels.npy'))            

        self.visual_transform = transforms.ToTensor()
        self.audio_transform = transforms.ToTensor()
        
        self.domain_id = domain_id
        
    def __getitem__(self, index):
        visual, audio, target = self.visual[index], self.audio[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        visual = Image.fromarray(img.reshape((28,28)), mode='L')

        visual = self.visual_transform(visual)
        audio = self.audio_transform(visual)
            
        return {'visual': visual, 'audio': audio, 'target': target, 'domain_id': self.domain_id}


