import torch
import torch.fft
import numpy as np
from AbstractVSA import AbstractVSA

class HRR(AbstractVSA):
    def __init__(self, dimension, device=None):
        # Determine device (GPU if available, else CPU)
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)

    def generate_vector(self):
        """
        HRR uses Gaussian Random Vectors (Euclidean normalization).
        """
        # Generate on device
        v = torch.randn(self.d, device=self.device)
        
        # Normalize to unit length
        norm = torch.norm(v)
        return v / (norm + 1e-9)

    def bind(self, u, v):
        """
        Circular Convolution via FFT.
        F(u * v) = F(u) . F(v)
        """
        # FFT is fast O(N log N)
        # RFFT is optimized for real-valued inputs
        u_f = torch.fft.rfft(u)
        v_f = torch.fft.rfft(v)
        
        # Element-wise multiplication in freq domain
        return torch.fft.irfft(u_f * v_f, n=self.d)

    def unbind(self, u, z):
        """
        Circular Correlation.
        In HRR, the approximate inverse is the 'Involution' (reversing the vector).
        """
        # Create involution of u effectively in Fourier domain by taking the conjugate.
        u_f = torch.fft.rfft(u)
        z_f = torch.fft.rfft(z)
        
        # Multiply by conjugate (inverse in freq domain)
        return torch.fft.irfft(torch.conj(u_f) * z_f, n=self.d)

    def bundle(self, vectors):
        """Element-wise addition."""
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        result = torch.sum(vectors, dim=0)

        # Normalization helps keep values stable in deep networks
        norm = torch.norm(result)
        if norm > 0:
            return result / norm
        return result

    def similarity(self, u, v):
        """Cosine Similarity."""
        # torch.nn.functional.cosine_similarity is optimized
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()