import torch
import numpy as np
from AbstractVSA import AbstractVSA

class VTB(AbstractVSA):
    def __init__(self, dimension, device=None):
        # Determine device (GPU if available, else CPU)
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        # VTB requires the dimension to be a perfect square to form a square matrix.
        self.d_root = int(np.sqrt(dimension))
        assert self.d_root * self.d_root == dimension, \
            "VTB dimension must be a perfect square (e.g., 1024, 2500) for matrix reshaping."
            
        super().__init__(dimension)

    def generate_vector(self):
        """
        VTB uses Gaussian Random Vectors with mean 0 and variance 1/D.
        """
        # Standard deviation is sqrt(1/d)
        v = torch.randn(self.d, device=self.device) * (1 / np.sqrt(self.d))
        
        # Normalize to unit length for stability
        norm = torch.norm(v)
        return v / (norm + 1e-9)

    def bind(self, u, v):
        """
        Matrix Multiplication Binding.
        Vectors are reshaped into square matrices d' x d' where d' = sqrt(D).
        Formula: C = A (x) B  =>  C_mat = B_mat * A_mat
        """
        # Reshape to matrices (views are zero-copy where possible)
        U_mat = u.view(self.d_root, self.d_root)
        V_mat = v.view(self.d_root, self.d_root)
        
        # We apply V as a transformation on U. 
        # Note: Matrix multiplication is non-commutative.
        result_mat = torch.matmul(V_mat, U_mat)
        
        return result_mat.view(-1)

    def unbind(self, u, z):
        """
        Matrix Transpose Unbinding.
        The approximate inverse is the transpose of the matrix.
        Given z = v * u (where v transformed u), we recover v by removing u.
        v_approx = z * u_inverse = z * u_transpose
        """
        U_mat = u.view(self.d_root, self.d_root)
        Z_mat = z.view(self.d_root, self.d_root)
        
        # Recover V by multiplying Z by U_transpose
        result_mat = torch.matmul(Z_mat, U_mat.T)
        
        return result_mat.view(-1)

    def bundle(self, vectors):
        """
        Element-wise addition with Normalization.
        [cite_start][cite: 118] (Table 1: 'elem. addition with normalization')
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        result = torch.sum(vectors, dim=0)
        
        # Normalization is required in VTB to maintain vector statistics
        norm = torch.norm(result)
        if norm > 0:
            return result / norm
        return result

    def similarity(self, u, v):
        """
        Cosine Similarity.
        [cite_start][cite: 118, 157]
        """
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()