import numpy as np
import torch
import random


def compute_offsets(task, nc_per_task):
    """
        Compute offsets for cifar to determine which
        outputs to select for a given task.
    """
    offset1 = task * nc_per_task
    offset2 = (task + 1) * nc_per_task
    return offset1, offset2


class Reservoir:
    """This is reservoir sampling, each sample has storage-probability 'buffer samples M / seen samples'
    """
    def __init__(self, mem_size, image_size, device='cuda'):
        self.mem_size = mem_size
        self.memory_data = torch.zeros(
            mem_size, image_size[0], image_size[1], image_size[2],
            dtype=torch.float, device=device)
        self.memory_labs = torch.zeros(mem_size, dtype=torch.long, device=device)
        self.seen_cnt = 0

    def update(self, x, y, t):
        for i in range(x.shape[0]):
            if self.seen_cnt < self.mem_size:
                self.memory_data[self.seen_cnt].copy_(x[i])
                self.memory_labs[self.seen_cnt].copy_(y[i])
            else:
                j = random.randrange(self.seen_cnt)
                if j < self.mem_size:
                    self.memory_data[j].copy_(x[i])
                    self.memory_labs[j].copy_(y[i])
            self.seen_cnt += 1
        return

    def sample(self, sample_size):
        perm = torch.randperm(len(self.memory_data))
        index = perm[:sample_size]
        x = self.memory_data[index]
        y = self.memory_labs[index]
        return x, y


class RingBuffer:
    def __init__(self, n_tasks, n_memories, image_size, device='cuda'):
        self.memory_data = torch.zeros(
            n_tasks, n_memories, image_size[0], image_size[1], image_size[2],
            dtype=torch.float, device=device)
        self.memory_labs = torch.zeros(n_tasks, n_memories, dtype=torch.long, device=device)
        self.n_memories = n_memories
        self.old_task = -1
        self.mem_cnt = 0
        self.observed_tasks = []

    def update(self, x, y, t):
        if t != self.old_task:
            self.observed_tasks.append(t)
            self.old_task = t
        # Update ring buffer storing examples from current task
        bsz = y.data.size(0)
        endcnt = min(self.mem_cnt + bsz, self.n_memories)
        effbsz = endcnt - self.mem_cnt
        self.memory_data[t, self.mem_cnt: endcnt].copy_(
            x.data[: effbsz])
        if bsz == 1:
            self.memory_labs[t, self.mem_cnt] = y.data[0]
        else:
            self.memory_labs[t, self.mem_cnt: endcnt].copy_(
                y.data[: effbsz])
        self.mem_cnt += effbsz
        if self.mem_cnt == self.n_memories:
            self.mem_cnt = 0

    def sample(self, sample_size):
        sampler_per_task = sample_size // (len(self.observed_tasks))
        m_x, m_y = [], []
        for tt in self.observed_tasks:
            index = torch.randperm(len(self.memory_data[tt]))[:sampler_per_task]
            m_x.append(self.memory_data[tt][index])
            m_y.append(self.memory_labs[tt][index])
        m_x, m_y = torch.cat(m_x), torch.cat(m_y)
        # shuffle
        index = torch.randperm(len(m_x))
        m_x, m_y = m_x[index], m_y[index]
        return m_x, m_y

#
# class GSS:
#     def __init__(self):
