import torch.nn as nn
import torch
from torchvision import transforms
import numpy as np
from torch.utils.data import DataLoader
from torch.nn import functional as F

import Replay.utils as utils
import datautil.imgdata.util as imgutil
from datautil.mydataloader import InfiniteDataLoader
from utils.util import log_print

class LDAuCID_buff:
    def __init__(self, args):
        self.args = args
        self.exemplar_set = []   # list of list[PIL image] : [[exemplar1 PIL image], [exemplar2 PIL image]...]
        self.exemplar_label_set = []  # list of np.array : [array(exemplar1 labels), array(exemplar2 labels)...]
        self.exemplar_dlabel_set = []
        self.replay_dataset = None

    def update_dataloader(self, dataloader=None):
        '''
        concatenate current task data and exemplars.
        if DGalgorithm is LDAuCID, return current task data and exemplar data seperately.
        '''
        exemplar_set = self.exemplar_set
        exemplar_label_set = self.exemplar_label_set
        exemplar_dlabel_set = self.exemplar_dlabel_set
        log_print('exemplar_set size: {}'.format(len(exemplar_set[0]) if len(exemplar_set)>0 else 0), self.args.log_file)
        replay_dataloader = None

        if len(exemplar_set) > 0:
            imgs = utils.concat_list(exemplar_set)
            labels = utils.concat_list(exemplar_label_set)
            dlabels = utils.concat_list(exemplar_dlabel_set)
            self.replay_dataset = utils.ReplayDataset(imgs, labels, dlabels, transform=imgutil.image_train(self.args))
            
            
        return self.replay_dataset

    def update(self, model, task_id, dataloader):
        if self.args.replay_mode == 'class':    # exemplar for each class and domain
            m=int(self.args.memory_size / (self.args.num_classes * (task_id+1)))
        elif self.args.replay_mode == 'domain':    # exemplar for each domain
            m=int(self.args.memory_size / (task_id+1))
        self._reduce_exemplar_sets(m)

        #images = np.array([np.asarray(dataloader.dataset.loader(dict)) for dict in dataloader.dataset.x][:])   # np.array(PIL image)
        image_dict, class_label, domain_label = dataloader.dataset.get_raw_data()
        images = [dataloader.dataset.loader(dict) for dict in image_dict]       # list of PIL image
        # class_label = dataloader.dataset.labels  # np.array
        # domain_label = dataloader.dataset.dlabels

        if self.args.replay_mode == 'class':  # each exemplar contains data of one class in one specific doamin
            for c in range(self.args.num_classes):
                # bool_index = (class_label == c)
                # imgs = images[bool_index]
                # indices = [i for i, x in enumerate(class_label) if x == c]
                indices = np.where(class_label == c)[0]
                if len(indices) == 0:
                    log_print('No class {} pseudo labels!!!'.format(c), self.args.log_file)
                    continue
                imgs = [images[i] for i in indices]         # list of PIL image
                clabel = class_label[class_label == c]
                dlabel = domain_label[class_label == c]
                self._construct_exemplar_set(model, imgs, clabel, dlabel, m)
        elif self.args.replay_mode == 'domain':  # each exemplar contains data of all classes in one specific doamin
            self._construct_exemplar_set(model, images, class_label, domain_label, m)
    
    def _construct_exemplar_set(self, model, images, class_label, domain_label, m):
        '''
        construct exemplar for each class in each domain
        input images should be one class in one specific domain
        '''
        exemplar_dataset = utils.ExemplarDataset(images, transform=imgutil.image_test(self.args))
        exemplar_dataloader = DataLoader(dataset=exemplar_dataset,
                                        shuffle=False,
                                        batch_size=self.args.batch_size,
                                        num_workers=self.args.N_WORKERS)
        model.eval()
        all_probabs = []
        with torch.no_grad():
            for i, x in enumerate(exemplar_dataloader):
                x = x.cuda()
                feature_extractor_outputs = (nn.ReLU()(model.network[1].fc0(model.network[0](x)))).cpu().numpy()
                probabs = np.max(model.gmm_model.predict_proba(feature_extractor_outputs),axis=1) 
                all_probabs.append(probabs)
        all_probabs = np.concatenate(all_probabs, axis=0)
        selected_index = all_probabs.argsort()[-m:][::-1]
                
        self.exemplar_set.append([images[i] for i in selected_index])
        self.exemplar_label_set.append(class_label[selected_index])
        self.exemplar_dlabel_set.append(domain_label[selected_index])


    def _reduce_exemplar_sets(self, m):
        for index in range(len(self.exemplar_set)):
            self.exemplar_set[index] = self.exemplar_set[index][:m]
        for index in range(len(self.exemplar_label_set)):
            self.exemplar_label_set[index] = self.exemplar_label_set[index][:m]
        for index in range(len(self.exemplar_dlabel_set)):
            self.exemplar_dlabel_set[index] = self.exemplar_dlabel_set[index][:m]