import torch
import math
from itertools import permutations

class Memory:
    def __init__(self, N, p=2, isWedge=False, orthogonalize_roles=False, threshold=None, device='cuda'):
        """
        N: Vector dimension
        p: Rank of the filler (number of vectors bound to the role)
        isWedge: Use wedge product (antisymmetric) vs outer product (symmetric)
        orthogonalize_roles: If True, applies Gram-Schmidt to roles before binding/checking.
                             Set to False if roles are already orthonormal.
        threshold: Score threshold for recognition. If None, defaults to 0.6 (Wedge) or 0.3 (Outer).
        device: 'cuda' or 'cpu'
        """
        self.N = N
        self.p = p
        self.isWedge = isWedge
        self.orthogonalize_roles = orthogonalize_roles
        self.device = device

        # Memory is a tensor of rank p: (N, N, ..., N)
        self.memory = torch.zeros((N,) * p, device=device)

        # Default thresholds
        if threshold is not None:
            self.threshold = threshold
        else:
            self.threshold = 0.6 if isWedge else 0.3

    def _gram_schmidt(self, vectors):
        """
        Orthonormalizes a set of vectors using the Gram-Schmidt process.
        vectors: torch tensor of shape (num_vectors, N)
        """
        num_vecs = vectors.shape[0]
        basis = torch.zeros_like(vectors)

        for i in range(num_vecs):
            w = vectors[i].clone()
            # Vectorized projection onto all previous basis vectors
            if i > 0:
                projections = torch.mv(basis[:i], w)  # (i,) shape
                w -= torch.sum(projections.unsqueeze(1) * basis[:i], dim=0)

            norm = torch.norm(w)
            if norm > 1e-9:
                basis[i] = w / norm

        return basis

    def _generalized_outer_optimized(self, vectors):
        """
        Computes the tensor product using einsum for p=2, 3, 4.
        Falls back to iterative method for p > 4.
        vectors: torch tensor of shape (p, N)
        """
        if self.p == 2:
            return torch.einsum('i,j->ij', vectors[0], vectors[1])
        elif self.p == 3:
            return torch.einsum('i,j,k->ijk', vectors[0], vectors[1], vectors[2])
        elif self.p == 4:
            return torch.einsum('i,j,k,l->ijkl', vectors[0], vectors[1], vectors[2], vectors[3])
        else:
            # Fallback for p > 4
            result = vectors[0]
            for v in vectors[1:]:
                shape = result.shape + (1,)
                result = result.reshape(shape) * v
            return result

    def _generalized_wedge(self, vectors):
        """
        Computes the wedge product for p vectors using optimized outer product.
        vectors: torch tensor of shape (p, N)
        """
        tensor_sum = torch.zeros((self.N,) * self.p, device=self.device)

        # Pre-compute sign for each permutation
        perms = list(permutations(range(self.p)))

        for sigma in perms:
            # Calculate parity efficiently
            inversions = sum(1 for i in range(self.p) for j in range(i + 1, self.p) if sigma[i] > sigma[j])
            sign = (-1) ** inversions

            # Permute vectors
            perm_vectors = vectors[list(sigma)]

            tensor_sum += sign * self._generalized_outer_optimized(perm_vectors)

        return tensor_sum

    def get_tensor_representation(self, vectors):
        """
        vectors: torch tensor of shape (p, N)
        """
        if self.isWedge:
            return self._generalized_wedge(vectors)
        return self._generalized_outer_optimized(vectors)

    def _process_role_vectors(self, role_factors):
        """Helper to handle optional orthogonalization."""
        if self.orthogonalize_roles:
            return self._gram_schmidt(role_factors)
        return role_factors

    def _project_out_roles_batched(self, filler_factors, clean_roles):
        """
        Vectorized projection of role subspace from filler vectors.
        filler_factors: (p, N)
        clean_roles: (num_roles, N)
        Returns: (p, N) - orthogonal complement (rejection)
        """
        # Compute all projections at once: (p, num_roles)
        dots = torch.mm(filler_factors, clean_roles.T)  # (p, num_roles)

        # Subtract projections to get orthogonal complement: (p, N) - (p, num_roles) @ (num_roles, N)
        # residue = filler_factors - torch.mm(dots, clean_roles)
        residue = torch.mm(dots, clean_roles)

        return residue

    def get_dual_contraction_score(self, role_factors, filler_factors):
        """
        Projects Role out of p Filler vectors and checks memory.
        role_factors: torch tensor of shape (num_roles, N)
        filler_factors: torch tensor of shape (p, N)
        """
        if filler_factors.shape[0] != self.p:
            raise ValueError(f"Expected {self.p} filler vectors, got {filler_factors.shape[0]}")

        # 1. OPTIONAL ORTHOGONALIZATION
        clean_roles = self._process_role_vectors(role_factors)

        # 2. PROJECTION (Carving) - VECTORIZED
        residue = self._project_out_roles_batched(filler_factors, clean_roles)

        # 3. CHECK SURVIVAL
        norms = torch.norm(residue, dim=1)  # (p,)
        if torch.sum(norms > 1e-3) != self.p:
            return 0.0

        # 4. NORMALIZE (no sorting)
        norms = torch.norm(residue, dim=1, keepdim=True)
        clean_vectors = residue / torch.clamp(norms, min=1e-9)

        # 5. CREATE CANDIDATE
        candidate_tensor = self.get_tensor_representation(clean_vectors)

        # 6. CHECK MEMORY (single operation, no intermediate allocations)
        dot_product = torch.sum(self.memory * candidate_tensor)

        # 7. NORMALIZE SCORE
        if self.isWedge:
            norm_factor = math.factorial(self.p)
        else:
            norm_factor = 1.0

        return (dot_product / norm_factor).item()

    def get_dual_contraction_score_batched(self, role_factors, filler_factors_batch):
        """
        Fully batched version: score multiple filler candidates against same role.
        role_factors: torch tensor of shape (num_roles, N)
        filler_factors_batch: torch tensor of shape (batch_size, p, N)
        Returns: torch tensor of shape (batch_size,)
        """
        batch_size = filler_factors_batch.shape[0]
        if filler_factors_batch.shape[1] != self.p:
            raise ValueError(f"Expected {self.p} filler vectors, got {filler_factors_batch.shape[1]}")

        # 1. OPTIONAL ORTHOGONALIZATION (same for all)
        clean_roles = self._process_role_vectors(role_factors)

        # 2. PROJECTION - FULLY BATCHED
        # filler_factors_batch: (batch_size, p, N)
        # clean_roles: (num_roles, N)
        # Compute dots: (batch_size, p, num_roles)
        dots = torch.einsum('bpn,rn->bpr', filler_factors_batch, clean_roles)

        # Subtract projections to get orthogonal complement: (batch_size, p, N)
        # residue = filler_factors_batch - torch.einsum('bpr,rn->bpn', dots, clean_roles)
        residue = torch.einsum('bpr,rn->bpn', dots, clean_roles)

        # 3. CHECK SURVIVAL - BATCHED
        norms = torch.norm(residue, dim=2)  # (batch_size, p)
        survival_mask = torch.sum(norms > 1e-3, dim=1) == self.p  # (batch_size,)

        # 4. NORMALIZE - BATCHED (no sorting)
        norms_expanded = norms.unsqueeze(2)  # (batch_size, p, 1)
        clean_vectors = residue / torch.clamp(norms_expanded, min=1e-9)  # (batch_size, p, N)

        # 5. CREATE CANDIDATE TENSORS - BATCHED
        if self.p == 1:
            # For p=1, the tensor is just the vector itself
            candidate_tensors = clean_vectors[:, 0, :]  # (batch_size, N)
        elif self.p == 2:
            # Batched outer product: (batch_size, N, N)
            candidate_tensors = torch.einsum('bi,bj->bij', clean_vectors[:, 0, :], clean_vectors[:, 1, :])
        elif self.p == 3:
            candidate_tensors = torch.einsum('bi,bj,bk->bijk',
                                            clean_vectors[:, 0, :],
                                            clean_vectors[:, 1, :],
                                            clean_vectors[:, 2, :])
        elif self.p == 4:
            candidate_tensors = torch.einsum('bi,bj,bk,bl->bijkl',
                                            clean_vectors[:, 0, :],
                                            clean_vectors[:, 1, :],
                                            clean_vectors[:, 2, :],
                                            clean_vectors[:, 3, :])
        else:
            # Fallback for p > 4
            raise NotImplementedError(f"Batched tensor product not implemented for p={self.p}")

        # 6. CHECK MEMORY - BATCHED
        # Flatten memory and candidate tensors for batched dot product
        memory_flat = self.memory.flatten()  # (N^p,)
        candidates_flat = candidate_tensors.reshape(batch_size, -1)  # (batch_size, N^p)

        dot_products = torch.mv(candidates_flat, memory_flat)  # (batch_size,)

        # 7. NORMALIZE SCORE
        if self.isWedge:
            norm_factor = math.factorial(self.p)
        else:
            norm_factor = 1.0

        scores = dot_products / norm_factor

        # Apply survival mask (set non-surviving to 0)
        scores = scores * survival_mask.float()

        return scores

    def bind(self, role_factors, filler_factors):
        """
        Carves the Fillers using the Role and adds the tensor to memory.
        role_factors: torch tensor of shape (num_roles, N)
        filler_factors: torch tensor of shape (p, N)
        """
        if filler_factors.shape[0] != self.p:
            raise ValueError(f"Expected {self.p} filler vectors, got {filler_factors.shape[0]}")

        # 1. OPTIONAL ORTHOGONALIZATION
        clean_roles = self._process_role_vectors(role_factors)

        # 2. PROJECTION - VECTORIZED
        residue = self._project_out_roles_batched(filler_factors, clean_roles)

        # 3. NORMALIZE (no sorting)
        norms = torch.norm(residue, dim=1, keepdim=True)
        clean_vectors = residue / torch.clamp(norms, min=1e-6)

        # 4. UPDATE MEMORY (in-place addition)
        tensor_to_add = self.get_tensor_representation(clean_vectors)
        self.memory.add_(tensor_to_add)

    def evaluate_score(self, score):
        return abs(score) > self.threshold