# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np
from typing import Tuple
from torchvision import transforms
from datasets.seq_imagenet import ImageFolder
from datasets.seq_imagenet import pil_loader
from augmentations.imagenet_aug import MaskGenerator, NonBinaryMaskGenerator


def reservoir(num_seen_examples: int, buffer_size: int):
    """
    Reservoir sampling algorithm.
    :param num_seen_examples: the number of seen examples
    :param buffer_size: the maximum buffer size
    :return: the target index if the current image is sampled, else -1
    """
    if num_seen_examples < buffer_size:
        return num_seen_examples

    rand = np.random.randint(0, num_seen_examples + 1)
    if rand < buffer_size:
        return rand
    else:
        return -1

def ring(num_seen_examples: int, buffer_portion_size: int, task: int):
    return num_seen_examples % buffer_portion_size + task * buffer_portion_size

class Buffer:
    """
    The memory buffer of rehearsal method.
    """
    def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir', configs=None):
        self.configs=configs
        assert mode in ['ring', 'reservoir']
        self.buffer_size = buffer_size
        self.cnd_buffer_size = buffer_size * 5
        
        self.device = device
        self.num_seen_examples = 0
        self.functional_index = eval(mode)
        self.mode = mode
        if mode == 'ring':
            assert n_tasks is not None
            self.task_number = n_tasks
            self.buffer_portion_size = buffer_size // n_tasks
                    
        self.attributes = ['examples', 'labels', 'logits', 'task_labels', 'masks']
        self.cnd_attributes = ['cnd_examples', 'cnd_labels', 'cnd_logits', 'cnd_task_labels']

        if self.configs:
            if self.configs.model.type == 'swin':
                model_patch_size=self.configs.model.swin.patch_size
            elif self.configs.model.type == 'vit':
                model_patch_size=self.configs.model.vit.patch_size
            else:
                raise NotImplementedError
        
            self.mask_generator = MaskGenerator(
                input_size=self.configs.dataset.image_size,
                mask_patch_size=self.configs.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=self.configs.model.mask_ratio,
            )
            self.non_binary_mask_generator = NonBinaryMaskGenerator(
                input_size=self.configs.dataset.image_size,
                mask_patch_size=self.configs.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=self.configs.model.mask_ratio,
            )
        
    def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor,
                     logits: torch.Tensor, masks: torch.Tensor, task_labels: torch.Tensor):
        """
        Initializes just the required tensors.
        :param examples: tensor containing the images
        :param labels: tensor containing the labels
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                typ = torch.int64 if attr_str.endswith('els') else torch.float32
                setattr(self, attr_str, torch.zeros((self.buffer_size,
                        *attr.shape[1:]), dtype=typ, device='cpu'))

    def init_cnd_tensors(self, cnd_examples: torch.Tensor, cnd_labels: torch.Tensor,
                     cnd_logits: torch.Tensor, cnd_task_labels: torch.Tensor):
        """
        Initializes just the required tensors.
        :param examples: tensor containing the images
        :param labels: tensor containing the labels
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        """
        for attr_str in self.cnd_attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                typ = torch.int64 if attr_str.endswith('els') else torch.float32
                setattr(self, attr_str, torch.zeros((self.cnd_buffer_size,
                        *attr.shape[1:]), dtype=typ, device=self.device))

    def add_data(self, examples, labels=None, logits=None, masks=None, paths=None, task_labels=None):
        """
        Adds the data to the memory buffer according to the reservoir strategy.
        :param examples: tensor containing the images
        :param labels: tensor containing the labels
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        :return:
        """
        if not hasattr(self, 'examples'):
            self.init_tensors(examples, labels, logits, masks, task_labels,)
        
        if not hasattr(self, 'paths') and paths is not None:
            self.paths = np.array([None for _ in range(self.buffer_size)])

        for i in range(examples.shape[0]):
            if self.mode == 'ring':
                index = ring(self.num_seen_examples, self.buffer_portion_size, task_labels[i])
                # if self.num_seen_examples < int(self.buffer_portion_size * (task_labels+1)):
                #     self.num_seen_examples += 1
            elif self.mode == 'reservoir':
                index = reservoir(self.num_seen_examples, self.buffer_size)
                # self.num_seen_examples += 1
            else:
                NotImplementedError
            
            self.num_seen_examples += 1
            if index >= 0:
                self.examples[index] = examples[i]
                if labels is not None:  
                    self.labels[index] = labels[i]
                if logits is not None:
                    self.logits[index] = logits[i]
                if paths is not None:
                    self.paths[index] = paths[i]
                if task_labels is not None:
                    self.task_labels[index] = task_labels[i]
                if masks is not None:
                    self.masks[index] = masks[i]


    def get_data(self, size: int, transform: transforms=None, get_mask: bool=False):
        """
        Random samples a batch of size items.
        :param size: the number of requested items
        :param transform: the transformation to be applied (data augmentation)
        :return
        
        if size > min(self.num_seen_examples, self.examples.shape[0]):
            size = min(self.num_seen_examples, self.examples.shape[0])
        """
        choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
                                  size=size, replace=True)
        
        if transform is None: transform = lambda x: x
        # mask_gen = None
        # if get_mask: mask_gen = lambda x: self.non_binary_mask_generator()
        
        if not hasattr(self, 'paths'):            
            ret_tuple = (torch.stack([transform(ee.cpu())
                                for ee in self.examples[choice]]),)
        else:            
            ret_tuple = (torch.stack([transform(pil_loader(ee))
                                for ee in self.paths[choice]]),)
        if get_mask:            
            ret_tuple += (torch.stack([torch.Tensor(self.mask_generator())
                                for _ in range(len(choice))]),)            
        
        for attr_str in self.attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str)
                ret_tuple += (attr[choice],)
        return ret_tuple

    # For OCS
    def add_cnd_data(self, examples, labels=None, logits=None, task_labels=None):
        """
        Adds the cndidate data for the memory buffer in OCS.
        :param examples: tensor containing the images
        :param labels: tensor containing the labels
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        :return:
        """
        if not hasattr(self, 'cnd_examples'):
            self.init_cnd_tensors(examples, labels, logits, task_labels)

        for i in range(examples.shape[0]):
            index = reservoir(self.num_seen_examples, self.cnd_buffer_size)
            self.num_seen_examples += 1
            if index >= 0:
                self.cnd_examples[index] = examples[i]
                if labels is not None:
                    self.cnd_labels[index] = labels[i]
                if logits is not None:
                    self.cnd_logits[index] = logits[i]
                if task_labels is not None:
                    self.cnd_task_labels[index] = task_labels[i]

    def is_empty(self):
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False
    
    def is_full(self, logger):
        """
        Returns true if the buffer is full, false otherwise.
        """
        if self.num_seen_examples > self.buffer_size:
            return True
        else:
            logger.info('num_seen_examples:%s, self.buffer_size:%s'%(self.num_seen_examples, self.buffer_size))
            return False

    def get_all_data(self, transform: transforms=None):
        """
        Return all the items in the memory buffer.
        :param transform: the transformation to be applied (data augmentation)
        :return: a tuple with all the items in the memory buffer
        """
        if transform is None: transform = lambda x: x
        ret_tuple = (torch.stack([transform(ee.cpu())
                            for ee in self.examples]),)
        for attr_str in self.attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str)
                ret_tuple += (attr,)
        return ret_tuple


    def get_all_cnd_data(self, transform: transforms=None):
        """
        Return all the items in the memory buffer.
        :param transform: the transformation to be applied (data augmentation)
        :return: a tuple with all the items in the memory buffer
        """
        if transform is None: transform = lambda x: x
        ret_tuple = (torch.stack([transform(ee.cpu())
                            for ee in self.cnd_examples]),)
        for attr_str in self.cnd_attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str)
                ret_tuple += (attr,)
        return ret_tuple

    def empty(self):
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0
