import torch
from AbstractVSA import AbstractVSA

class MAP_C(AbstractVSA):
    def __init__(self, dimension, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        super().__init__(dimension)

    def generate_vector(self):
        """
        MAP-C uses Continuous Uniform Vectors [-1, 1].
       
        """
        # torch.rand is [0, 1) -> *2 -> [0, 2) -> -1 -> [-1, 1)
        return torch.rand(self.d, device=self.device) * 2 - 1

    def bind(self, u, v):
        """
        Element-wise Multiplication.
       
        """
        return u * v

    def unbind(self, u, z):
        """
        Element-wise Multiplication.
        (Note: In continuous MAP, this is an approximate inverse).
       
        """
        return u * z

    def bundle(self, vectors):
        """
        Sum with Clipping.
        Values > 1 become 1, Values < -1 become -1.
       
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        sum_vec = torch.sum(vectors, dim=0)
        
        # Hard Clip to [-1, 1]
        return torch.clamp(sum_vec, min=-1.0, max=1.0)

    def similarity(self, u, v):
        """Cosine Similarity."""
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()