from .ver import *
from torch.nn import functional as F
import random

class ExperienceEvolveApprox(FOExperienceEvolve):
    def __init__(self, base, optimizer, input_size, cfg, goal):
        super().__init__(base, optimizer, input_size, cfg, goal)
        self.edit_least = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_LEAST', default=0)
        self.edit_random = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_RANDOM', default=0)

        self.edit_interfere = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_INTERFERE', default=1)
        self.edit_replace = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_REPLACE', default=0)
        self.replace_reweight = get_config_attr(cfg, 'EXTERNAL.OCL.REPLACE_REWEIGHT', default=0)
        self.reg_supportive = get_config_attr(cfg, 'EXTERNAL.OCL.REG_SUPPORTIVE', default=0)
        self.use_relu = get_config_attr(cfg, 'EXTERNAL.OCL.USE_RELU', default=0)
        self.reg_strength = get_config_attr(cfg, 'EXTERNAL.OCL.REG_STRENGTH', default=0.1)
        self.always_proj = get_config_attr(cfg, 'EXTERNAL.OCL.ALWAYS_PROJ', default=0)

        self.edit_mir_k = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_K', default=-1)

        self.hal_mem = get_config_attr(cfg, 'EXTERNAL.OCL.HAL_MEM', default=0)

        if self.edit_mir_k == -1:
            self.edit_mir_k = self.mir_k

        if self.reg_supportive:
            self.support_hash = {}
            self.support_hash_to_index = {}
            self.support_cnt = 0

    def sample_mem_batch_same_task(self, device, task_id_or_label, return_indices=False, mem_k=None, seed=0, use_same_label=False):
        if mem_k is None:
            mem_k = self.mem_bs
        if use_same_label:
            label = task_id_or_label
        else:
            task_id = task_id_or_label

        n_max = min(self.mem_limit, self.example_seen)
        indices = []
        for i in range(n_max):
            if use_same_label:
                if self.reservoir['y'][i] == label:
                    indices.append(i)
            else:
                if self.reservoir['y_extra'][i] == task_id:
                    indices.append(i)

        # reservoir
        if not indices:
            return None, None, None
        elif len(indices) >= mem_k:
            indices = np.random.RandomState(seed * self.example_seen + self.cfg.SEED).\
                choice(indices, mem_k, replace=False)

        x = self.reservoir['x'][indices]
        x_origin = self.reservoir['x_origin'][indices]

        x = torch.from_numpy(x).to(device).float()
        x_origin = torch.from_numpy(x_origin).to(device).float()
        y = index_select(self.reservoir['y'], indices, device) # [  [...], [...] ]
        y_extra = index_select(self.reservoir['y_extra'], indices, device)
        y_extra = concat_with_padding(y_extra)
        if type(y[0]) not in [list, tuple]:
            y_pad = concat_with_padding(y)
        else:
            y_pad = [torch.stack(_).to(device) for _ in zip(*y)]

        if not return_indices:
            return x, y_pad, y_extra
        else:
            return (x, indices), y_pad, y_extra

    def clear_mem_grad(self, mem_x):
        mem_x.detach_()
        mem_x.grad = None
        mem_x.requires_grad = True

    def observe(self, x, y, task_ids, extra=None, optimize=True, sequential=False):
        n_iter = get_config_attr(self.cfg, 'EXTERNAL.OCL.N_ITER', default=1, mute=True)
        batch_size = x.size(0)
        self.store_cache()
        for i_iter in range(n_iter):
            if not self.mir:
                mem_x_indices, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True, seed=i_iter + 1)
                if self.edit_random: # select another batch for editing
                    edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True,
                                                                               k=self.mir_k, seed=i_iter + 2)
                else:
                    edit_x_indices, edit_y, edit_task_ids = mem_x_indices, mem_y, mem_task_ids
            else:
                mem_x_indices, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True,
                                                                           input_x=x, input_y=y, input_task_ids=task_ids,
                                                                           mir_k=self.mir_k, mir=self.mir,
                                                                           skip_task=task_ids[0].item(),
                                                                           seed=i_iter + 1
                                                                           )
                if self.edit_least:
                    edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True,
                                                                               input_x=x, input_y=y, input_task_ids=task_ids,
                                                                               mir_k=self.edit_mir_k, mir=self.mir,
                                                                               skip_task=task_ids[0].item(),
                                                                               mir_least=True,
                                                                               seed=i_iter + 2
                                                                               )
                elif self.edit_random:
                    edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True,
                                                                               k=self.edit_mir_k, seed=i_iter + 2)
                else:
                    edit_x_indices, edit_y, edit_task_ids = mem_x_indices, mem_y, mem_task_ids
            self.optimizer.zero_grad()

            edit_x_val_indices, edit_y_val, _ = self.sample_mem_batch_same_task(x.device, task_ids.cpu().numpy()[0],
                                                                              return_indices=True, seed=i_iter + 2,
                                                                              mem_k=self.mir_k if self.mir else self.mem_bs)
            edit_task_ids_val = task_ids
            #mem_x_val_indices, _, mem_y_val, mem_task_ids_val = self.sample_mem_batch(x.device, return_indices=True, seed=1)

            if edit_x_indices is None:
                combined_x, combined_y, combined_task_ids = x, y, task_ids
            else:
                mem_x, mem_indices = mem_x_indices
                edit_x, edit_indices = edit_x_indices
                if edit_x_val_indices is not None:
                    edit_x_val, indices_val = edit_x_val_indices
                else:
                    edit_x_val, indices_val = None, None
                if i_iter == 0 and self.edit_interfere:
                    if self.hal_mem:
                        train_x, train_y, _ = self.sample_mem_batch(x.device, return_indices=False,
                                                                    seed=i_iter + 3)
                    else:
                        train_x, train_y = x, y

                    edit_x, mem_x = self.edit_mem_interfere(train_x, train_y, task_ids, mem_x, mem_y, edit_x, edit_y, edit_task_ids,
                           edit_x_val, edit_y_val, edit_task_ids_val, edit_indices)
                    self.evolve_mem(edit_x, edit_indices)

                    # load cached parameters back
                    self.load_cache()

                combined_x = torch.cat([x, mem_x], 0)
                combined_y = concat_with_padding([y, mem_y])
                combined_task_ids = concat_with_padding([task_ids, mem_task_ids])

            ret_dict = self.forward_net(combined_x, combined_y, task_ids=combined_task_ids, reduce=False)
            loss_tmp = ret_dict['loss']
            if optimize:
                loss = loss_tmp[: x.size(0)].mean()
                if mem_x_indices is not None:
                    loss += loss_tmp[x.size(0):].mean()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if self.edit_replace:
            for b in range(batch_size):
                self.update_mem_reservoir_with_edit(x[b], y[b], task_ids[b])
        elif self.replace_reweight:
            for b in range(batch_size):
                self.update_mem_reservoir_reweight(x[b], y[b], task_ids[b], loss_x=loss_tmp[b].item())
        elif self.reg_supportive:
            for b in range(batch_size):
                self.update_mem_supportive_examples(x[b], y[b], task_ids[b])
        else:
            for b in range(batch_size):
                if type(y) is tuple:
                    self.update_mem(x[b], [_[b] for _ in y], task_ids[b])
                else:
                    self.update_mem(x[b], y[b], task_ids[b])

        if self.replace_reweight and not (edit_x_indices is None or edit_x_val_indices is None or len(edit_x_val_indices[0]) <= 1):
            mem_loss = loss_tmp[x.size(0):]
            for b in range(mem_x.size(0)):
                self.update_loss_stat(mem_indices[b], mem_loss[b].item())

        return ret_dict

    def edit_mem_interfere(self, x, y, task_ids, mem_x, mem_y, edit_x, edit_y, edit_task_ids,
                           edit_x_val, edit_y_val, edit_task_ids_val, edit_indices):
        """
        Edit memory so that they are more inter
        :param x:
        :param y:
        :param task_ids:
        :param mem_x:
        :param mem_y:
        :param edit_x:
        :param edit_y:
        :param edit_task_ids:
        :param edit_x_val:
        :param edit_y_val:
        :param edit_task_ids_val:
        :return:
        """
        device = x.device
        # only edit at the first iter

        for i in range(self.grad_iter):
            # evaluate loss on edit_x, edit_y
            self.clear_mem_grad(edit_x)
            # evaluate grad of l wrt edit
            ret_dict_edit_before = self.forward_net(edit_x, edit_y, reduce=True, task_ids=edit_task_ids)
            # train the model on D
            grad_reg = -torch.autograd.grad(torch.sum(ret_dict_edit_before['loss']),
                                            edit_x, retain_graph=True)[0]
            ret_dict_edit_before['loss'].backward()
            edit_x_grad1 = edit_x.grad

            self.clear_mem_grad(edit_x)

            for _ in range(1):
                ret_dict_d = self.forward_net(x, y, task_ids=task_ids)
                self.optimizer.zero_grad()
                ret_dict_d['loss'].backward(retain_graph=False)
                if isinstance(self.optimizer, torch.optim.SGD):
                    step_wo_state_update_sgd(self.optimizer, amp=1.)
                elif isinstance(self.optimizer, torch.optim.Adam):
                    step_wo_state_update_adam(self.optimizer, amp=1.)
                else:
                    raise NotImplementedError

            ret_dict_edit_after = self.forward_net(edit_x, edit_y, reduce=True, task_ids=edit_task_ids)
            if 'mask_cnts' not in ret_dict_edit_after:
                #mean_a = (ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']).mean().detach()
                #mean_b = ret_dict_edit_after['loss'].mean().detach()
                #loss_increase = ret_dict_edit_after['loss'].mean() * (mean_a / mean_b)
                loss_increase = ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']
            else:
                loss_increase = (ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']).sum() / \
                                (sum(ret_dict_edit_after['mask_cnts']) + 1e-10)

            #if self.use_relu:
            #    loss_increase = F.relu(loss_increase)
            ret_dict_edit_after['loss'].backward()
            edit_x_grad2 = edit_x.grad

            grad_delta = edit_x_grad2 - edit_x_grad1
            grad_delta_2 = 0

            self.clear_mem_grad(edit_x)
            self.load_cache()

            total_grad = 0

            if self.cfg.EXTERNAL.OCL.USE_LOSS_1:
                total_grad += self.cfg.EXTERNAL.OCL.USE_LOSS_1 * grad_delta
                if self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -2: # random
                    random_vecs = self.get_random_grad(grad_delta, edit_indices)
                    total_grad += random_vecs
            if self.cfg.EXTERNAL.OCL.USE_LOSS_2:
                total_grad += self.cfg.EXTERNAL.OCL.USE_LOSS_2 * grad_delta_2

            if type(total_grad) is not int:  # has grad update
                #if self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG:
                if self.reg_supportive:
                    total_grad = self.project_edit_x_gradient_support(total_grad, edit_x, edit_y, edit_task_ids, edit_indices,
                                                              device)
                elif self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG == 1:
                    for b in range(total_grad.size(0)):
                        total_grad[b] = proj_grad(-grad_reg[b], total_grad[b], binary=False, always_proj=self.always_proj)
                elif self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG == 2:
                    for b in range(total_grad.size(0)):
                        total_grad[b] = -grad_reg[b] * self.reg_strength + total_grad[b]

                for b in range(total_grad.size(0)):
                    edit_x[b] = edit_x[b] + total_grad[b] * self.grad_stride
            edit_x = edit_x.detach()
            mem_x = mem_x.detach()

        return edit_x, mem_x

    def get_random_grad(self, grad, indices):
        if not hasattr(self, 'random_dirs'):
            self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).uniform_(-1,1)
        random_vecs = []
        for i, indice in enumerate(indices):
            random_vec = self.random_dirs[indice].to(grad.device)
            random_vec = random_vec / random_vec.norm()
            random_vecs.append(random_vec)
        random_vecs = torch.stack(random_vecs)
        return random_vecs

    def project_edit_x_gradient_support(self, total_grad, edit_x, edit_y, edit_task_ids, edit_indices, device):
        proj_grad_delta = torch.zeros_like(total_grad)
        for b in range(proj_grad_delta.size(0)):
            supportive_examples = self.get_supportive_examples(edit_indices[b], device=device)
            if supportive_examples is not None:
                #self.store_cache()
                edit_x_b = edit_x[b].detach()
                edit_x_b.requires_grad = True
                edit_x_b.grad = None
                supp_x, supp_y, supp_task_ids = supportive_examples

                if self.reg_supportive in [1,2]:
                    ret_dict_edit_before = self.forward_net(edit_x_b.view(1, -1), edit_y[b].view(-1),
                                                            edit_task_ids[b].view(-1), reduce=False)

                    ret_dict_supp = self.forward_net(supp_x, supp_y, reduce=True,
                                                     task_ids=supp_task_ids)
                    self.optimizer.zero_grad()
                    ret_dict_supp['loss'].backward()
                    if isinstance(self.optimizer, torch.optim.SGD):
                        step_wo_state_update_sgd(self.optimizer, amp=1.)
                    elif isinstance(self.optimizer, torch.optim.Adam):
                        step_wo_state_update_adam(self.optimizer, amp=1.)
                    else:
                        raise NotImplementedError
                    ret_dict_edit_after = self.forward_net(edit_x_b.view(1, -1), edit_y[b].view(-1),
                                                           edit_task_ids[b].view(-1), reduce=False)
                    loss_increase = ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']
                    loss_increase.backward()
                    grad2 = edit_x_b.grad
                elif self.reg_supportive in [3,4]:
                    ret_dict_supp_before = self.forward_net(supp_x, supp_y, reduce=True, task_ids=supp_task_ids)
                    ret_dict_edit = self.forward_net(edit_x_b.view(1, -1), edit_y[b].view(-1),
                                                     edit_task_ids[b].view(-1))
                    self.optimizer.zero_grad()
                    ret_dict_edit['loss'].backward(create_graph=True)

                    if isinstance(self.optimizer, torch.optim.SGD):
                        weights = get_updated_weights_sgd(self.optimizer)
                    else:
                        raise NotImplementedError

                    ret_dict_supp_after = self.forward_net(supp_x, supp_y, reduce=True, task_ids=supp_task_ids,
                                                        from_weights=True, weights=weights)

                    loss_increase = ret_dict_supp_after['loss'] - ret_dict_supp_before['loss']
                    loss_increase.backward()
                    grad2 = edit_x_b.grad

                if self.reg_supportive % 2 == 1: # do grad projection
                    proj_grad_delta[b] = proj_grad(-grad2, total_grad[b], binary=True)
                elif self.reg_supportive % 2 == 0:
                    proj_grad_delta[b] -= grad2
                self.load_cache()
                self.net.zero_grad()
            else:
                proj_grad_delta[b] = total_grad[b]
        total_grad = proj_grad_delta
        return total_grad

    def update_loss_stat(self, x_idx, loss):
        self.reservoir['loss_stats'][x_idx].append(loss)
        self.reservoir['loss_stat_steps'][x_idx].append(self.example_seen)
        self.reservoir['forget'][x_idx] = loss - min(self.reservoir['loss_stats'][x_idx])
        # defined as loss increase compared to when it is intially learned

    def to_mem_type(self, x, y, y_extra):
        x = x.cpu().numpy()
        if type(y) not in [list, tuple]:
            y = y_to_np(y)
        else:
            y = y_to_cpu(y)
        if type(y_extra) not in [list, tuple]:
            y_extra = y_to_np(y_extra)
        else:
            y_extra = y_to_cpu(y_extra)
        return x, y, y_extra

    def update_mem_reservoir_reweight(self, x, y, y_extra=None, loss_x=None, *args, **kwargs):
        device = x.device
        x, y, y_extra = self.to_mem_type(x, y, y_extra)

        if self.example_seen < self.mem_limit:
            self.reservoir['x'][self.example_seen] = x
            self.reservoir['y'][self.example_seen] = y
            self.reservoir['y_extra'][self.example_seen] = y_extra
            j = self.example_seen
        else:
            j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.example_seen)
            if j < self.mem_limit:
                # the index to be replaced is j
                task_j = self.reservoir['y_extra'][j]
                # find the example task is most unlikely to be forgotton and replace that instead
                if self.replace_reweight == 1:
                    indices = [_ for _ in range(self.mem_limit) if self.reservoir['y_extra'][_] == task_j]
                elif self.replace_reweight == 2:
                    indices = [_ for _ in range(self.mem_limit)]
                else:
                    raise ValueError
                C = 50
                if len(indices) > C:
                    indices = random.sample(indices, k=C)
                # evaluate loss instantly on these indices
                if len(indices) > 1:
                    cand_x = self.reservoir['x'][indices]
                    cand_x = torch.from_numpy(cand_x).to(device).float()
                    cand_y = index_select(self.reservoir['y'], indices, device)  # [  [...], [...] ]
                    cand_y_extra = index_select(self.reservoir['y_extra'], indices, device)
                    if type(cand_y[0]) not in [list, tuple]:
                        cand_y_pad = concat_with_padding(cand_y)
                    else:
                        cand_y_pad = [torch.stack(_).to(device) for _ in zip(*cand_y)]
                    cand_y_extra = concat_with_padding(cand_y_extra)

                    with torch.no_grad():
                        ret_dict = self.forward_net(cand_x, cand_y_pad, cand_y_extra, reduce=False)
                        loss = ret_dict['loss']

                    for b in range(cand_x.size(0)):
                        self.reservoir['loss_stats'][indices[b]].append(loss[b].item())
                        self.reservoir['loss_stat_steps'][indices[b]].append(self.example_seen)
                        self.reservoir['forget'][indices[b]] = loss[b].item() - min(self.reservoir['loss_stats'][indices[b]])

                    j = min(indices, key=lambda x: self.reservoir['forget'][x])
                else:
                    j = indices[0]

                self.reservoir['x'][j] = x
                self.reservoir['y'][j] = y
                self.reservoir['y_extra'][j] = y_extra
        if j < self.mem_limit:
            self.reservoir['loss_stats'][j] = [loss_x]
            self.reservoir['loss_stat_steps'][j] = [self.example_seen]
            self.reservoir['forget'][j] = 0
        self.example_seen += 1

    def update_mem_reservoir_with_edit(self, x_th, y_th, y_extra_th=None, *args, **kwargs):
        device = x_th.device
        x, y, y_extra = self.to_mem_type(x_th, y_th, y_extra_th)
        if self.example_seen < self.mem_limit:
            self.reservoir['x'][self.example_seen] = x
            self.reservoir['y'][self.example_seen] = y
            self.reservoir['y_extra'][self.example_seen] = y_extra
        else:
            j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.example_seen)
            # j-th example will be discarded from the memory
            x_rand_and_indices, y_rand, task_ids_rand = self.sample_mem_batch_same_task(device, y, return_indices=True,
                                                                                    use_same_label=True,
                                                                                    mem_k=self.mem_bs if not self.mir else
                                                                                          self.mir_k)
            if x_rand_and_indices is not None:
                x_rand, x_rand_indices = x_rand_and_indices
                # evaluate loss decrease on x_th before and after training on x_rand
                self.store_cache()
                x_th = x_th.view(1,-1).repeat(x_rand.size(0),1)
                y_th = y_th.view(-1).repeat(y_rand.size(0))
                x_rand.requires_grad = True

                y_extra_th = y_extra_th.view(-1).repeat(y_rand.size(0))
                ret_dict_rand = self.forward_net(x_rand, y_rand, reduce=True, task_ids=task_ids_rand)
                ret_dict_d_before = self.forward_net(x_th, y_th, reduce=True, task_ids=y_extra_th)
                self.optimizer.zero_grad()
                ret_dict_rand['loss'].backward(create_graph=True)

                if isinstance(self.optimizer, torch.optim.SGD):
                    weights = get_updated_weights_sgd(self.optimizer)
                else:
                    raise NotImplementedError

                ret_dict_d_after = self.forward_net(x_th, y_th, reduce=True, task_ids=y_extra_th,
                                                    from_weights=True, weights=weights)
                loss_increase = ret_dict_d_after['loss'] - ret_dict_d_before['loss']
                loss_increase.backward()
                total_grad = - x_rand.grad

                for b in range(total_grad.size(0)):
                    gamma = 1.  # 1. if indices[b] % 2 else 0
                    x_rand[b] = x_rand[b] + gamma * total_grad[b] * self.grad_stride
                self.evolve_mem(x_rand.detach(), x_rand_indices)
                # update x_rand

                self.load_cache()

            if j < self.mem_limit:
                self.reservoir['x'][j] = x
                self.reservoir['y'][j] = y
                self.reservoir['y_extra'][j] = y_extra

        self.example_seen += 1

    def update_mem_supportive_examples(self, x_th, y_th, y_extra_th=None, *args, **kwargs):
        device = x_th.device
        x, y, y_extra = self.to_mem_type(x_th, y_th, y_extra_th)

        if self.example_seen < self.mem_limit:
            #self.reservoir['x'][self.example_seen] = x
            #self.reservoir['y'][self.example_seen] = y
            #self.reservoir['y_extra'][self.example_seen] = y_extra
            j = self.example_seen
        else:
            j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.example_seen)
        if j < self.mem_limit:
            self.reservoir['support'][j] = []
            # first change the hash of j
            if j in self.support_hash:
                self.support_hash_to_index.pop(self.support_hash[j])
            self.support_hash[j] = self.support_cnt
            self.support_hash_to_index[self.support_cnt] = j
            self.support_cnt += 1

            # draw a random subset of examples from the memory
            indices = [_ for _ in range(min(self.example_seen, self.mem_limit))]
            C, B = 50, 10
            if len(indices) > C:
                indices = random.sample(indices, k=C)
                if j in indices:
                    indices.pop(indices.index(j))
            if indices:
                cand_x, cand_y_pad, cand_y_extra = self.indices_to_examples(indices, device)
                self.store_cache()
                #self.load_cache ()

                x_th = x_th.view(1,-1).repeat(B,1)
                y_th = y_th.view(-1).repeat(B)
                y_extra_th = y_extra_th.view(-1).repeat(B)

                # evaluate loss on cand_x, cand_y_pad, cand_y_extra
                with torch.no_grad():
                    ret_cand_before = self.forward_net(cand_x, cand_y_pad, cand_y_extra, reduce=False)

                ret_x = self.forward_net(x_th, y_th, y_extra_th)
                self.optimizer.zero_grad()
                ret_x['loss'].backward()
                if isinstance(self.optimizer, torch.optim.SGD):
                    step_wo_state_update_sgd(self.optimizer)
                else:
                    raise NotImplementedError
                with torch.no_grad():
                    ret_cand_after = self.forward_net(cand_x, cand_y_pad, cand_y_extra, reduce=False)
                loss_decrease = ret_cand_before['loss'] - ret_cand_after['loss']

                # choose B examples with the largest loss decrease
                topv, topi = loss_decrease.topk(min(loss_decrease.size(0),B))
                for i in range(topv.size(0)):
                    if topv[i].item() > 0:
                        self.reservoir['support'][j].append(self.support_hash[indices[i]])
                self.load_cache()
            self.reservoir['x'][j] = x
            self.reservoir['y'][j] = y
            self.reservoir['y_extra'][j] = y_extra

        self.example_seen += 1

    def get_supportive_examples(self, mem_idx, device):
        indices = []
        for hash_idx in self.reservoir['support'][mem_idx]:
            if hash_idx in self.support_hash_to_index:
                indices.append(self.support_hash_to_index[hash_idx])
        if not indices:
            return None
        else:
            x, y, y_extra = self.indices_to_examples(indices, device)
            return x, y, y_extra

    def indices_to_examples(self, indices, device):
        cand_x = self.reservoir['x'][indices]
        cand_x = torch.from_numpy(cand_x).to(device).float()
        cand_y = index_select(self.reservoir['y'], indices, device)  # [  [...], [...] ]
        cand_y_extra = index_select(self.reservoir['y_extra'], indices, device)
        if type(cand_y[0]) not in [list, tuple]:
            cand_y_pad = concat_with_padding(cand_y)
        else:
            cand_y_pad = [torch.stack(_).to(device) for _ in zip(*cand_y)]
        cand_y_extra = concat_with_padding(cand_y_extra)
        return cand_x, cand_y_pad, cand_y_extra