import numpy as np

class RandomTaskReplayMemory:
    """
    A simple Replay memory for nccl
    """
    def __init__(self, num_tasks, memory_size, input_shape, num_outputs, sample_batch_size):
        if not isinstance(input_shape, list):
            input_shape = list(input_shape)
        
        memory_size = memory_size // num_tasks
        
        self.input_memory = np.zeros([num_tasks, memory_size] + input_shape, dtype=np.float32)
        self.label_memory = np.zeros([num_tasks, memory_size, num_outputs], dtype=np.int32)
        self.sample_batch_size = sample_batch_size
        self.memory_size = memory_size
        self.num_tasks = num_tasks

        self.num_outputs = num_outputs
        self.num_classes_per_task = int(self.num_outputs / self.num_tasks)
        self.ptr = 0
        

    def store(self, inputs, labels, task_idx):
        batch_size = inputs.shape[0]
        start_point = self.ptr
        if self.ptr + batch_size > self.memory_size:
            end_point = self.memory_size            
            start_point2 = 0
            end_point2 = (self.ptr + batch_size) - self.memory_size

            data_limit = self.memory_size - self.ptr 

            self.input_memory[task_idx, start_point:end_point] = inputs[:data_limit]
            self.label_memory[task_idx, start_point:end_point] = labels[:data_limit]

            self.input_memory[task_idx, start_point2:end_point2] = inputs[data_limit:]
            self.label_memory[task_idx, start_point2:end_point2] = labels[data_limit:]

            self.ptr = end_point2
        else:
            end_point = self.ptr + batch_size
            self.input_memory[task_idx, start_point:end_point] = inputs
            self.label_memory[task_idx, start_point:end_point] = labels
            self.ptr = self.ptr + batch_size

    def sample(self, current_task_idx):
        """
        memory architecture: [T, M, H, W, C]
        B samples on [T, M]
        """
        a = current_task_idx * self.memory_size
        sampled_idx = np.random.choice(a, size=self.sample_batch_size, replace=False)
        mask = np.zeros(self.num_tasks * self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.reshape(self.num_tasks, self.memory_size).astype(np.bool)
        
        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        return inputs, labels

    def sample_for_split(self, current_task_idx):
        a = current_task_idx * self.memory_size
        sampled_idx = np.random.choice(a, size=self.sample_batch_size, replace=False)
        mask = np.zeros(self.num_tasks * self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.reshape(self.num_tasks, self.memory_size).astype(np.bool)

        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        labels_int = np.argmax(labels, axis=-1)
        #offsets = np.stack([np.array([labels_int[i] // self.num_classes_per_task, labels_int[i] // self.num_classes_per_task+1] for i in range(len(labels_int)))])
        
        offsets = np.zeros_like(labels, np.int32)
        for i in range(labels.shape[0]):
            c = int(labels_int[i] // self.num_classes_per_task)
            offsets[i, self.num_classes_per_task * c: self.num_classes_per_task * (c+1)] = 1
        
        return inputs, labels, offsets

class ReservoirMemory:
    def __init__(self, num_tasks, memory_size, input_shape, num_outputs, sample_batch_size):
        if not isinstance(input_shape, list):
            input_shape = list(input_shape)
        
        self.sample_batch_size = sample_batch_size
        self.memory_size = memory_size  # * num_tasks
        self.num_tasks = num_tasks
        self.num_outputs = num_outputs

        self.num_classes_per_task = int(self.num_outputs / self.num_tasks)
        self.input_memory = np.zeros([self.memory_size] + input_shape, dtype=np.float32)
        self.label_memory = np.zeros([self.memory_size, num_outputs], dtype=np.int32)

        self.ptr = 0
        # The number of seen data is N
        self.N = 0

    def store(self, inputs, labels, task_idx):
        # print(f"memory size is {self.memory_size}")

        for inp, label in zip(inputs, labels):
            if self.N  < self.memory_size:
                self.input_memory[self.N] = inp
                self.label_memory[self.N] = label
            else:
                j = np.random.randint(0, self.N)
                if j < self.memory_size:
                    self.input_memory[j] = inp
                    self.label_memory[j] = label

            self.N = self.N + 1


    def sample(self, current_task_idx):
        """
        memory architecture: [T * M, H, W, C]
        B samples on [T * M]
        """
        a = np.minimum(self.N, self.memory_size)
        sampled_idx = np.random.choice(a, size=self.sample_batch_size, replace=False)
        mask = np.zeros(self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.astype(np.bool)
        
        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        return inputs, labels

    def sample_for_split(self, current_task_idx):
        a = np.minimum(self.N, self.memory_size)
        sampled_idx = np.random.choice(a, size=self.sample_batch_size, replace=False)
        mask = np.zeros(self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.astype(np.bool)
        
        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        labels_int = np.argmax(labels, axis=-1)
        #offsets = np.stack([np.array([labels_int[i] // self.num_classes_per_task, labels_int[i] // self.num_classes_per_task+1] for i in range(len(labels_int)))])
        
        offsets = np.zeros_like(labels, np.int32)
        for i in range(labels.shape[0]):
            c = int(labels_int[i] // self.num_classes_per_task)
            offsets[i, self.num_classes_per_task * c: self.num_classes_per_task * (c+1)] = 1
        return inputs, labels, offsets