import numpy as np
from torch.utils.data.sampler import BatchSampler

# +
class MemoryMultiplerSampler(BatchSampler):
    def __init__(self, task_id, cur_task, batch_size, mem_multiplier, drop_last=False):
        self.task_id = task_id
        self.batch_size = batch_size
        self.mem_multiplier = mem_multiplier
        
        self.new_data_idx = np.where(task_id == cur_task)[0]
        self.old_data_idx = np.where(task_id < cur_task)[0]
        
        self.old_batch_size = min(len(self.old_data_idx), batch_size)
        
        self.drop_last = drop_last
            
    def __iter__(self):
        np.random.shuffle(self.new_data_idx)
        np.random.shuffle(self.old_data_idx)
        
        old_batch_index = 0
        for batch_index in range(len(self)):
            low_index = batch_index * self.batch_size
            high_index = min(len(self.new_data_idx), (batch_index + 1) * self.batch_size)
            
            new_data = self.new_data_idx[low_index:high_index].tolist()
        
            old_data = []
            for _ in range(self.mem_multiplier):
                low_index = old_batch_index * self.old_batch_size
                high_index = (old_batch_index + 1) * self.old_batch_size
                old_batch_index += 1
                
                old_data += self.old_data_idx[low_index:high_index].tolist()
                
                if (high_index+self.old_batch_size) > len(self.old_data_idx):
                    np.random.shuffle(self.old_data_idx)
                    old_batch_index = 0                    
                        
            yield new_data + old_data
    
    def __len__(self):
        if self.drop_last:
            return len(self.new_data_idx) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.new_data_idx) + self.batch_size - 1) // self.batch_size
        
#         return len(self.new_data_idx) // self.batch_size

