import torch
from AbstractVSA import AbstractVSA

class MAP_B(AbstractVSA):
    def __init__(self, dimension, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        super().__init__(dimension)

    def generate_vector(self):
        """
        MAP-B uses Bipolar Vectors {-1, 1}.
        Generated by thresholding random values.
       
        """
        # Rand > 0.5 -> 1, else -1
        return (torch.rand(self.d, device=self.device) > 0.5).float() * 2 - 1

    def bind(self, u, v):
        """
        Element-wise Multiplication.
        {-1, 1} * {-1, 1} -> {-1, 1} (Equivalent to XOR in binary space).
       
        """
        return u * v

    def unbind(self, u, z):
        """
        Element-wise Multiplication (Self-Inverse).
        Since u_i is {-1, 1}, u_i * u_i = 1.
       
        """
        return u * z

    def bundle(self, vectors):
        """
        Bundling: Majority Rule (Sign of Sum).
        Strict Binary VSA logic.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Sum
        s = torch.sum(vectors, dim=0)
        
        # Sign (Majority Vote)
        # Handle zeros (ties) by random assignment
        start_sign = torch.sign(s)
        
        # Identify zeros
        zeros = (start_sign == 0)
        if zeros.any():
            # Random {-1, 1} for ties
            random_fill = torch.sign(torch.randn_like(s))
            start_sign[zeros] = random_fill[zeros]
            
        return start_sign

    def similarity(self, u, v):
        """Cosine Similarity (equivalent to normalized dot product for bipolar)."""
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()