import torch
import numpy as np
from AbstractVSA import AbstractVSA

class HLB(AbstractVSA):
    def __init__(self, dimension, device=None):
        # Determine device (GPU if available, else CPU)
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)

    def generate_vector(self):
        """
        HLB uses a specific bimodal Gaussian distribution.
        Values are sampled from N(-1, 1/D) or N(1, 1/D) with p=0.5.
        
        Reference: HLBTensor.random in vsa_models.py
        """
        std = 1.0 / np.sqrt(self.d)
        size = (self.d,)
        
        # Generate two potential pools of values on GPU
        n1 = torch.normal(mean=-1.0, std=std, size=size, device=self.device)
        n2 = torch.normal(mean=1.0, std=std, size=size, device=self.device)
        
        # Randomly select between the two distributions for each element
        mask = torch.rand(size, device=self.device) > 0.5
        v = torch.where(mask, n1, n2)
        
        return v

    def bind(self, u, v):
        """
        Element-wise Multiplication.
        HLB binding is defined as the Hadamard product.
        
        Reference: HLBTensor.bind in vsa_models.py (torch.mul)
        """
        return u * v

    def unbind(self, u, z):
        """
        Element-wise Division.
        The inverse of a vector u in HLB is 1/u.
        Therefore, unbinding z = u * v -> v = z / u.
        
        Reference: HLBTensor.inverse in vsa_models.py (1 / self)
        """
        # Note: Since vectors are centered at -1 and 1, values are bounded away from 0.
        # Adding epsilon for numerical stability.
        return z / (u + 1e-9)

    def bundle(self, vectors):
        """
        Element-wise addition.
        
        Reference: HLBTensor.bundle in vsa_models.py (torch.add)
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        return torch.sum(vectors, dim=0)

    def similarity(self, u, v):
        """
        Cosine Similarity.
        
        Reference: HLBTensor.cosine_similarity in vsa_models.py
        """
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()