# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Any
import time

import torch
import torch.nn.functional as F
import numpy as np
import tqdm

from torchvision import transforms
from einops import rearrange, repeat
from data.base_video_dataset import Buffer_VideoDataset


def reservoir(num_seen_examples: int, buffer_size: int) -> int:
    """
    Reservoir sampling algorithm.
    :param num_seen_examples: the number of seen examples
    :param buffer_size: the maximum buffer size
    :return: the target index if the current image is sampled, else -1
    """
    if num_seen_examples < buffer_size:
        return num_seen_examples

    rand = np.random.randint(0, num_seen_examples + 1)
    if rand < buffer_size:
        return rand
    else:
        return -1

class Buffer_task_free(object):
    """
    The memory buffer of rehearsal method.
    """
    def __init__(self, memory_size, device):
        self.buffer_size = memory_size
        self.device = device
        self.num_seen_examples = 0
        self.attributes = ['video_data', 'audio_data', 'logits_a', 'logits_v']
        self.age = np.zeros(memory_size)

    def init_tensors(self, video_data: torch.Tensor, audio_data: torch.Tensor,
                     logits_a: torch.Tensor, logits_v: torch.Tensor) -> None:
        """
        Initializes just the required tensors.
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):  # TODO: Unnecessary here. remove it
                setattr(self, attr_str, torch.zeros((self.buffer_size,
                        *attr.shape[1:]), dtype=torch.float32))

    def add_data(self, video_data, audio_data=None, logits_v=None, logits_a=None, sample_idx=None):
        """
        Adds the data to the memory buffer according to the reservoir strategy.
        :param examples: tensor containing the images
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        :return:
        """
        if not hasattr(self, 'video_data'):
            self.init_tensors(video_data=video_data, audio_data=audio_data,
                              logits_v=logits_v, logits_a=logits_a)

        for i in range(video_data.shape[0]):
            # Randomly select location to store new data when buffer is full
            if sample_idx is None:
                index = reservoir(self.num_seen_examples, self.buffer_size)
                self.num_seen_examples += 1
            # Sample idx is given. This function will be used in GMED for sample update
            else:
                index = sample_idx[i]

            if index >= 0:
                self.video_data[index] = video_data[i].detach().cpu()
                if audio_data is not None:
                    self.audio_data[index] = audio_data[i].detach().cpu()
                if logits_v is not None:
                    self.logits_v[index] = logits_v[i].detach().cpu()
                if logits_a is not None:
                    self.logits_a[index] = logits_a[i].detach().cpu()

    def get_data(self, size: int, transform: transforms = None, sample_idx=None, return_indices=False) -> Tuple:
        """
        Random samples a batch of size items.
        :param size: the number of requested items
        :param transform: the transformation to be applied (data augmentation)
        :return:
        """
        if sample_idx is None:
            if size > min(self.num_seen_examples, self.video_data.shape[0]):
                size = min(self.num_seen_examples, self.video_data.shape[0])

            choice = np.random.choice(min(self.num_seen_examples, self.video_data.shape[0]),
                                      size=size, replace=False)
        else:
            choice = sample_idx

        clips = {}
        # Bring data
        for attr_str in self.attributes[:2]:
            if hasattr(self, attr_str):
                if transform is not None:
                    clips.update(transform({attr_str: getattr(self, attr_str)[choice]}))
                else:
                    clips.update({attr_str: getattr(self, attr_str)[choice]})

        # Bring logits
        for attr_str in self.attributes[2:]:
            if hasattr(self, attr_str):
                clips.update({attr_str: getattr(self, attr_str)[choice]})

        # If need sample idx for updating sample, return sample idx
        if return_indices:
            return clips, choice
        else:
            return clips

    def get_mem_ages(self, indices, astype):
        ages = self.age[indices]
        if torch.is_tensor(astype):
            ages = torch.from_numpy(ages).float().to(astype.device)
        return ages

    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False

    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0

    def state_dict(self):
        state_dict = {'num_seen_examples': self.num_seen_examples,
                      'buffer_size': self.buffer_size,
                      'age': self.age}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                state_dict.update({attr_str: getattr(self, attr_str)})

        return state_dict

    def load_state_dict(self, state):
        self.num_seen_examples = state['num_seen_examples']
        self.buffer_size = state['buffer_size']
        self.age = state['age']
        for attr_str in self.attributes:
            if attr_str in state:
                setattr(self, attr_str, state[attr_str])


class Buffer_audio(object):
    """
    The memory buffer of rehearsal method.
    """
    def __init__(self, memory_size, device):
        self.buffer_size = memory_size
        self.device = device
        self.num_seen_examples = 0
        self.attributes = ['video_data', 'audio_data', 'att_map_av_ids', 'att_map_va_ids', 'task_labels', 'logits_a', 'logits_v',
                           'video_query', 'audio_query']
        self.age = np.zeros(memory_size)


    def init_tensors(self, video_data: torch.Tensor, audio_data: torch.Tensor, att_map_av_ids: torch.Tensor, att_map_va_ids: torch.Tensor,
                     logits_a: torch.Tensor, logits_v: torch.Tensor, task_labels: torch.Tensor, video_query: torch.Tensor, audio_query: torch.Tensor) -> None:
        """
        Initializes just the required tensors.
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                if attr_str.startswith('att_map') or attr_str == 'task_labels':
                    typ = torch.int64
                    setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                         *attr.shape[1:]), dtype=typ))
                else:
                    typ = torch.float32
                    setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                         *attr.shape[1:]), dtype=typ))

    def add_data(self, video_data, audio_data=None, att_map_av_ids=None, att_map_va_ids=None,
                 logits_a=None, logits_v=None, task_labels=None, sample_idx=None, video_query=None, audio_query=None):
        """
        Adds the data to the memory buffer according to the reservoir strategy.
        :param examples: tensor containing the images
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        :return:
        """
        if not hasattr(self, 'video_data'):
            self.init_tensors(video_data=video_data, audio_data=audio_data,
                              att_map_av_ids=att_map_av_ids, att_map_va_ids=att_map_va_ids,
                              logits_a=logits_a, logits_v=logits_v,
                              task_labels=task_labels,
                              video_query=video_query, audio_query=audio_query)

        for i in range(video_data.shape[0]):
            # Randomly select location to store new data when buffer is full
            if sample_idx is None:
                index = reservoir(self.num_seen_examples, self.buffer_size)
                self.num_seen_examples += 1
            # Sample idx is given. This function will be used in GMED for sample update
            else:
                index = sample_idx[i]

            if index >= 0:
                self.video_data[index] = video_data[i].detach().cpu()
                if audio_data is not None:
                    self.audio_data[index] = audio_data[i].detach().cpu()
                if att_map_av_ids is not None:
                    self.att_map_av_ids[index] = att_map_av_ids[i].detach().cpu()
                if att_map_va_ids is not None:
                    self.att_map_va_ids[index] = att_map_va_ids[i].detach().cpu()
                if logits_v is not None:
                    self.logits_v[index] = logits_v[i].detach().cpu()
                if logits_a is not None:
                    self.logits_a[index] = logits_a[i].detach().cpu()

                if video_query is not None:
                    self.video_query[index] = video_query[i].detach().cpu()
                if audio_query is not None:
                    self.audio_query[index] = audio_query[i].detach().cpu()

                if task_labels is not None:
                    self.task_labels[index] = task_labels[i].detach().cpu()

    def get_data(self, size: int, sample_idx=None, return_indices=False) -> Tuple:
        """
        Random samples a batch of size items.
        :param size: the number of requested items
        :param transform: the transformation to be applied (data augmentation)
        :return:
        """
        if sample_idx is None:
            if size > min(self.num_seen_examples, self.video_data.shape[0]):
                size = min(self.num_seen_examples, self.video_data.shape[0])

            choice = np.random.choice(min(self.num_seen_examples, self.video_data.shape[0]),
                                      size=size, replace=False)
        else:
            choice = sample_idx

        clips = {}
        # Bring data
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                clips.update({attr_str: getattr(self, attr_str)[choice]})

        # If need sample idx for updating sample, return sample idx
        if return_indices:
            return clips, choice
        else:
            return clips


    def get_mem_ages(self, indices, astype):
        ages = self.age[indices]
        if torch.is_tensor(astype):
            ages = torch.from_numpy(ages).float().to(astype.device)
        return ages


    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False


    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0

    def state_dict(self):
        state_dict = {'num_seen_examples': self.num_seen_examples,
                      'buffer_size': self.buffer_size,
                      'age': self.age}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                state_dict.update({attr_str: getattr(self, attr_str)})

        return state_dict

    def load_state_dict(self, state):
        self.num_seen_examples = state['num_seen_examples']
        self.buffer_size = state['buffer_size']
        self.age = state['age']
        for attr_str in self.attributes:
            if attr_str in state:
                setattr(self, attr_str, state[attr_str])


