import torch
import numpy as np
from typing import Tuple

class Buffer:
    """
    The memory buffer of rehearsal method. 
    """
    def __init__(self, buffer_max_rel, device):
        self.buffer_max_rel = buffer_max_rel
        self.device = device
        self.num_rels_seen = 0
        self.attributes = ['rel', 'rel_triplets']

    def init_tensors(self):
        for attr_str in self.attributes:
            if not hasattr(self, attr_str):
                if attr_str.endswith('rel_triplets'):
                    setattr(self, attr_str, [set() for index in range(self.buffer_max_rel)])
                elif attr_str.endswith('rel'):
                    setattr(self, attr_str, torch.zeros((self.buffer_max_rel), dtype=torch.int64, device=self.device)-1)
                else:
                    print("Error init buffer!")
    

    def add_data(self, rel, triplets):
        if not hasattr(self, 'rel') or not hasattr(self, 'rel_triplets'):
            # self.init_tensors(rel, triplets, triplets)
            self.init_tensors()
            
        if (self.rel == rel).nonzero(as_tuple=True)[0].shape[0] > 0:
            for triplet in triplets:
                self.rel_triplets[(self.rel == rel).nonzero(as_tuple=True)[0]].add(triplet)
        else: # new rel
            self.rel[self.num_rels_seen] = rel
            for triplet in triplets:
                self.rel_triplets[self.num_rels_seen].add(triplet)
            self.num_rels_seen += 1     
    
    def intersect_per_epoch(self, buffer):
        if buffer.is_empty():
            return
        if not hasattr(self, 'rel') or not hasattr(self, 'rel_triplets'):
            self.init_tensors()
            for idx, rel in enumerate(buffer.rel):
                if rel == -1:
                    break
                if self.rel.eq(rel).any():
                    continue
                self.rel[self.num_rels_seen] = rel
                self.rel_triplets[self.num_rels_seen] = buffer.rel_triplets[idx]
                self.num_rels_seen += 1
        else:
            for idx, rel in enumerate(buffer.rel):
                if rel == -1:
                    break
                intersect_rel_idx = (self.rel == rel).nonzero(as_tuple=True)[0]
                if intersect_rel_idx.shape[0] > 0: 
                    # get intersection, unique the self.rel_triplets[idx]
                    self_rel_triplets = torch.unique(torch.stack(list(self.rel_triplets[idx])), dim=0)
                    try:
                        # unique the buffer.rel_triplets
                        buffer_rel_triplets = torch.unique(torch.stack(list(buffer.rel_triplets[intersect_rel_idx])), dim=0)
                    except IndexError as e:
                        import pdb
                        pdb.set_trace()
                    intersection_mask = torch.all(self_rel_triplets[:, None, :] == buffer_rel_triplets[None, :, :], dim=2)
                    self.rel_triplets[idx] = self_rel_triplets[intersection_mask.any(dim=1)]

    def union_buffer(self, buffer):
        if buffer.is_empty():
            return
        if not hasattr(self, 'rel') or not hasattr(self, 'rel_triplets'):
            self.init_tensors()
            
        # join buffer to the current buffer
        for idx, rel in enumerate(buffer.rel):
            if rel == -1:
                break
            self.rel[self.num_rels_seen] = rel
            self.rel_triplets[self.num_rels_seen] = buffer.rel_triplets[idx]
            self.num_rels_seen += 1
            
    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        if self.num_rels_seen == 0:
            return True
        else:
            return False

    def empty(self) -> None:
        """
        Set all the tensors to None.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_rels_seen = 0
