import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

import datautil.imgdata.util as imgutil
from datautil.mydataloader import InfiniteDataLoader

class ExemplarDataset(Dataset):
    '''
    Used for compute_class_mean
    input: imgs should be PIL image.
    '''
    def __init__(self, imgs, transform):
        self.imgs = imgs
        self.transform = transform
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, index):
        return self.transform(self.imgs[index])
        # return self.transform(Image.fromarray(self.imgs[index]))

class ReplayDataset(Dataset):
    '''
    construct replay dataset
    input: imgs should be PIL image.
    '''
    def __init__(self, images, class_labels, domain_labels, transform=None, target_transform=None):
        self.images = images
        self.labels = class_labels
        self.dlabels = domain_labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        imgs = self.transform(self.images[index]) if self.transform is not None else self.images[index]
        return imgs, self.labels[index], self.dlabels[index]
    
    def get_raw_data(self):
        return self.images, self.labels, self.dlabels

def concat_list(data_list):
    '''
    flatten list
    input: list of list [[..], .., [..]]
    return list [..]
    '''
    datas = []
    for l in data_list:
        for i in l:
            datas.append(i)
    return datas

def construct_BF_dataloader(args, Replay_algorithm, task_id):

    if len(Replay_algorithm.exemplar_set) > 0 and task_id > 0:
        # initialization
        domain_exemplar_set = {}         # each element store data in one domain
        domain_exemplar_label_set = {}
        domain_exemplar_dlabel_set = {}
        exemplar_set = Replay_algorithm.exemplar_set
        exemplar_label_set = Replay_algorithm.exemplar_label_set
        exemplar_dlabel_set = Replay_algorithm.exemplar_dlabel_set
        for domain in range(task_id+1):
            domain_exemplar_set['domain{}'.format(domain)] = []
            domain_exemplar_label_set['domain{}'.format(domain)] = []
            domain_exemplar_dlabel_set['domain{}'.format(domain)] = []

        if args.replay_mode == 'class':
            d_index = 0
            for i in range(len(exemplar_set)):
                domain_exemplar_set['domain{}'.format(d_index)].append(exemplar_set[i])
                domain_exemplar_label_set['domain{}'.format(d_index)].append(exemplar_label_set[i])
                domain_exemplar_dlabel_set['domain{}'.format(d_index)].append(exemplar_dlabel_set[i])
                if i % args.num_classes == 0 and i != 0:
                    d_index += 1
            for k in domain_exemplar_set.keys():
                domain_exemplar_set[k] = concat_list(domain_exemplar_set[k])
                domain_exemplar_label_set[k] = concat_list(domain_exemplar_label_set[k])
                domain_exemplar_dlabel_set[k] = concat_list(domain_exemplar_dlabel_set[k])
        elif args.replay_mode == 'domain':
            for i in range(len(exemplar_set)):
                domain_exemplar_set['domain{}'.format(i)] = exemplar_set[i]
                domain_exemplar_label_set['domain{}'.format(i)] = exemplar_label_set[i]
                domain_exemplar_dlabel_set['domain{}'.format(i)] = exemplar_dlabel_set[i]

        dataset_list = [ReplayDataset(domain_exemplar_set[k], domain_exemplar_label_set[k], domain_exemplar_dlabel_set[k], transform=imgutil.image_test(args)) for k in domain_exemplar_set.keys()]
        dataloaders = [InfiniteDataLoader(dataset=d, weights=None, batch_size=args.BF_bs, num_workers=args.N_WORKERS) for d in dataset_list]
        # dataloaders = [DataLoader(dataset=d, batch_size=args.BF_bs, shuffle=True, num_workers=args.N_WORKERS) for d in dataset_list]
        return dataloaders

    