import torch
import numpy as np
from AbstractVSA import AbstractVSA

class MBAT(AbstractVSA):
    def __init__(self, dimension, device=None):
        # Determine device (GPU if available, else CPU)
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)
        
        print(f"Initializing MBAT on device: {self.device}")

        # 1. Initialize random matrix on GPU
        # We perform SVD to ensure it starts Orthonormal
        random_mat = torch.rand((dimension, dimension), device=self.device) * 2 - 1 # Uniform [-1, 1]
        u, _, vh = torch.linalg.svd(random_mat, full_matrices=True)
        
        # M_base is an orthonormal matrix
        self.M_base = torch.matmul(u, vh)

    def generate_vector(self):
        """
        Generates a Gaussian random vector on the GPU.
        """
        # Mean 0, Variance 1/D
        std = 1.0 / np.sqrt(self.d)
        v = torch.randn(self.d, device=self.device) * std
        
        # Normalize to unit length
        return v / torch.norm(v)

    def _get_shift_index(self, u):
        """
        Calculates shift index using PyTorch operations.
        Hash = Sum of indices where element > 0.
        """
        # Find indices where value is positive
        # (u > 0) returns a boolean tensor
        # .nonzero() returns indices
        positive_indices = (u > 0).nonzero(as_tuple=True)[0]
        
        # Sum indices and modulo dimension
        if len(positive_indices) == 0:
            return 0
            
        shift = torch.sum(positive_indices).item() % self.d
        return int(shift)

    def _get_shifted_matrix(self, u):
        """
        Shifts the base matrix M based on vector u.
        Uses torch.roll which is optimized on GPU.
        """
        shift = self._get_shift_index(u)
        
        # Shift rows (dim 0) and columns (dim 1)
        # rolling is faster on GPU than CPU, though still memory intensive for huge matrices
        M_u = torch.roll(self.M_base, shifts=(shift, shift), dims=(0, 1))
        
        return M_u

    def bind(self, u, v):
        """
        Matrix Binding on GPU.
        z = M_shifted * v
        """
        # Ensure inputs are tensors on the correct device
        if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=self.device, dtype=torch.float32)
        if not isinstance(v, torch.Tensor): v = torch.tensor(v, device=self.device, dtype=torch.float32)

        M_u = self._get_shifted_matrix(u)
        
        # Matrix-Vector multiplication
        return torch.matmul(M_u, v)

    def unbind(self, u, z):
        """
        Matrix Unbinding on GPU.
        v_approx = M_shifted.T * z
        """
        if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=self.device, dtype=torch.float32)
        if not isinstance(z, torch.Tensor): z = torch.tensor(z, device=self.device, dtype=torch.float32)

        M_u = self._get_shifted_matrix(u)
        
        # Transpose for inverse (since M is orthonormal)
        return torch.matmul(M_u.T, z)

    def bundle(self, vectors):
        """
        Bundling on GPU.
        """
        # Stack vectors if they are a list
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Element-wise sum
        result = torch.sum(vectors, dim=0)
        
        # Normalize
        norm = torch.norm(result)
        if norm > 0:
            return result / norm
        return result

    def similarity(self, u, v):
        """
        Cosine Similarity on GPU.
        """
        if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=self.device, dtype=torch.float32)
        if not isinstance(v, torch.Tensor): v = torch.tensor(v, device=self.device, dtype=torch.float32)

        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()