import torch
import numpy as np
from AbstractVSA import AbstractVSA

class BSDC_S(AbstractVSA):
    def __init__(self, dimension, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)
        
        # Standard sparse density approx 1/sqrt(D)
        #
        self.density = 1.0 / np.sqrt(dimension)

    def generate_vector(self):
        """
        Sparse Binary Vector.
        Logic from generate_vectors.m (Case BSDC/BSDC_SHIFT):
        Selects k = floor(D * density) random indices to be 1.
        """
        k = int(np.floor(self.d * self.density))
        v = torch.zeros(self.d, device=self.device)
        
        # Random indices
        indices = torch.randperm(self.d, device=self.device)[:k]
        v[indices] = 1.0
        return v

    def _get_shift(self, u):
        """
        Calculates shift amount.
        Matlab: idx = [1:size(vectors_1,1)]*vectors_1;
        Sum of indices of active bits.
        """
        # Create indices [1, 2, ..., D]
        # We use 1-based indexing to match Matlab logic exactly, 
        # though 0-based is fine if consistent.
        indices = torch.arange(1, self.d + 1, device=self.device, dtype=torch.float32)
        
        # Dot product sums the indices where u is 1
        shift = torch.dot(indices, u).item()
        return int(shift)

    def bind(self, u, v):
        """
        Shift Binding.
        u acts as the shifter. v is shifted by calc_shift(u).
       
        """
        shift = self._get_shift(u)
        
        # Circular shift v by 'shift'
        return torch.roll(v, shifts=shift, dims=0)

    def unbind(self, u, z):
        """
        Shift Unbinding.
        Shift z by -calc_shift(u).
       
        """
        shift = self._get_shift(u)
        
        # Circular shift z by '-shift'
        return torch.roll(z, shifts=-shift, dims=0)

    def bundle(self, vectors):
        """
        Bundling with Top-K Thinning.
        Strict Binary {0, 1} output.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Sum
        sum_vec = torch.sum(vectors, dim=0)
        
        # Thinning (Keep top k)
        k = int(np.floor(self.d * self.density))
        
        # Create result vector
        result = torch.zeros_like(sum_vec)
        
        # Find top k indices
        # We add small noise to break ties randomly
        noise = torch.rand_like(sum_vec) * 1e-5
        _, top_indices = torch.topk(sum_vec + noise, k)
        
        result[top_indices] = 1.0
        return result

    def similarity(self, u, v):
        """Overlap / Cosine for binary vectors."""
        # For binary vectors, dot product is the number of overlapping bits
        overlap = torch.dot(u, v)
        
        # Normalize by norm (Cosine Similarity) or just return overlap?
        # AbstractVSA usually expects Cosine [0,1]
        norm_u = torch.norm(u)
        norm_v = torch.norm(v)
        if norm_u == 0 or norm_v == 0: return 0.0
        return (overlap / (norm_u * norm_v)).item()