import torch
import torch.nn as nn
import torch.nn.functional as F
import hashlib
from hlb_utils import normal_mixture, index_sequence, cosine_similarity
from network import Network as BaseNetwork

def compute_seed(role_indices):
    """Compute deterministic seed from role indices."""
    # role_indices is a list/tuple of integers
    # For determinism, sort them
    sorted_indices = sorted(list(role_indices))
    # Create string key
    key = str(tuple(sorted_indices)).encode('utf-8')
    # Hash to integer
    seed = int(hashlib.sha256(key).hexdigest()[:8], 16)
    return seed

def batched_qr_orthonormal(seeds, n_dim, num_vectors, device='cuda'):
    """
    Generate orthonormal bases for multiple seeds.
    """
    n_seeds = len(seeds)
    # Generate random matrices
    batch_matrices = torch.zeros((n_seeds, n_dim, num_vectors), device=device)
    for i, seed in enumerate(seeds):
        generator = torch.Generator(device=device).manual_seed(seed)
        batch_matrices[i] = torch.randn(n_dim, num_vectors, generator=generator, device=device)
        
    # QR Decomposition
    # Q: (n_seeds, n_dim, min(n_dim, num_vectors))
    Q, _ = torch.linalg.qr(batch_matrices)
    
    # We want (n_seeds, num_vectors, n_dim) as the basis "rows"
    # But Q columns are orthonormal.
    # So we take Q and transpose.
    # Note: If num_vectors > n_dim, QR gives n_dim vectors. 
    # We basically can't have more than n_dim orthonormal vectors.
    
    return Q.transpose(1, 2)

