import torch
import torch.nn.functional as F
import numpy as np
from AbstractVSA import AbstractVSA

# https://arxiv.org/pdf/2202.04805

class CGR(AbstractVSA):
    def __init__(self, dimension, n=16, device=None):
        """
        Cyclic Group Representation (CGR) - Compatibility Mode.
        
        To work with generic VSA test scripts (which use F.normalize and Dot Product),
        this class represents the cyclic group elements as 2D Floating Point vectors 
        (Cos, Sin) instead of Integers.
        
        Args:
            dimension (int): The number of independent atomic elements. 
                             ACTUAL TENSOR SIZE will be dimension * 2.
            n (int): Group order (resolution).
        """
        super().__init__(dimension//2)
        self.n = n
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def generate_vector(self):
        """
        Generates Random Phases, converts to [Cos, Sin] pairs.
        Output Shape: [2 * dimension]
        """
        # 1. Random Integers
        k = torch.randint(0, self.n, (self.d,), device=self.device)
        
        # 2. Map to Angles
        theta = k.float() * (2 * np.pi / self.n)
        
        # 3. Map to [Cos, Sin]
        # We stack them to get shape [D, 2], then flatten to [2*D]
        # This allows standard Dot Product to calculate Cosine Similarity.
        vec = torch.stack([torch.cos(theta), torch.sin(theta)], dim=-1).flatten()
        return vec

    def bind(self, u, v):
        """
        Binding: Addition of Angles.
         implemented as Complex Multiplication on the [Cos, Sin] pairs.
        """
        # Reshape to [D, 2] -> (Real, Imag)
        u_c = u.view(-1, 2)
        v_c = v.view(-1, 2)
        
        # Complex Multiply: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
        real = u_c[:, 0] * v_c[:, 0] - u_c[:, 1] * v_c[:, 1]
        imag = u_c[:, 0] * v_c[:, 1] + u_c[:, 1] * v_c[:, 0]
        
        # Quantize Logic (Optional but recommended for CGR):
        # We get the angle, round to nearest 2pi/n, and regenerate cos/sin
        # This keeps the vector "clean" like an integer.
        raw_angles = torch.atan2(imag, real)
        step = 2 * np.pi / self.n
        discrete_angles = torch.round(raw_angles / step) * step
        
        # Back to coordinates
        res_real = torch.cos(discrete_angles)
        res_imag = torch.sin(discrete_angles)
        
        return torch.stack([res_real, res_imag], dim=-1).flatten()

    def unbind(self, u, z):
        """
        Unbinding: Subtraction of Angles.
        Implemented as Multiply by Conjugate.
        """
        u_c = u.view(-1, 2)
        z_c = z.view(-1, 2)
        
        # Conjugate u: (a, b) -> (a, -b)
        u_real, u_imag = u_c[:, 0], -u_c[:, 1]
        z_real, z_imag = z_c[:, 0], z_c[:, 1]
        
        real = u_real * z_real - u_imag * z_imag
        imag = u_real * z_imag + u_imag * z_real
        
        # Quantize
        raw_angles = torch.atan2(imag, real)
        step = 2 * np.pi / self.n
        discrete_angles = torch.round(raw_angles / step) * step
        
        return torch.stack([torch.cos(discrete_angles), torch.sin(discrete_angles)], dim=-1).flatten()

    def bundle(self, vectors):
        """
        Superposition: Vector Sum + Normalization.
        Matches CGR "Circular Mean" logic.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Sum the [Cos, Sin] vectors
        sum_vec = torch.sum(vectors, dim=0) # [2*D]
        
        # Normalize (Project back to unit circle)
        # We process as [D, 2] to normalize each pair correctly
        pairs = sum_vec.view(-1, 2)
        
        # Get Angles
        angles = torch.atan2(pairs[:, 1], pairs[:, 0])
        
        # Quantize
        step = 2 * np.pi / self.n
        discrete_angles = torch.round(angles / step) * step
        
        return torch.stack([torch.cos(discrete_angles), torch.sin(discrete_angles)], dim=-1).flatten()

    def similarity(self, u, v):
        # Standard Dot Product of these vectors IS Cosine Similarity
        # (sum of cos(theta_diff))
        return F.cosine_similarity(u, v, dim=0).item()