import torch
import numpy as np
from AbstractVSA import AbstractVSA

class BSDC_SEG(AbstractVSA):
    def __init__(self, dimension, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)
        
        # Calculate Segment Configuration
        #
        self.density = 1.0 / np.sqrt(dimension)
        self.num_segments = int(np.floor(dimension * self.density))
        self.seg_size = int(np.floor(dimension / self.num_segments))
        
        # Handle case where dimensions don't divide perfectly
        self.used_dim = self.num_segments * self.seg_size

    def generate_vector(self):
        """
        Segmented Sparse Vector.
        One '1' per segment.
       
        """
        v = torch.zeros(self.d, device=self.device)
        
        # Generate random index for each segment
        # shape (num_segments,) with values in [0, seg_size-1]
        offsets = torch.randint(0, self.seg_size, (self.num_segments,), device=self.device)
        
        # Calculate global indices
        # base indices: [0, seg_size, 2*seg_size, ...]
        base_indices = torch.arange(0, self.used_dim, self.seg_size, device=self.device)
        final_indices = base_indices + offsets
        
        v[final_indices] = 1.0
        return v

    def bind(self, u, v):
        """
        Segment-wise Shift Binding.
       
        """
        # Reshape to (num_segments, seg_size)
        # Ignore unused trailing dimensions for calculation
        u_seg = u[:self.used_dim].view(self.num_segments, self.seg_size)
        v_seg = v[:self.used_dim].view(self.num_segments, self.seg_size)
        
        # Find position of the '1' in each segment of u (the role)
        # argmax gives index of 1 in each segment
        shifts = torch.argmax(u_seg, dim=1)
        
        # Apply shifts to v
        # PyTorch doesn't have a batched 'roll' with different shifts per row easily.
        # We use index arithmetic: (arange + shift) % size
        
        # Create grid of indices [0, 1, ... seg_size-1] repeated for each segment
        grid = torch.arange(self.seg_size, device=self.device).expand(self.num_segments, -1)
        
        # Calculate source indices: we want result[i] to come from source[i - shift]
        # So source_idx = (grid - shift) % size
        # However, bind_vectors.m says: "result_rows = mod(role_rows + filler_rows - 1, size) + 1"
        # This implies addition of indices (convolution logic).
        # if output[k] = input[k - shift], that matches convolution.
        
        # Let's follow the standard shift logic: result is v shifted by u.
        # If u has bit at index S, result is v rolled by S.
        src_indices = (grid - shifts.unsqueeze(1)) % self.seg_size
        
        # Gather values
        result_seg = torch.gather(v_seg, 1, src_indices)
        
        # Flatten and pad if necessary
        result = torch.zeros(self.d, device=self.device)
        result[:self.used_dim] = result_seg.view(-1)
        
        # Handle remaining part (pass through filler as per Matlab code)
        if self.used_dim < self.d:
            result[self.used_dim:] = v[self.used_dim:]
            
        return result

    def unbind(self, u, z):
        """
        Segment-wise Shift Unbinding.
        Inverse of bind (shift by -u).
       
        """
        u_seg = u[:self.used_dim].view(self.num_segments, self.seg_size)
        z_seg = z[:self.used_dim].view(self.num_segments, self.seg_size)
        
        shifts = torch.argmax(u_seg, dim=1)
        
        # Unbind: shift back, so (grid + shift) % size
        grid = torch.arange(self.seg_size, device=self.device).expand(self.num_segments, -1)
        src_indices = (grid + shifts.unsqueeze(1)) % self.seg_size
        
        result_seg = torch.gather(z_seg, 1, src_indices)
        
        result = torch.zeros(self.d, device=self.device)
        result[:self.used_dim] = result_seg.view(-1)
        
        if self.used_dim < self.d:
            result[self.used_dim:] = z[self.used_dim:]
            
        return result

    def bundle(self, vectors):
        """
        Segment-wise Bundling with Hard WTA.
        Strict Binary {0, 1} output.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        sum_vec = torch.sum(vectors, dim=0)
        
        # Reshape segments
        sum_seg = sum_vec[:self.used_dim].view(self.num_segments, self.seg_size)
        
        # Find max in each segment (Hard Winner Take All)
        # We restore 1 bit per segment.
        _, max_indices = torch.max(sum_seg, dim=1) # (num_segments,)
        
        # Create result
        result_seg = torch.zeros_like(sum_seg)
        # Scatter 1s at max indices
        result_seg.scatter_(1, max_indices.unsqueeze(1), 1.0)
        
        result = torch.zeros(self.d, device=self.device)
        result[:self.used_dim] = result_seg.view(-1)
        
        # Trailing dimensions
        if self.used_dim < self.d:
             # For strict binary, we should probably threshold/thin this too or just copy?
             # Matlab usually appends. We'll threshold to be safe or copy logic.
             # Original code copied sum. Let's threshold > 0.5?
             # Actually simplest is just keep 0 for unused or ...
             # Original code: result[self.used_dim:] = sum_vec[self.used_dim:]
             # If we want {0,1}, we must threshold.
             result[self.used_dim:] = (sum_vec[self.used_dim:] > 0.5).float()
             
        return result

    def similarity(self, u, v):
        # Cosine / Overlap
        overlap = torch.dot(u, v)
        return (overlap / (torch.norm(u) * torch.norm(v) + 1e-9)).item()