import torch
import torch.nn.functional as F
from AbstractVSA import AbstractVSA

# https://arxiv.org/pdf/2405.09689v1

class GHRR(AbstractVSA):
    def __init__(self, dimension, m=2, device=None):
        """
        Generalized HRR - Compatibility Mode.
        Stores flattened Real-valued tensors to work with standard Test Scripts.
        """
        super().__init__(dimension)
        self.m = m
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Calculate number of blocks to match 'dimension' total parameters
        # m=2 (Quaternion): 4 params per block
        if m == 2:
            block_size = 4
        else:
            block_size = m*m * 2
            
        remainder = dimension % block_size
        if remainder != 0:
            dimension -= remainder
            # print(f"GHRR Warning: Dimension adjusted to {dimension} (must be multiple of {block_size})")
            
        super().__init__(dimension)
        
        self.n_blocks = self.d // block_size

    def generate_vector(self):
        if self.m == 2:
            # Quaternions: Generate [Blocks, 4]
            q = torch.randn(self.n_blocks, 4, device=self.device)
            q = F.normalize(q, p=2, dim=-1)
            # Flatten to [Blocks * 4]
            return q.flatten()
        else:
            # General Unitary: Generate [Blocks, m, m] Complex
            real = torch.randn(self.n_blocks, self.m, self.m, device=self.device)
            imag = torch.randn(self.n_blocks, self.m, self.m, device=self.device)
            z = torch.complex(real, imag)
            q, r = torch.linalg.qr(z)
            
            # Flatten to Real concatenation: [Blocks * m * m * 2]
            return torch.cat([q.real.flatten(), q.imag.flatten()])

    def _to_blocks(self, flat_vec):
        """Helper to reshape flat vector back to matrix blocks."""
        if self.m == 2:
            # Reshape to [Blocks, 4] (Quaternions)
            return flat_vec.view(self.n_blocks, 4)
        else:
            # Reshape to Complex Matrices
            # Split real/imag
            mid = flat_vec.numel() // 2
            real = flat_vec[:mid].view(self.n_blocks, self.m, self.m)
            imag = flat_vec[mid:].view(self.n_blocks, self.m, self.m)
            return torch.complex(real, imag)

    def _from_blocks(self, blocks):
        """Helper to flatten blocks."""
        if self.m == 2:
            return blocks.flatten()
        else:
            return torch.cat([blocks.real.flatten(), blocks.imag.flatten()])

    def bind(self, u, v):
        u_b = self._to_blocks(u)
        v_b = self._to_blocks(v)
        
        if self.m == 2:
            # Quaternion Multiplication
            # q = [a, b, c, d]
            a1, b1, c1, d1 = u_b[:,0], u_b[:,1], u_b[:,2], u_b[:,3]
            a2, b2, c2, d2 = v_b[:,0], v_b[:,1], v_b[:,2], v_b[:,3]
            
            # Hamilton Product
            a = a1*a2 - b1*b2 - c1*c2 - d1*d2
            b = a1*b2 + b1*a2 + c1*d2 - d1*c2
            c = a1*c2 - b1*d2 + c1*a2 + d1*b2
            d = a1*d2 + b1*c2 - c1*b2 + d1*a2
            
            res = torch.stack([a, b, c, d], dim=-1)
        else:
            # Matrix Multiplication
            res = torch.matmul(u_b, v_b)
            
        return self._from_blocks(res)

    def unbind(self, u, z):
        u_b = self._to_blocks(u)
        z_b = self._to_blocks(z)
        
        if self.m == 2:
            # Quaternion Inverse: Conjugate (a, -b, -c, -d)
            a, b, c, d = u_b[:,0], u_b[:,1], u_b[:,2], u_b[:,3]
            u_inv = torch.stack([a, -b, -c, -d], dim=-1)
            
            # Use same multiplication logic as bind...
            # Or just call bind with u_inv
            # (Self-contained for clarity here)
            a1, b1, c1, d1 = u_inv[:,0], u_inv[:,1], u_inv[:,2], u_inv[:,3]
            a2, b2, c2, d2 = z_b[:,0], z_b[:,1], z_b[:,2], z_b[:,3]
            
            a = a1*a2 - b1*b2 - c1*c2 - d1*d2
            b = a1*b2 + b1*a2 + c1*d2 - d1*c2
            c = a1*c2 - b1*d2 + c1*a2 + d1*b2
            d = a1*d2 + b1*c2 - c1*b2 + d1*a2
            
            res = torch.stack([a, b, c, d], dim=-1)
        else:
            # Matrix Inverse: Conjugate Transpose
            u_inv = u_b.transpose(-2, -1).conj()
            res = torch.matmul(u_inv, z_b)
            
        return self._from_blocks(res)

    def bundle(self, vectors):
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Sum
        sum_vec = torch.sum(vectors, dim=0)
        
        # Global Normalization (Euclidean)
        # We DO NOT project back to Unitary/Quaternion blocks here.
        # We leave the superposition comfortably off the manifold to preserve capacity.
        norm = torch.norm(sum_vec)
        if norm > 0:
            return sum_vec / norm
        return sum_vec

    def similarity(self, u, v):
        # Dot product of flattened real representations 
        # is equivalent to Real part of Trace Similarity
        return F.cosine_similarity(u, v, dim=0).item()