"""
TAN (Topological Attention Network) Architecture Implementation
Based on the ICLR 2026 paper specifications
Clean implementation showing exact architecture as described in the paper
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Optional, Tuple, Dict, List
from dataclasses import dataclass

@dataclass
class TANConfig:
    """Configuration for Topological Attention Network"""
    vocab_size: int = 50265
    embed_dim: int = 768
    num_heads: int = 12
    num_layers: int = 12
    max_seq_length: int = 512
    dropout: float = 0.1
    
    # Topological parameters
    k_neighbors: int = 32
    use_topology: bool = True
    topology_dim: int = 128
    
    # LSH parameters
    use_lsh: bool = True
    num_hashes: int = 8
    hash_bits: int = 256
    lsh_temperature: float = 0.1
    
    # Multi-scale parameters
    multi_scale_k: List[int] = None
    
    def __post_init__(self):
        if self.multi_scale_k is None:
            # Different k values for different heads
            self.multi_scale_k = [8, 16, 32, 64] * (self.num_heads // 4)

class TopologicalFeatureExtractor(nn.Module):
    """
    Extracts topological features from embeddings using k-NN graphs
    and persistent homology approximations
    """
    
    def __init__(self, embed_dim: int, k_neighbors: int, topology_dim: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.k_neighbors = k_neighbors
        self.topology_dim = topology_dim
        
        # Learnable projection for topological features
        self.topology_proj = nn.Linear(embed_dim, topology_dim)
        
        # Multi-layer encoder for topological features
        self.topology_encoder = nn.Sequential(
            nn.Linear(topology_dim, topology_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(topology_dim * 2, topology_dim),
            nn.LayerNorm(topology_dim)
        )
        
        # Persistence landscape approximation
        self.persistence_mlp = nn.Sequential(
            nn.Linear(topology_dim, topology_dim),
            nn.ReLU(),
            nn.Linear(topology_dim, topology_dim)
        )
        
    def compute_knn_graph(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute k-nearest neighbor graph
        Returns distances and indices of k-nearest neighbors
        """
        batch_size, seq_len, _ = embeddings.shape
        
        # Compute pairwise distances
        embeddings_norm = embeddings / (embeddings.norm(dim=-1, keepdim=True) + 1e-8)
        distances = 1 - torch.matmul(embeddings_norm, embeddings_norm.transpose(-2, -1))
        
        # Get k-nearest neighbors (excluding self)
        # Add small noise to break ties
        distances = distances + torch.randn_like(distances) * 1e-6
        
        # Set diagonal to infinity (exclude self-connections) - Fixed for batch dimension
        mask = torch.eye(seq_len, device=distances.device).unsqueeze(0).expand(batch_size, -1, -1)
        distances = distances.masked_fill(mask.bool(), float('inf'))
        
        # Get top-k neighbors
        k = min(self.k_neighbors, seq_len - 1)
        neighbor_distances, neighbor_indices = torch.topk(distances, k, dim=-1, largest=False)
        
        return neighbor_distances, neighbor_indices
    
    def extract_topological_features(self, embeddings: torch.Tensor,
                                    neighbor_distances: torch.Tensor,
                                    neighbor_indices: torch.Tensor) -> torch.Tensor:
        """
        Extract topological features from the k-NN graph
        Approximates persistent homology features
        """
        batch_size, seq_len, _ = embeddings.shape
        
        # Project embeddings to topology space
        topo_embeddings = self.topology_proj(embeddings)
        
        # Gather neighbor features
        neighbor_features = torch.gather(
            topo_embeddings.unsqueeze(2).expand(-1, -1, seq_len, -1),
            2,
            neighbor_indices.unsqueeze(-1).expand(-1, -1, -1, self.topology_dim)
        )
        
        # Weight neighbors by inverse distance
        weights = F.softmax(-neighbor_distances, dim=-1).unsqueeze(-1)
        weighted_neighbors = (neighbor_features * weights).sum(dim=2)
        
        # Combine with original features
        combined = topo_embeddings + weighted_neighbors
        
        # Encode topological features
        topo_features = self.topology_encoder(combined)
        
        # Approximate persistence landscape
        persistence_features = self.persistence_mlp(topo_features)
        
        return persistence_features
    
    def forward(self, embeddings: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass to extract topological features
        """
        # Compute k-NN graph
        distances, indices = self.compute_knn_graph(embeddings)
        
        # Extract topological features
        topo_features = self.extract_topological_features(embeddings, distances, indices)
        
        return {
            'features': topo_features,
            'distances': distances,
            'indices': indices
        }

class LocalitySensitiveHashing(nn.Module):
    """
    LSH module for efficient attention computation
    Guided by topological features
    """
    
    def __init__(self, embed_dim: int, num_hashes: int, hash_bits: int, temperature: float):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_hashes = num_hashes
        self.hash_bits = hash_bits
        self.temperature = temperature
        
        # Random projection matrices for LSH
        self.register_buffer('hash_proj', torch.randn(num_hashes, embed_dim, hash_bits))
        
        # Learnable bias for topology-guided hashing
        self.topology_bias = nn.Parameter(torch.zeros(num_hashes, hash_bits))
        
    def hash_vectors(self, vectors: torch.Tensor, topology_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Hash vectors using LSH with optional topology guidance
        """
        batch_size, seq_len, _ = vectors.shape
        
        # Project vectors
        projections = torch.einsum('bsd,hdk->bshk', vectors, self.hash_proj)
        
        # Add topology-guided bias if available
        if topology_features is not None:
            # Only use topology features if dimensions match
            topo_dim = topology_features.size(-1)
            if topo_dim <= self.embed_dim:
                # Project topology features to hash space using only the relevant dimensions
                topo_proj = torch.einsum('bsd,hdk->bshk', topology_features, self.hash_proj[:, :topo_dim, :])
                projections = projections + self.temperature * topo_proj
        
        # Add learnable bias
        projections = projections + self.topology_bias.unsqueeze(0).unsqueeze(1)
        
        # Compute hash codes
        hash_codes = torch.sign(projections)
        
        return hash_codes
    
    def compute_hash_similarity(self, query_hashes: torch.Tensor, key_hashes: torch.Tensor) -> torch.Tensor:
        """
        Compute similarity between query and key hash codes
        """
        # Hamming similarity between queries and keys
        # query_hashes: [batch, seq_len, num_hashes, hash_bits]
        # key_hashes: [batch, seq_len, num_hashes, hash_bits]
        # We want: [batch, seq_len, seq_len]
        similarity = torch.einsum('bqhk,bthk->bhqt', query_hashes, key_hashes) / self.hash_bits
        
        # Average over hash functions
        similarity = similarity.mean(dim=1)  # [batch, seq_len, seq_len]
        
        return similarity
    
    def forward(self, queries: torch.Tensor, keys: torch.Tensor,
                topology_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute LSH-based attention mask
        """
        # Hash queries and keys
        query_hashes = self.hash_vectors(queries, topology_features)
        key_hashes = self.hash_vectors(keys, topology_features)
        
        # Compute hash similarity between queries and keys
        hash_similarity = self.compute_hash_similarity(query_hashes, key_hashes)
        
        # Create attention mask based on hash similarity
        # Tokens with similar hashes should attend to each other
        attention_mask = (hash_similarity > 0.5).float()
        
        return attention_mask

class TopologicalAttention(nn.Module):
    """
    Multi-head attention with topological features and LSH
    """
    
    def __init__(self, config: TANConfig):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        self.head_dim = config.embed_dim // config.num_heads
        self.use_topology = config.use_topology
        self.use_lsh = config.use_lsh
        
        # Standard attention projections
        self.q_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.k_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.v_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.out_proj = nn.Linear(config.embed_dim, config.embed_dim)
        
        # Topological attention components
        if self.use_topology:
            self.topology_gate = nn.Sequential(
                nn.Linear(config.embed_dim + config.topology_dim, config.embed_dim),
                nn.Sigmoid()
            )
            
            # Multi-scale k values for different heads
            self.head_k_neighbors = config.multi_scale_k[:config.num_heads]
        
        # LSH for efficiency
        if self.use_lsh:
            self.lsh = LocalitySensitiveHashing(
                config.embed_dim,
                config.num_hashes,
                config.hash_bits,
                config.lsh_temperature
            )
        
        self.dropout = nn.Dropout(config.dropout)
        self.scale = math.sqrt(self.head_dim)
        
    def forward(self, hidden_states: torch.Tensor,
                topology_features: Optional[Dict[str, torch.Tensor]] = None,
                attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass with topological attention
        """
        batch_size, seq_len, _ = hidden_states.shape
        
        # Project to Q, K, V
        queries = self.q_proj(hidden_states)
        keys = self.k_proj(hidden_states)
        values = self.v_proj(hidden_states)
        
        # Reshape for multi-head attention
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / self.scale
        
        # Apply LSH masking if enabled
        if self.use_lsh and topology_features is not None:
            lsh_mask = self.lsh(
                hidden_states,
                hidden_states,
                topology_features.get('features')
            )
            # Convert to attention mask format
            lsh_mask = lsh_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
            attention_scores = attention_scores.masked_fill(lsh_mask == 0, -1e4)
        
        # Apply standard attention mask if provided
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(
                attention_mask.unsqueeze(1).unsqueeze(2) == 0,
                -1e4
            )
        
        # Compute attention probabilities
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        
        # Apply attention to values
        context = torch.matmul(attention_probs, values)
        
        # Reshape back
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        # Apply topological gating if enabled
        if self.use_topology and topology_features is not None:
            topo_features = topology_features['features']
            combined = torch.cat([context, topo_features], dim=-1)
            gate = self.topology_gate(combined)
            context = gate * context + (1 - gate) * hidden_states
        
        # Output projection
        output = self.out_proj(context)
        output = self.dropout(output)
        
        return output

class TANLayer(nn.Module):
    """
    Single layer of the Topological Attention Network
    """
    
    def __init__(self, config: TANConfig):
        super().__init__()
        self.config = config
        
        # Topological feature extractor
        if config.use_topology:
            self.topology_extractor = TopologicalFeatureExtractor(
                config.embed_dim,
                config.k_neighbors,
                config.topology_dim
            )
        
        # Topological attention
        self.attention = TopologicalAttention(config)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim * 4),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embed_dim * 4, config.embed_dim),
            nn.Dropout(config.dropout)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(config.embed_dim)
        self.norm2 = nn.LayerNorm(config.embed_dim)
        
    def forward(self, hidden_states: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass through the layer
        """
        # Extract topological features if enabled
        topology_features = None
        if self.config.use_topology:
            topology_features = self.topology_extractor(hidden_states)
        
        # Self-attention with residual connection
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states = self.attention(hidden_states, topology_features, attention_mask)
        hidden_states = residual + hidden_states
        
        # Feed-forward with residual connection
        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.ffn(hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states

class TAN(nn.Module):
    """
    Topological Attention Network
    Complete model implementation
    """
    
    def __init__(self, config: TANConfig):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.embeddings = nn.Embedding(config.vocab_size, config.embed_dim)
        self.position_embeddings = nn.Embedding(config.max_seq_length, config.embed_dim)
        self.embedding_dropout = nn.Dropout(config.dropout)
        
        # TAN layers
        self.layers = nn.ModuleList([
            TANLayer(config) for _ in range(config.num_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(config.embed_dim)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        """Initialize weights"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def forward(self, input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass through TAN
        
        Args:
            input_ids: Token IDs [batch_size, seq_length]
            attention_mask: Attention mask [batch_size, seq_length]
        
        Returns:
            Dictionary containing:
                - hidden_states: Final hidden states [batch_size, seq_length, embed_dim]
                - all_hidden_states: List of hidden states from each layer
        """
        batch_size, seq_length = input_ids.shape
        
        # Create position IDs
        position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        token_embeddings = self.embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        # Combine embeddings
        hidden_states = token_embeddings + position_embeddings
        hidden_states = self.embedding_dropout(hidden_states)
        
        # Pass through layers
        all_hidden_states = []
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
            all_hidden_states.append(hidden_states)
        
        # Final normalization
        hidden_states = self.final_norm(hidden_states)
        
        return {
            'hidden_states': hidden_states,
            'all_hidden_states': all_hidden_states
        }

class TANForSequenceClassification(nn.Module):
    """
    TAN model for sequence classification tasks
    """
    
    def __init__(self, config: TANConfig, num_labels: int):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        
        # Base TAN model
        self.tan = TAN(config)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim),
            nn.Tanh(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embed_dim, num_labels)
        )
        
    def forward(self, input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass for classification
        """
        # Get TAN outputs
        outputs = self.tan(input_ids, attention_mask)
        hidden_states = outputs['hidden_states']
        
        # Pool the hidden states (use CLS token)
        pooled_output = hidden_states[:, 0, :]
        
        # Get logits
        logits = self.classifier(pooled_output)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            if self.num_labels == 1:
                # Regression
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            else:
                # Classification
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits, labels)
        
        return {
            'loss': loss,
            'logits': logits,
            'hidden_states': hidden_states,
            'all_hidden_states': outputs['all_hidden_states']
        }

