import torch

import ShiftingWindowSetting as sw
import random
import math


class ER_reservoir(sw.CLLearningAlgo):

    mem = []

    def __init__(self, args, mem_size=1000, replay_batch_size=10):
        super().__init__(args=args)
        self.mem_size = mem_size
        self.remaining_space = mem_size
        self.seen_count = 0
        self.next = 0
        self.w = 0
        self.full = False
        self.replay_batch_size = replay_batch_size

    # def regularised_loss_fn(self, X, Y):
    #    return self.calc_reg_loss_term()

    def calc_reg_loss_term(self):
        if len(self.mem) < self.replay_batch_size:
            return torch.zeros(1, device=self.device, requires_grad=True)
        X, Y, t = list(zip(*random.sample(self.mem, self.replay_batch_size)))
        X, Y = torch.stack(X, dim=0), torch.tensor(Y)
        X, Y = X.to(self.device), Y.to(self.device)
        per_point_nullClasses = [self.calc_null_classes(t[i], self.task_stream.classes) for i in range(len(t))]
        #per_point_nullClasses = sw.calc_per_point_nullclasses_for_task_inc_setting(self.task_stream, X, Y, self.window_len)
        return self.loss_fn(sw.calc_multi_head_model_output(self.model, X, per_point_nullClasses), Y)

    def _update_memory(self):
        batch = [(self.batch[0][i], self.batch[1][i].item(), self.task_id) for i in range(self.batch[0].shape[0])]
        if not self.full:
            can_remove_count = min(len(batch), self.remaining_space)
            self.mem += batch[:can_remove_count]
            self.remaining_space -= can_remove_count
            self.seen_count += can_remove_count
            if self.remaining_space == 0:
                self.full = True
                self.w = math.exp(math.log(random.uniform(0, 1))/self.mem_size)
                self.next = self.seen_count + math.floor(math.log(random.uniform(0, 1))/math.log(1-self.w)) + 1
            if can_remove_count == len(batch):
                return
            batch = batch[can_remove_count:]

        while self.next <= self.seen_count+len(batch)-1:
            self.mem[random.randint(0, self.mem_size-1)] = batch[self.next-self.seen_count]
            self.w = self.w*math.exp(math.log(random.uniform(0, 1))/self.mem_size)
            self.next = self.next + math.floor(math.log(random.uniform(0, 1)) / math.log(1 - self.w)) + 1

        self.seen_count += len(batch)

    def after_optimiser_step(self):
        self._update_memory()