class Network(nn.Module):
    def __init__(self, device, in_features, hidden, out_features, labels, requires_grad=True, negative=False,
                 factor=1, drop_rate=0., reduce_dim=False, kernel_dim=None, num_role_vecs=None):
        super().__init__()

        # --- Base Network Architecture (Identical to network_sparse.py) ---
        self.reduce_dim = reduce_dim
        if self.reduce_dim:
            in_features = (in_features - kernel_dim) // kernel_dim + 1
            self.avgpool = nn.Sequential(
                nn.AvgPool1d(kernel_size=kernel_dim, stride=kernel_dim)
            )
            print(f"Reducing dimension with kernel size {kernel_dim}: New input feature size: {in_features}.")

        self.network = nn.Sequential(
            nn.Linear(in_features, hidden),
            nn.LeakyReLU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden, hidden * factor),
            nn.LeakyReLU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden * factor, out_features),
            nn.Tanh()
        )

        # --- VSA Customization ---
        # 1. Dimensions
        self.device = device
        self.num_labels = labels
        self.out_features_flat = out_features
        
        # Calculate VSA dimension "N" assuming out_features = N*N (for p=2)
        # N * N = out_features
        # N = sqrt(out_features)
        # This is expected to be an integer square
        self.N_dim = int(out_features ** 0.5)
        if self.N_dim * self.N_dim != out_features:
             print(f"WARNING: out_features ({out_features}) is not a perfect square. "
                   f"VSA Logic (Outer Product) expects D^2 size. Using N={self.N_dim}.")
        
        # 2. Initialization: Label Matrices & Role Tensor
        # Labels are static/frozen embeddings (requires_grad=False) or learnable?
        # "embeddings are frozen" -> requires_grad=False
        
        # Two distinct vectors per label: [C1, C2]
        # self.cls1 = nn.Parameter(torch.randn(self.num_labels, self.N_dim), requires_grad=False)
        # self.cls2 = nn.Parameter(torch.randn(self.num_labels, self.N_dim), requires_grad=False)
        self.cls = nn.Parameter(normal_mixture((1, labels, out_features)), requires_grad=requires_grad)
        
        # Single Role Tensor (Rank 1? Or N Roles?)
        # "Initialize a single Role Tensor (self.roles) of shape (NumRoles, Dim)."
        # NumRoles = N_dim (Since we project out N roles from vectors of size N?)
        # In Memory_General_Torch.py, role_vectors shape is (num_roles, N).
        # We will assume NumRoles = N_dim for full basis, or user might imply something else?
        
        # We no longer use a single static self.roles parameter.
        # Instead, bases are generated dynamically per label in pre-compute.
        # self.roles = nn.Parameter(torch.randn(self.num_roles, self.N_dim), requires_grad=False)
        
        # We allow configuring num_role_vecs to be less than N_dim to avoid zeroing out signal.
        if num_role_vecs is None:
            self.num_roles = self.N_dim
        else:
            self.num_roles = num_role_vecs
            
        print(f"VSA Config: N_dim={self.N_dim}, NumRoles={self.num_roles}")
        
        # Orthogonalize Roles (Optional but good for VSA)
        # Doing QR decomposition to ensure roles are orthogonal
        # Q, _ = torch.linalg.qr(self.roles.data) 
        # if Q.shape == self.roles.shape:
        #     self.roles.data = Q
        
        # 3. Pre-Computed Memory Bank
        # "Compute M_i = Bind(Roles, [C1_i, C2_i])"
        # "Store in self.bound_labels (NumLabels x out_features)"
        
        # Trigger Pre-computation
        self.pre_compute_memory_bank(self.cls)
        
    def pre_compute_memory_bank(self, all_fillers):
        """
        Pre-computes the VSA representation for each label.
        all_fillers: (NumLabels, 1, 1, N) or similar from normal_mixture
        """
        # Use the device passed in __init__, not the device of the (CPU) parameters
        device = self.device
        print(f"Pre-computing VSA Memory Bank on {device}...")
        
        fillers_flat = all_fillers.squeeze(0) # (Labels, N*N)
        # We need (Labels, 2, N_dim)
        # We'll take the first 2 * N_dim elements for the fillers
        needed = 2 * self.N_dim
        if fillers_flat.shape[1] < needed:
             raise ValueError(f"Out Features {fillers_flat.shape[1]} too small for rank-2 factorization of dim {self.N_dim}")
             
        fillers_sub = fillers_flat[:, :needed]
        all_fillers_reshaped = fillers_sub.view(self.num_labels, 2, self.N_dim)
        
        batch_memories = []
        
        # Limit roles dimension if needed
        limit_roles = min(self.num_roles, self.N_dim)
        
        chunk_size = 100
        # Process in chunks to save memory and use GPU efficiently
        for i in range(0, self.num_labels, chunk_size):
                
            # 1. Get Fillers for Chunk & Move to GPU
            # 1. Get Fillers for Chunk & Move to GPU
            chunk_fillers = all_fillers_reshaped[i : i+chunk_size].to(device) # (B, 2, N)
            current_batch_size = chunk_fillers.shape[0]
            
            # 2. Generate Seeds for Chunk
            chunk_seeds = []
            for j in range(current_batch_size):
                chunk_seeds.append(compute_seed([0, i + j]))
            
            # 3. Generate Roles for Chunk (On GPU)
            # (B, NumRoles, N_dim)
            chunk_roles = batched_qr_orthonormal(chunk_seeds, self.N_dim, limit_roles, device=device)
            
            # 4. Bind (Project -> Outer Product)
            # chunk_fillers: (B, 2, N)
            # chunk_roles: (B, R, N)
            
            # Projection: (B, 2, R) = bjn, brn -> bjr
            dots = torch.einsum('bjn,brn->bjr', chunk_fillers, chunk_roles)
            
            # Residue: (B, 2, N) = bjr, brn -> bjn
            residue = torch.einsum('bjr,brn->bjn', dots, chunk_roles)
            
            # Outer Product (p=2) -> (B, N, N)
            v1 = residue[:, 0, :]
            v2 = residue[:, 1, :]
            outer = torch.einsum('bn,bm->bnm', v1, v2)
            
            # Flatten -> (B, N*N)
            flat = outer.reshape(outer.shape[0], -1) 
            batch_memories.append(flat)

            # Optional: Free memory explicitly if needed, but scope should handle it
            del chunk_roles, dots, residue, outer
            
        self.bound_labels = nn.Parameter(torch.cat(batch_memories, dim=0), requires_grad=False)
        # Ensure parameter is on the correct device (though typically Parameters are float32)
        self.bound_labels.data = self.bound_labels.data.to(device)
        
        # print(f"Memory Bank Computed. Shape: {self.bound_labels.shape}")

    def bind(self, roles, fillers):
        """
        Simulates Memory_General_Torch.py bind.
        roles: (NumRoles, N)
        fillers: (p=2, N)
        Returns: Flattened Tensor (1, N*N)
        """
        # 1. Project roles out of fillers
        # dots: (2, R)
        dots = torch.mm(fillers, roles.T)
        # residue: (2, N)
        # residue = fillers - torch.mm(dots, roles)
        residue = torch.mm(dots, roles)
        
        # 2. Outer Product (p=2) using optimized einsum
        tensor = torch.einsum('i,j->ij', residue[0], residue[1])
        
        # 3. Flatten
        return tensor.flatten().unsqueeze(0)

    def forward(self, x):
        if self.reduce_dim:
            x = self.avgpool(x)
        return self.network(x)

    def loss(self, logits, true):
        """
        logits: (Batch, OutFeatures) - Network Output
        true: (Batch, MaxLabels) - Indices of true labels (padded with 0 or similar?)
              Actually in network_sparse.py, 'true' seems to be a list of indices or mask?
              Looking at network_sparse.loss:
                 pos = torch.sum(self.cls[true], dim=1)
              This implies 'true' is (Batch, MaxLabels) containing INDICES.
        """
        # Logic: Look up pre-computed memories for true labels and sum them.
        
        # self.bound_labels: (TotalLabels, OutFeatures)
        # true: (Batch, k) indices
        
        # Gather/Sum Targets
        # targets: (Batch, OutFeatures)
        # We sum the vectors for the indices in 'true'
        
        # Note: If 'true' contains padding indices (like last index), make sure they map to Zero vector?
        # self.cls in sparse had a zero vector at end.
        # We initialized cls1/cls2 with num_labels+1. 
        # But for summing, we probably want the padding index to contribute 0.
        # Assuming last index is padding/null.
        # Ideally, we'd zero out the last row of bound_labels if it's padding.
        # Or relying on 'true' structure. 
        # network_sparse usually handles 0-index or -1 index?
        # "self.labels = labels + 1" in original implies 0 was special?
        
        # To be safe: User instruction: "Target = Sum(self.bound_labels[true_labels])"
        # We simply perform the lookup and sum.
        
        batch_targets = torch.sum(self.bound_labels[true], dim=1) # (Batch, OutFeatures)
        
        # Loss = CosineDistance(logits, Target)
        # CosineDistance = 1 - CosineSimilarity
        
        # Ensure dimensionality match
        similarity = cosine_similarity(logits, batch_targets, dim=-1)
        
        # We want to MINIMIZE distance => MAXIMIZE similarity
        # Loss = 1 - Mean(Similarity)
        # usually 1 - cos
        
        return torch.mean(1.0 - similarity)

    def inference(self, logits, steps=100): # Steps for batched processing if memory large
        """
        Logic: Calculate similarity between logits and all pre-computed memories.
        Return indices of highest scores.
        """
        # logits: (Batch, OutFeatures)
        # self.bound_labels: (NumLabels, OutFeatures)
        
        # Compute Cosine Similarity Matrix: (Batch, NumLabels)
        # Helper: cosine_similarity usually takes (A, B).
        # We can normalize both and do matrix multiplication.
        
        logits_norm = logits / (logits.norm(dim=1, keepdim=True) + 1e-8)
        memory_norm = self.bound_labels / (self.bound_labels.norm(dim=1, keepdim=True) + 1e-8)
        
        # Scores: (Batch, NumLabels)
        scores = torch.mm(logits_norm, memory_norm.T)
        
        # We usually return the top scores or the raw scores?
        # network_sparse.inference returns "torch.abs(torch.concatenate(score, dim=1))"
        # It seems to return the similarity scores for ALL labels.
        # The evaluation loop usually picks top-k from this score matrix.
        
        return scores