class TANForMultiLabelClassification(nn.Module):
    """
    TAN model for multi-label classification tasks (e.g., GoEmotions)
    """
    
    def __init__(self, config: TANConfig, num_labels: int):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        
        # Base TAN model
        self.tan = TAN(config)
        
        # Multi-label classification head
        self.classifier = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embed_dim, num_labels)
        )
        
    def forward(self, input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass for multi-label classification
        """
        # Get TAN outputs
        outputs = self.tan(input_ids, attention_mask)
        hidden_states = outputs['hidden_states']
        
        # Pool the hidden states (use CLS token)
        pooled_output = hidden_states[:, 0, :]
        
        # Get logits
        logits = self.classifier(pooled_output)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels.float())
        
        return {
            'loss': loss,
            'logits': logits,
            'hidden_states': hidden_states,
            'all_hidden_states': outputs['all_hidden_states']
        }

def create_tan_model(task_type: str = 'classification',
                     num_labels: int = 2,
                     **kwargs) -> nn.Module:
    """
    Factory function to create TAN models
    
    Args:
        task_type: Type of task ('classification', 'multi_label', 'base')
        num_labels: Number of labels for classification
        **kwargs: Additional config parameters
    
    Returns:
        TAN model instance
    """
    # Create config
    config = TANConfig(**kwargs)
    
    # Create model based on task type
    if task_type == 'classification':
        model = TANForSequenceClassification(config, num_labels)
    elif task_type == 'multi_label':
        model = TANForMultiLabelClassification(config, num_labels)
    elif task_type == 'base':
        model = TAN(config)
    else:
        raise ValueError(f"Unknown task type: {task_type}")
    
    # Print model information
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Created TAN model for {task_type}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Configuration: {config}")
    
    return model

# Example usage
if __name__ == "__main__":
    # Create model for GoEmotions (27 labels, multi-label)
    model = create_tan_model(
        task_type='multi_label',
        num_labels=100,
        max_seq_length=128,
        use_topology=True,
        use_lsh=True
    )
    
    # Test forward pass
    batch_size = 4
    seq_length = 128
    
    input_ids = torch.randint(0, 50265, (batch_size, seq_length))
    attention_mask = torch.ones(batch_size, seq_length)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        print(f"Output shape: {outputs['logits'].shape}")
        print(f"Hidden states shape: {outputs['hidden_states'].shape}")