import torch
import numpy as np
from AbstractVSA import AbstractVSA

class BSDC_CDT(AbstractVSA):
    def __init__(self, dimension, device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        super().__init__(dimension)
        self.density = 1.0 / np.sqrt(dimension)
        self.max_iters = 50 # Default from bind_vectors.m

    def generate_vector(self):
        """Sparse Binary Vector (same as BSDC-S)."""
        k = int(np.floor(self.d * self.density))
        v = torch.zeros(self.d, device=self.device)
        indices = torch.randperm(self.d, device=self.device)[:k]
        v[indices] = 1.0
        return v

    def _cdt(self, z):
        """
        Context Dependent Thinning Procedure.
       
        """
        # Copy z to avoid modifying input
        z = z.clone()
        
        counter = 1
        # While density is too high
        while torch.mean(z) > self.density:
            # Shift determined by counter
            # Matlab: r = counter; permutation = circshift(vector, r)
            r = counter
            permutation = torch.roll(z, shifts=r, dims=0)
            
            # Find overlaps (AND)
            thinned = (z * permutation) > 0 # Logical AND
            
            # Remove overlaps (set to 0)
            z[thinned] = 0.0
            
            if counter > self.max_iters:
                break
            counter += 1
            
        return z

    def bind(self, u, v):
        """
        Disjunction Binding with CDT.
       
        """
        # Disjunction (Logical OR)
        # Note: Matlab uses addition then CDT.
        # values_disj = vectors_1 + vectors_2
        disj = u + v 
        # (Technically elements can be 2.0 here, but CDT handles non-zeros)
        
        # Apply CDT to thin the result back to target density
        return self._cdt(disj)

    def unbind(self, u, z):
        """
        NO UNBIND OPERATOR.
        
       
        "There is no specific unbind operator for the selected VSA - 
         use the finding of the most similar vectors in item memory instead!"
        """
        raise NotImplementedError("BSDC-CDT does not support algebraic unbinding. Use cleanup/search directly.")

    def bundle(self, vectors):
        """
        Bundling with Top-K Thinning (Matlab Standard).
        Strict Binary {0, 1} output.
        """
        if isinstance(vectors, list):
            vectors = torch.stack(vectors)
            
        # Sum
        sum_vec = torch.sum(vectors, dim=0)
        
        # Thinning
        k = int(np.floor(self.d * self.density))
        
        result = torch.zeros_like(sum_vec)
        
        # Tie-breaking noise
        noise = torch.rand_like(sum_vec) * 1e-5
        _, top_indices = torch.topk(sum_vec + noise, k)
        
        result[top_indices] = 1.0
        return result

    def similarity(self, u, v):
        # Overlap
        overlap = torch.dot(u, v)
        norm_u = torch.norm(u)
        norm_v = torch.norm(v)
        if norm_u == 0 or norm_v == 0: return 0.0
        return (overlap / (norm_u * norm_v)).item()