import torch
class SelectData():
    def __init__(self):
        pass
    def select_datas_for_classes(self):
        if self.task_classes is not None and self.task_id is not None:
            past_tasks = list(range(self.task_id)) if self.task_id > 0 else None
            self.seen_classes = []
            if past_tasks:
                self.seen_classes = [i for item in [self.task_classes[i] for i in past_tasks] for i in item]
            self.unseen_classes = list(set([i for item in self.task_classes for i in item]) - set(self.seen_classes) - set(
                self.task_classes[self.task_id]))

            target_array = torch.stack(self.targets)
            classes_array = target_array[:,self.task_classes[self.task_id]]
            res = torch.sum(classes_array,1)
            indexs = torch.where(res != 0)
            self.targets = [self.targets[i] for i in indexs[0]]

            # °ÑÃ»¼û¹ýµÄÀàµÄÐÅÏ¢È¥µô
            target_array = torch.stack(self.targets)
            for id in self.unseen_classes:
                target_array[:,id] = torch.zeros_like(target_array[:,id])
            self.targets = [target_array[i,:] for i in range(target_array.shape[0])]
            self.images = [self.images[i] for i in indexs[0]]