import torch
import numpy as np
from AbstractVSA import AbstractVSA

class FHRR(AbstractVSA):
    def __init__(self, dimension, device=None):
        # Determine device (GPU if available, else CPU)
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        if dimension % 2 != 0:
            dimension -= 1
        
        super().__init__(dimension)
        
        # Number of complex components
        self.n_complex = dimension // 2

    def generate_vector(self):
        """
        FHRR using Complex Numbers (Phasors), represented as flattened Real/Imaginary pairs.
        Each component is e^(i * theta).
        Representation: [Re_0, Im_0, Re_1, Im_1, ..., Re_k, Im_k]
        """
        # Generate random angles in (-pi, pi]
        theta = (torch.rand(self.n_complex, device=self.device) * 2 * np.pi) - np.pi
        
        # Convert to [cos, sin]
        c = torch.cos(theta)
        s = torch.sin(theta)
        
        # Stack to (n_complex, 2) then flatten to (dimension,)
        return torch.stack([c, s], dim=-1).reshape(-1)

    def bind(self, u, v):
        """
        Binding: Element-wise Complex Multiplication.
        (a + ib)(c + id) = (ac - bd) + i(ad + bc)
        """
        # Reshape to (n, 2)
        u_c = u.view(-1, 2)
        v_c = v.view(-1, 2)
        
        re = u_c[:, 0] * v_c[:, 0] - u_c[:, 1] * v_c[:, 1]
        im = u_c[:, 0] * v_c[:, 1] + u_c[:, 1] * v_c[:, 0]
        
        return torch.stack([re, im], dim=-1).reshape(-1)

    def unbind(self, u, z):
        """
        Unbinding: Complex Division (Multiply by Conjugate).
        z / u = z * conj(u)  (assuming |u| = 1)
        (a + ib)(c - id) = (ac + bd) + i(bc - ad)
        
        u is the key (first arg), z is the bound object
        conj(u) = (u_re, -u_im)
        """
        u_c = u.view(-1, 2)
        z_c = z.view(-1, 2)
        
        # z * conj(u)
        re = z_c[:, 0] * u_c[:, 0] + z_c[:, 1] * u_c[:, 1]
        im = z_c[:, 1] * u_c[:, 0] - z_c[:, 0] * u_c[:, 1]
        
        return torch.stack([re, im], dim=-1).reshape(-1)

    def bundle(self, vectors):
        """
        Bundling: Component-wise Sum + Global Normalization.
        We relax the strict phasor capability for the memory to preserve 
        superposition information, similar to standard HRR.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
        
        # Sum vectors
        s = torch.sum(vectors, dim=0)
        
        # Global Normalization (Euclidean)
        norm = torch.norm(s)
        if norm > 0:
            return s / norm
        return s

    def similarity(self, u, v):
        """
        Cosine Similarity on the flattened vectors.
        Equivalent to Mean Cosine Similarity of the angles.
        """
        return torch.nn.functional.cosine_similarity(u, v, dim=0).item()