class Buffer_audio3(object):
    """
    The memory buffer of rehearsal method.
    """
    def __init__(self, memory_size, device):
        self.buffer_size = memory_size
        self.device = device
        self.num_seen_examples = 0
        self.attributes = ['video_data', 'audio_data', 'logits_a', 'logits_v',
                           'video_query', 'audio_query', 'prob_v_att', 'prob_a_att', 'n_attn_av', 'n_attn_va']
        self.age = np.zeros(memory_size)


    def init_tensors(self, video_data: torch.Tensor, audio_data: torch.Tensor,
                     logits_a: torch.Tensor, logits_v: torch.Tensor, video_query: torch.Tensor, audio_query: torch.Tensor,
                     prob_v_att: torch.Tensor, prob_a_att: torch.Tensor, n_attn_av: torch.Tensor, n_attn_va: torch.Tensor) -> None:
        """
        Initializes just the required tensors.
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                typ = torch.float32
                setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                     *attr.shape[1:]), dtype=typ))

    def add_data(self, video_data, audio_data=None,
                 logits_a=None, logits_v=None, sample_idx=None, video_query=None, audio_query=None,
                 prob_v_att=None, prob_a_att=None, n_attn_av=None, n_attn_va=None):
        """
        Adds the data to the memory buffer according to the reservoir strategy.
        :param examples: tensor containing the images
        :param logits: tensor containing the outputs of the network
        :return:
        """
        if not hasattr(self, 'video_data'):
            self.init_tensors(video_data=video_data, audio_data=audio_data,
                              logits_a=logits_a, logits_v=logits_v,
                              video_query=video_query, audio_query=audio_query,
                              prob_v_att=prob_v_att, prob_a_att=prob_a_att, n_attn_av=n_attn_av, n_attn_va=n_attn_va)

        for i in range(video_data.shape[0]):
            # Randomly select location to store new data when buffer is full
            if sample_idx is None:
                index = reservoir(self.num_seen_examples, self.buffer_size)
                self.num_seen_examples += 1
            # Sample idx is given. This function will be used in GMED for sample update
            else:
                index = sample_idx[i]

            if index >= 0:
                self.video_data[index] = video_data[i].detach().cpu()
                if audio_data is not None:
                    self.audio_data[index] = audio_data[i].detach().cpu()
                if logits_v is not None:
                    self.logits_v[index] = logits_v[i].detach().cpu()
                if logits_a is not None:
                    self.logits_a[index] = logits_a[i].detach().cpu()

                if video_query is not None:
                    self.video_query[index] = video_query[i].detach().cpu()
                if audio_query is not None:
                    self.audio_query[index] = audio_query[i].detach().cpu()

                if prob_v_att is not None:
                    self.prob_v_att[index] = prob_v_att[i].detach().cpu()
                if prob_a_att is not None:
                    self.prob_a_att[index] = prob_a_att[i].detach().cpu()

                if n_attn_av is not None:
                    self.n_attn_av[index] = n_attn_av[i].detach().cpu()
                if n_attn_va is not None:
                    self.n_attn_va[index] = n_attn_va[i].detach().cpu()

    def get_data(self, size: int, sample_idx=None, return_indices=False) -> Tuple:
        """
        Random samples a batch of size items.
        :param size: the number of requested items
        :param transform: the transformation to be applied (data augmentation)
        :return:
        """
        if sample_idx is None:
            if size > min(self.num_seen_examples, self.video_data.shape[0]):
                size = min(self.num_seen_examples, self.video_data.shape[0])

            choice = np.random.choice(min(self.num_seen_examples, self.video_data.shape[0]),
                                      size=size, replace=False)
        else:
            choice = sample_idx

        clips = {}
        # Bring data
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                clips.update({attr_str: getattr(self, attr_str)[choice]})

        # If need sample idx for updating sample, return sample idx
        if return_indices:
            return clips, choice
        else:
            return clips


    def get_mem_ages(self, indices, astype):
        ages = self.age[indices]
        if torch.is_tensor(astype):
            ages = torch.from_numpy(ages).float().to(astype.device)
        return ages


    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False


    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0

    def state_dict(self):
        state_dict = {'num_seen_examples': self.num_seen_examples,
                      'buffer_size': self.buffer_size,
                      'age': self.age}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                state_dict.update({attr_str: getattr(self, attr_str)})

        return state_dict

    def load_state_dict(self, state):
        self.num_seen_examples = state['num_seen_examples']
        self.buffer_size = state['buffer_size']
        self.age = state['age']
        for attr_str in self.attributes:
            if attr_str in state:
                setattr(self, attr_str, state[attr_str])



class Buffer_audio2(object):
    """
    The memory buffer of rehearsal method.
    """
    def __init__(self, memory_size, device):
        self.buffer_size = memory_size
        self.device = device
        self.num_seen_examples = 0
        self.attributes = ['video_data', 'audio_data', 'att_map_av_ids', 'att_map_va_ids', 'task_labels', 'logits_a', 'logits_v',
                           'video_query', 'audio_query']
        self.age = np.zeros(memory_size)


    def init_tensors(self, video_data: torch.Tensor, audio_data: torch.Tensor, att_map_av_ids: torch.Tensor, att_map_va_ids: torch.Tensor,
                     logits_a: torch.Tensor, logits_v: torch.Tensor, task_labels: torch.Tensor, video_query: torch.Tensor, audio_query: torch.Tensor) -> None:
        """
        Initializes just the required tensors.
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                if attr_str.startswith('att_map') or attr_str == 'task_labels':
                    typ = torch.int64
                    setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                         *attr.shape[1:]), dtype=typ).to(self.device))
                else:
                    typ = torch.float32
                    setattr(self, attr_str, torch.zeros((self.buffer_size,
                                                         *attr.shape[1:]), dtype=typ).to(self.device))

    def add_data(self, video_query, video_data=None, audio_data=None, att_map_av_ids=None, att_map_va_ids=None,
                 logits_a=None, logits_v=None, task_labels=None, sample_idx=None, audio_query=None):
        """
        Adds the data to the memory buffer according to the reservoir strategy.
        :param examples: tensor containing the images
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        :return:
        """
        if not hasattr(self, 'video_query'):
            self.init_tensors(video_data=video_data, audio_data=audio_data,
                              att_map_av_ids=att_map_av_ids, att_map_va_ids=att_map_va_ids,
                              logits_a=logits_a, logits_v=logits_v,
                              task_labels=task_labels,
                              video_query=video_query, audio_query=audio_query)

        for i in range(video_query.shape[0]):
            # Randomly select location to store new data when buffer is full
            if sample_idx is None:
                index = reservoir(self.num_seen_examples, self.buffer_size)
                self.num_seen_examples += 1
            # Sample idx is given. This function will be used in GMED for sample update
            else:
                index = sample_idx[i]

            if index >= 0:
                if video_data is not None:
                    self.video_data[index] = video_data[i].detach().cpu()
                if audio_data is not None:
                    self.audio_data[index] = audio_data[i].detach().cpu()
                if att_map_av_ids is not None:
                    self.att_map_av_ids[index] = att_map_av_ids[i].detach().cpu()
                if att_map_va_ids is not None:
                    self.att_map_va_ids[index] = att_map_va_ids[i].detach().cpu()
                if logits_v is not None:
                    self.logits_v[index] = logits_v[i].detach().cpu()
                if logits_a is not None:
                    self.logits_a[index] = logits_a[i].detach().cpu()

                if video_query is not None:
                    self.video_query[index] = video_query[i].detach().cpu()
                if audio_query is not None:
                    self.audio_query[index] = audio_query[i].detach().cpu()

                if task_labels is not None:
                    self.task_labels[index] = task_labels[i].detach().cpu()

    def get_data(self, size: int, sample_idx=None, return_indices=False) -> Tuple:
        """
        Random samples a batch of size items.
        :param size: the number of requested items
        :param transform: the transformation to be applied (data augmentation)
        :return:
        """
        if sample_idx is None:
            if size > min(self.num_seen_examples, self.video_query.shape[0]):
                size = min(self.num_seen_examples, self.video_query.shape[0])

            choice = np.random.choice(min(self.num_seen_examples, self.video_query.shape[0]),
                                      size=size, replace=False)
        else:
            choice = sample_idx

        clips = {}
        # Bring data
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                clips.update({attr_str: getattr(self, attr_str)[choice]})

        # If need sample idx for updating sample, return sample idx
        if return_indices:
            return clips, choice
        else:
            return clips


    def get_mem_ages(self, indices, astype):
        ages = self.age[indices]
        if torch.is_tensor(astype):
            ages = torch.from_numpy(ages).float().to(astype.device)
        return ages


    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False


    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0

    def state_dict(self):
        state_dict = {'num_seen_examples': self.num_seen_examples,
                      'buffer_size': self.buffer_size,
                      'age': self.age}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                state_dict.update({attr_str: getattr(self, attr_str)})

        return state_dict

    def load_state_dict(self, state):
        self.num_seen_examples = state['num_seen_examples']
        self.buffer_size = state['buffer_size']
        self.age = state['age']
        for attr_str in self.attributes:
            if attr_str in state:
                setattr(self, attr_str, state[attr_str])





