import torch
from AbstractVSA import AbstractVSA

class MAP_I(AbstractVSA):
    def __init__(self, dimension, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        super().__init__(dimension)

    def generate_vector(self):
        """
        MAP-I uses Bipolar Vectors {-1, 1} for atoms.
       
        """
        return (torch.rand(self.d, device=self.device) > 0.5).float() * 2 - 1

    def bind(self, u, v):
        """
        Element-wise Multiplication.
       
        """
        return u * v

    def unbind(self, u, z):
        """
        Element-wise Multiplication.
       
        """
        return u * z

    def bundle(self, vectors):
        """
        Pure Summation (No Normalization).
        Values are allowed to grow outside [-1, 1].
       
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        return torch.sum(vectors, dim=0)

    def similarity(self, u, v):
        """Cosine Similarity."""
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()