import torch
import torch.nn.functional as F
import numpy as np

class CGR_INT:
    def __init__(self, dimension, n=16, device=None):
        """
        CGR Implementation based on 'cornell-relaxml' repo.
        
        Logic Source:
          - Mapping: 'pts_map' from model.py
          - Bundling: 'group_bundle' from encoder.py
          - Binding: 'group_bind' from encoder.py (Implicit Modular Addition)
          - Similarity: 'GroupSim' from model.py
        
        This class maintains compatibility with your test script by returning 
        Float tensors (Cos/Sin pairs) while performing logic on Integers.
        """
        # We use dimension//2 integers so the output float vector (Cos+Sin) 
        # matches the requested dimension size.
        self.d = dimension // 2  
        self.n = n # 'gorder' in their code
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _pts_map(self, x, r=1.0):
        """
        Direct copy of 'pts_map' from model.py/encoder.py.
        Maps integers -> [Cos, Sin] pairs.
        """
        # theta = 2.0 * np.pi / (1.0 * self.gorder) * x
        theta = 2.0 * np.pi / (1.0 * self.n) * x.float()
        pts = r * torch.stack([torch.cos(theta), torch.sin(theta)], -1)
        return pts.flatten() # Flatten to [2*D] for your script

    def _get_ints(self, float_vec):
        """
        Reverse of pts_map. Recovers integers from the [Cos, Sin] vector.
        Necessary because your script passes around Floats, but their logic uses Ints.
        """
        # Reshape back to [D, 2] -> (Cos, Sin)
        pairs = float_vec.view(-1, 2)
        
        # Calculate angles (arctan2)
        # Their encoder.py uses: 2*pi + arctan(...) ... 
        # But standard atan2 is safer and mathematically equivalent for retrieval.
        angles = torch.atan2(pairs[:, 1], pairs[:, 0])
        
        # Quantize back to integer index [0, n-1]
        # logic matches: torch.round(angles / step)
        step = 2 * np.pi / self.n
        indices = torch.round(angles / step)
        return indices.long() % self.n

    def generate_vector(self):
        """
        Matches 'LinearEncoder.get_hdv' logic but for Group VSA.
        Generates random integers [0, n-1].
        """
        k = torch.randint(0, self.n, (self.d,), device=self.device)
        return self._pts_map(k)

    def bind(self, u, v):
        """
        Matches 'group_bind' from encoder.py.
        Logic: Integer Summation (Modular Addition).
        """
        # 1. Recover Integers from the script's Float vectors
        u_int = self._get_ints(u)
        v_int = self._get_ints(v)
        
        # 2. Their code: "results = torch.sum(lst, dim=0)"
        # "torch.fmod(results, self.gorder) # mathematically same"
        result_int = (u_int + v_int) % self.n
        
        # 3. Return as Floats (Cos/Sin)
        return self._pts_map(result_int)

    def unbind(self, u, z):
        """
        Inverse of bind (Modular Subtraction).
        """
        u_int = self._get_ints(u)
        z_int = self._get_ints(z)
        
        result_int = (z_int - u_int) % self.n
        
        return self._pts_map(result_int)


    def bundle(self, vectors):
        """
        Matches 'group_bundle' from encoder.py.
        Logic: Map to points -> Sum -> Angle -> Quantize.
        """
        # 1. Inputs are ALREADY 'pts_map' vectors (Cos/Sin floats)
        # So we can skip the mapping step and just sum them.
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # "pts = torch.sum(self.pts_map(lst), dim=0)"
        # Here 'vectors' is already pts_map(lst)
        sum_pts = torch.sum(vectors, dim=0)
        
        # 2. Extract Angle and Quantize (Their logic)
        # Reshape to [D, 2] to separate Cos/Sin
        pts = sum_pts.view(-1, 2)
        
        # Their specific angle logic from encoder.py:
        # raw_angles = 2 * np.pi + torch.arctan(pts[:, 1] / pts[:, 0]) - np.pi * (pts[:, 0] < 0).float()
        # angles = torch.fmod(raw_angles, 2 * np.pi)
        # Note: torch.atan2 is more robust than their manual arctan logic, 
        # but this effectively does the same "Circular Mean".
        angles = torch.atan2(pts[:, 1], pts[:, 0])
        
        # "return torch.floor(angles / (2.0 * np.pi) * self.gorder + 1 / 2)"
        # This is effectively rounding to the nearest integer index
        indices = torch.round(angles / (2 * np.pi) * self.n)
        
        # 3. Map back to float for output
        return self._pts_map(indices.long() % self.n)

    def similarity(self, u, v):
        """
        Matches 'GroupSim' from model.py.
        Logic: Dot product of the mapped points.
        """
        # Their code: "torch.sum(torch.sum(map_weight * map_input, dim=-1), dim=-1)"
        # Since u and v are already mapped (Cos/Sin flattened), 
        # a standard dot product / cosine similarity is mathematically identical.
        return F.cosine_similarity(u, v, dim=0).item()