class Buffer(object):
    """
    The memory buffer of rehearsal method
    """
    def __init__(self, buffer_size, device):
        self.buffer_size = buffer_size
        self.device = device
        self.num_seen_examples = 0
        self.attributes = ['video_data', 'audio_data', 'logits', 'task_labels', 'ids_shuffle_v', 'ids_shuffle_a']

    def init_tensors(self, video_data: torch.Tensor, audio_data: torch.Tensor,
                     logits: torch.Tensor, task_labels: torch.Tensor, ids_shuffle_v: torch.Tensor, ids_shuffle_a: torch.Tensor) -> None:
        """
        Initializes just the required tensors.
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                typ = torch.int64 if (attr_str.endswith('els') or attr_str.startswith('ids')) else torch.float32
                setattr(self, attr_str, torch.zeros((self.buffer_size,
                        *attr.shape[1:]), dtype=typ))

    def add_data(self, video_data, audio_data=None, logits=None, task_labels=None, ids_shuffle_v=None, ids_shuffle_a=None):
        """
        Adds the data to the memory buffer according to the reservoir strategy.
        :param examples: tensor containing the images
        :param logits: tensor containing the outputs of the network
        :param task_labels: tensor containing the task labels
        :return:
        """
        if not hasattr(self, 'video_data'):
            self.init_tensors(video_data=video_data, audio_data=audio_data,
                              logits=logits, task_labels=task_labels, ids_shuffle_v=ids_shuffle_v, ids_shuffle_a=ids_shuffle_a)

        for i in range(video_data.shape[0]):
            # Randomly select location to store new data when buffer is full
            index = reservoir(self.num_seen_examples, self.buffer_size)
            self.num_seen_examples += 1
            if index >= 0:
                self.video_data[index] = video_data[i].detach().cpu()
                if audio_data is not None:
                    self.audio_data[index] = audio_data[i].detach().cpu()
                if logits is not None:
                    self.logits[index] = logits[i].detach().cpu()
                if task_labels is not None:
                    self.task_labels[index] = task_labels[i].detach().cpu()
                if ids_shuffle_v is not None:
                    self.ids_shuffle_v[index] = ids_shuffle_v[i].detach().cpu()
                if ids_shuffle_a is not None:
                    self.ids_shuffle_a[index] = ids_shuffle_a[i].detach().cpu()

    def get_data(self):
        num_samples = min(self.buffer_size, self.num_seen_examples)
        data_dict = {}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                data_dict.update({attr_str: getattr(self, attr_str)[:num_samples]})
        return data_dict


    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_seen_examples == 0:
            return True
        else:
            return False

    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0

    def state_dict(self):
        state_dict = {'num_seen_examples': self.num_seen_examples,
                      'buffer_size': self.buffer_size}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                state_dict.update({attr_str: getattr(self, attr_str)})

        return state_dict

    def load_state_dict(self, state):
        self.num_seen_examples = state['num_seen_examples']
        self.buffer_size = state['buffer_size']
        for attr_str in self.attributes:
            if attr_str in state:
                setattr(self, attr_str, state[attr_str])


class Memory(object):
    def __init__(self, memory_size, nb_total_tasks, rehearsal, fixed=False):
        self.memory_size = memory_size
        self.nb_total_tasks = nb_total_tasks
        self.rehearsal = rehearsal
        self.fixed = fixed
        self.attributes = ['video_data', 'audio_data', 'logits', 'task_labels', 'ids_shuffle_v', 'ids_shuffle_a']

        self.nb_tasks = 0

    @property
    def memory_per_task(self):
        if self.fixed:
            return self.memory_size // self.nb_total_tasks
        return self.memory_size // self.nb_tasks if self.nb_tasks > 0 else self.memory_size

    def get_dataset(self, transform):
        data = {}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                data[attr_str] = getattr(self, attr_str)
        dataset = Buffer_VideoDataset(data, transform)
        return dataset

    def get(self):
        data = {}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                data[attr_str] = getattr(self, attr_str)
        return data

    def __len__(self):
        return len(self.video_data) if hasattr(self, 'video_data') else 0

    def save(self, path):
        torch.save(self.get(), path)

    def load(self, path):
        data = torch.load(path)
        for data_name in data:
            if data_name in self.attributes:
                setattr(self, data_name, data[data_name])

        assert len(self) <= self.memory_size, len(self)
        if hasattr(self, 'task_labels'):
            self.nb_tasks = len(torch.unique(self.task_labels))

    def reduce(self):
        data = {}
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                data[attr_str] = []

        assert hasattr(self, 'task_labels'), 'no task labels in memory'
        for task_id in torch.unique(self.task_labels):
            indices = torch.where(self.task_labels == task_id)[0]
            selected_indices = indices[:self.memory_per_task]

            for attr_str in data:
                data[attr_str].append(getattr(self, attr_str)[selected_indices])

        for attr_str in data:
            setattr(self, attr_str, torch.cat(data[attr_str]))

    def add(self, model, dataset):
        self.nb_tasks += 1
        data = herd_samples(model, dataset, self.memory_per_task, self.rehearsal)

        if not hasattr(self, 'video_data'):
            for attr_str in self.attributes:
                if attr_str in data:
                    setattr(self, attr_str, data[attr_str])
        else:
            if not self.fixed:
                self.reduce()
            for attr_str in self.attributes:
                if hasattr(self, attr_str):
                    setattr(self, attr_str, torch.cat([getattr(self, attr_str), data[attr_str]]))

def herd_samples(model, dataset, memory_per_task, rehearsal):

    assert 'task_labels' in dataset, 'no task labels in buffer'
    data = {}
    if rehearsal == "random":
        assert len(torch.unique(dataset['task_labels'])) == 1, 'Other task labels are inside buffer.'
        task_indices = torch.arange(len(dataset['task_labels']))
        indices = np.random.choice(task_indices, size=memory_per_task)

        for attr_str in model.buffer.attributes:
            if attr_str in dataset:
                data[attr_str] = dataset[attr_str][indices]

        return data
    else:
        ValueError(f'{rehearsal} is not implemented yet.')
