import torch
import numpy as np
from AbstractVSA import AbstractVSA

class BSC(AbstractVSA):
    def __init__(self, dimension, device=None):
        # Determine device (GPU if available, else CPU)
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)

    def generate_vector(self):
        """
        BSC using Bipolar {-1, 1} Representation.
        This allows standard Dot Product / Cosine Similarity to be equivalent 
        to Hamming Distance.
        """
        # Generate rademacher distribution {-1, 1}
        # sign of normal distribution is equivalent to uniform choice
        return torch.sign(torch.randn(self.d, device=self.device))

    def bind(self, u, v):
        """
        Binding is Multiplication (equivalent to XOR in bipolar).
        1 * 1 = 1   (0 XOR 0 = 0)
        -1 * -1 = 1 (1 XOR 1 = 0)
        1 * -1 = -1 (0 XOR 1 = 1)
        """
        return u * v

    def unbind(self, u, z):
        """
        Unbinding is Multiplication (Self-Inverse).
        """
        return u * z

    def bundle(self, vectors):
        """
        Bundling: Majority Rule (Sign of Sum).
        Strict Binary VSA logic.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Sum
        s = torch.sum(vectors, dim=0)
        
        # Sign (Majority Vote)
        # Handle zeros (ties) by random assignment
        start_sign = torch.sign(s)
        
        # Identify zeros
        zeros = (start_sign == 0)
        if zeros.any():
            # Random {-1, 1} for ties
            random_fill = torch.sign(torch.randn_like(s))
            start_sign[zeros] = random_fill[zeros]
            
        return start_sign

    def similarity(self, u, v):
        """
        Cosine Similarity on Bipolar Vectors.
        Linearly relates to Hamming Similarity.
        """
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()