#!/usr/bin/env python3
"""
Complete TAN (Topological Attention Network) Implementation for LEDGAR Dataset
Task: Legal Document Categorization/Classification (100 classes)
"""

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
import numpy as np
from pathlib import Path
import json
import time
import logging
import math
from tqdm import tqdm
from sklearn.metrics import (
    f1_score, accuracy_score, hamming_loss, 
    top_k_accuracy_score, classification_report
)
from transformers import AutoTokenizer
from datasets import load_dataset
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('tan_ledgar_training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ================== LEDGAR DATASET WITH EXPLICIT 100 CLASSES ==================

# LEDGAR 100 legal contract provision classes
LEDGAR_CLASSES = [
    "Adjustments", "Agreements", "Amendments", "Anti-Corruption Laws", "Applicable Laws", 
    "Approvals", "Arbitration", "Assignments", "Assigns", "Authority", "Authorizations", 
    "Base Salary", "Benefits", "Binding Effects", "Books", "Brokers", "Capitalization", 
    "Change In Control", "Closings", "Compliance With Laws", "Confidentiality", 
    "Consent To Jurisdiction", "Consents", "Construction", "Cooperation", "Costs", 
    "Counterparts", "Death", "Defined Terms", "Definitions", "Disability", "Disclosures", 
    "Duties", "Effective Dates", "Effectiveness", "Employment", "Enforceability", 
    "Enforcements", "Entire Agreements", "Erisa", "Existence", "Expenses", "Fees", 
    "Financial Statements", "Forfeitures", "Further Assurances", "General", "Governing Laws", 
    "Headings", "Indemnifications", "Indemnity", "Insurances", "Integration", 
    "Intellectual Property", "Interests", "Interpretations", "Jurisdictions", "Liens", 
    "Litigations", "Miscellaneous", "Modifications", "No Conflicts", "No Defaults", 
    "No Waivers", "Non-Disparagement", "Notices", "Organizations", "Participations", 
    "Payments", "Positions", "Powers", "Publicity", "Qualifications", "Records", 
    "Releases", "Remedies", "Representations", "Sales", "Sanctions", "Severability", 
    "Solvency", "Specific Performance", "Submission To Jurisdiction", "Subsidiaries", 
    "Successors", "Survival", "Tax Withholdings", "Taxes", "Terminations", "Terms", 
    "Titles", "Transactions With Affiliates", "Use Of Proceeds", "Vacations", "Venues", 
    "Vesting", "Waiver Of Jury Trials", "Waivers", "Warranties", "Withholdings"
]

assert len(LEDGAR_CLASSES) == 100, f"Expected 100 classes, got {len(LEDGAR_CLASSES)}"

class LEDGARDataset(Dataset):
    """
    LEDGAR dataset for legal document classification
    100 contract provision categories from SEC filings
    """
    
    def __init__(self, split: str = 'train', tokenizer=None, max_length: int = 512, 
                 max_samples: Optional[int] = None):
        """
        Initialize LEDGAR dataset
        
        Args:
            split: Dataset split ('train', 'validation', 'test')
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
            max_samples: Maximum number of samples to load (for debugging)
        """
        if tokenizer is None:
            raise ValueError("Tokenizer is required")
            
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        
        # Load LEDGAR dataset from LEX-GLUE
        try:
            logger.info(f"Loading LEDGAR dataset split: {split}")
            dataset = load_dataset("coastalcph/lex_glue", "ledgar", split=split, trust_remote_code=True)
        except Exception as e:
            logger.error(f"Failed to load LEDGAR dataset: {e}")
            raise
        
        # Limit samples if specified
        original_size = len(dataset)
        if max_samples and len(dataset) > max_samples:
            indices = np.random.choice(len(dataset), max_samples, replace=False)
            dataset = dataset.select(indices)
            logger.info(f"Limited {split} data from {original_size} to {max_samples} samples")
        
        # Process and validate dataset
        self.data = []
        self.label_to_idx = {}
        self.idx_to_label = {}
        unique_labels = set()
        
        logger.info(f"Processing {len(dataset)} {split} samples...")
        
        for idx, item in enumerate(tqdm(dataset, desc=f"Processing {split}")):
            try:
                # Extract text - try different possible field names
                text = None
                for text_field in ['text', 'provision', 'content', 'document']:
                    if text_field in item:
                        text = item[text_field]
                        break
                
                if not text:
                    logger.warning(f"No text found in item {idx}, skipping")
                    continue
                
                # Ensure text is string and not empty
                if not isinstance(text, str) or len(text.strip()) == 0:
                    logger.warning(f"Invalid text in item {idx}, skipping")
                    continue
                
                # Get label
                label = item['label']
                unique_labels.add(label)
                
                # Tokenize text
                try:
                    encoding = self.tokenizer(
                        text,
                        truncation=True,
                        padding='max_length',
                        max_length=self.max_length,
                        return_tensors='pt'
                    )
                except Exception as e:
                    logger.warning(f"Tokenization failed for item {idx}: {e}")
                    continue
                
                self.data.append({
                    'input_ids': encoding['input_ids'].squeeze(0),
                    'attention_mask': encoding['attention_mask'].squeeze(0),
                    'labels': torch.tensor(label, dtype=torch.long),
                    'text': text[:200] + '...' if len(text) > 200 else text  # Store truncated text for debugging
                })
                
            except Exception as e:
                logger.warning(f"Error processing item {idx}: {e}")
                continue
        
        # Validate class distribution
        unique_labels = sorted(list(unique_labels))
        self.num_labels = len(unique_labels)
        
        # Create label mappings
        for idx, label in enumerate(unique_labels):
            self.label_to_idx[label] = idx
            self.idx_to_label[idx] = label
        
        logger.info(f"Dataset Statistics for {split}:")
        logger.info(f"  Total samples: {len(self.data)}")
        logger.info(f"  Unique labels: {self.num_labels}")
        logger.info(f"  Expected labels: 100")
        logger.info(f"  Label range: {min(unique_labels)} - {max(unique_labels)}")
        
        # Validate that we have the expected 100 classes
        if self.num_labels != 100:
            logger.warning(f"Expected 100 classes but found {self.num_labels} in {split} split")
            if split == 'train' and self.num_labels < 100:
                logger.error("Training split missing some classes - this will cause issues!")
        
        # Log class distribution
        if len(self.data) > 0:
            label_counts = {}
            for item in self.data:
                label = item['labels'].item()
                label_counts[label] = label_counts.get(label, 0) + 1
            
            logger.info(f"Class distribution summary for {split}:")
            logger.info(f"  Classes with samples: {len(label_counts)}")
            logger.info(f"  Min samples per class: {min(label_counts.values())}")
            logger.info(f"  Max samples per class: {max(label_counts.values())}")
            logger.info(f"  Avg samples per class: {np.mean(list(label_counts.values())):.1f}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def get_class_names(self):
        """Return list of class names"""
        return [self.idx_to_label.get(i, f"Class_{i}") for i in range(self.num_labels)]

# ================== TAN CONFIGURATION ==================

@dataclass
class TANConfig:
    """Configuration for Topological Attention Network"""
    vocab_size: int = 30522  # BERT vocab size
    embed_dim: int = 768
    num_heads: int = 12
    num_layers: int = 6
    max_seq_length: int = 512
    dropout: float = 0.1
    
    # Topological parameters
    k_neighbors: int = 32
    use_topology: bool = True
    topology_dim: int = 128
    
    # LSH parameters
    use_lsh: bool = True
    num_hashes: int = 8
    hash_bits: int = 64
    lsh_temperature: float = 0.1
    
    # Multi-scale parameters
    multi_scale_k: List[int] = None
    
    def __post_init__(self):
        if self.multi_scale_k is None:
            # Different k values for different heads
            base_k_values = [8, 16, 32, 64]
            self.multi_scale_k = (base_k_values * (self.num_heads // len(base_k_values) + 1))[:self.num_heads]

# ================== TAN ARCHITECTURE IMPLEMENTATION ==================

class TopologicalFeatureExtractor(nn.Module):
    """Extract topological features using k-NN graphs and persistent homology concepts"""
    
    def __init__(self, embed_dim: int, k_neighbors: int, topology_dim: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.k_neighbors = k_neighbors
        self.topology_dim = topology_dim
        
        # Learnable projection for topological features
        self.topology_proj = nn.Linear(embed_dim, topology_dim)
        
        # Multi-layer encoder for topological features
        self.topology_encoder = nn.Sequential(
            nn.Linear(topology_dim, topology_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(topology_dim * 2, topology_dim),
            nn.LayerNorm(topology_dim)
        )
        
        # Persistence landscape approximation
        self.persistence_mlp = nn.Sequential(
            nn.Linear(topology_dim, topology_dim),
            nn.ReLU(),
            nn.Linear(topology_dim, topology_dim)
        )
    
    def compute_knn_graph(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute k-nearest neighbor graph"""
        batch_size, seq_len, _ = embeddings.shape
        
        # Compute pairwise distances using cosine similarity
        embeddings_norm = embeddings / (embeddings.norm(dim=-1, keepdim=True) + 1e-8)
        similarity = torch.matmul(embeddings_norm, embeddings_norm.transpose(-2, -1))
        distances = 1 - similarity
        
        # Add small noise to break ties
        distances = distances + torch.randn_like(distances) * 1e-6
        
        # Set diagonal to infinity (exclude self-connections)
        mask = torch.eye(seq_len, device=distances.device).unsqueeze(0).expand(batch_size, -1, -1)
        distances = distances.masked_fill(mask.bool(), float('inf'))
        
        # Get top-k neighbors
        k = min(self.k_neighbors, seq_len - 1)
        if k <= 0:
            # Handle edge case where sequence is too short
            k = max(1, seq_len - 1)
        
        neighbor_distances, neighbor_indices = torch.topk(distances, k, dim=-1, largest=False)
        
        return neighbor_distances, neighbor_indices
    
    def extract_topological_features(self, embeddings: torch.Tensor,
                                    neighbor_distances: torch.Tensor,
                                    neighbor_indices: torch.Tensor) -> torch.Tensor:
        """Extract topological features from the k-NN graph"""
        batch_size, seq_len, _ = embeddings.shape
        k = neighbor_indices.shape[-1]
        
        # Project embeddings to topology space
        topo_embeddings = self.topology_proj(embeddings)
        
        # Gather neighbor features
        batch_indices = torch.arange(batch_size, device=embeddings.device).view(-1, 1, 1)
        batch_indices = batch_indices.expand(-1, seq_len, k)
        
        neighbor_features = topo_embeddings[batch_indices, neighbor_indices]
        
        # Weight neighbors by inverse distance
        weights = F.softmax(-neighbor_distances, dim=-1).unsqueeze(-1)
        weighted_neighbors = (neighbor_features * weights).sum(dim=2)
        
        # Combine with original features
        combined = topo_embeddings + weighted_neighbors
        
        # Encode topological features
        topo_features = self.topology_encoder(combined)
        
        # Approximate persistence landscape
        persistence_features = self.persistence_mlp(topo_features)
        
        return persistence_features
    
    def forward(self, embeddings: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Forward pass to extract topological features"""
        # Compute k-NN graph
        distances, indices = self.compute_knn_graph(embeddings)
        
        # Extract topological features
        topo_features = self.extract_topological_features(embeddings, distances, indices)
        
        return {
            'features': topo_features,
            'distances': distances,
            'indices': indices
        }

class LocalitySensitiveHashing(nn.Module):
    """LSH module for efficient attention computation guided by topological features"""
    
    def __init__(self, embed_dim: int, num_hashes: int, hash_bits: int, temperature: float):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_hashes = num_hashes
        self.hash_bits = hash_bits
        self.temperature = temperature
        
        # Random projection matrices for LSH
        self.register_buffer('hash_proj', torch.randn(num_hashes, embed_dim, hash_bits) * 0.02)
        
        # Learnable bias for topology-guided hashing
        self.topology_bias = nn.Parameter(torch.zeros(num_hashes, hash_bits))
    
    def hash_vectors(self, vectors: torch.Tensor, 
                    topology_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Hash vectors using LSH with optional topology guidance"""
        batch_size, seq_len, _ = vectors.shape
        
        # Project vectors
        projections = torch.einsum('bsd,hdk->bshk', vectors, self.hash_proj)
        
        # Add topology-guided bias if available
        if topology_features is not None:
            topo_dim = topology_features.size(-1)
            if topo_dim <= self.embed_dim:
                # Project topology features to hash space
                topo_proj_matrix = self.hash_proj[:, :topo_dim, :]
                topo_proj = torch.einsum('bsd,hdk->bshk', topology_features, topo_proj_matrix)
                projections = projections + self.temperature * topo_proj
        
        # Add learnable bias
        projections = projections + self.topology_bias.unsqueeze(0).unsqueeze(1)
        
        # Compute hash codes
        hash_codes = torch.sign(projections)
        
        return hash_codes
    
    def compute_hash_similarity(self, query_hashes: torch.Tensor, 
                              key_hashes: torch.Tensor) -> torch.Tensor:
        """Compute similarity between query and key hash codes"""
        # Hamming similarity between queries and keys
        similarity = torch.einsum('bqhk,bkhk->bhqk', query_hashes, key_hashes) / self.hash_bits
        
        # Average over hash functions
        similarity = similarity.mean(dim=1)  # [batch, seq_len, seq_len]
        
        return similarity
    
    def forward(self, queries: torch.Tensor, keys: torch.Tensor,
                topology_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Compute LSH-based attention mask"""
        # Hash queries and keys
        query_hashes = self.hash_vectors(queries, topology_features)
        key_hashes = self.hash_vectors(keys, topology_features)
        
        # Compute hash similarity
        hash_similarity = self.compute_hash_similarity(query_hashes, key_hashes)
        
        # Create attention mask based on hash similarity
        attention_mask = (hash_similarity > 0.3).float()
        
        return attention_mask

class TopologicalAttention(nn.Module):
    """Multi-head attention with topological features and LSH"""
    
    def __init__(self, config: TANConfig):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        self.head_dim = config.embed_dim // config.num_heads
        self.use_topology = config.use_topology
        self.use_lsh = config.use_lsh
        
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # Standard attention projections
        self.q_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
        self.k_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
        self.v_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
        self.out_proj = nn.Linear(config.embed_dim, config.embed_dim)
        
        # Topological attention components
        if self.use_topology:
            self.topology_gate = nn.Sequential(
                nn.Linear(config.embed_dim + config.topology_dim, config.embed_dim),
                nn.Sigmoid()
            )
            
            # Multi-scale k values for different heads
            self.head_k_neighbors = config.multi_scale_k[:config.num_heads]
        
        # LSH for efficiency
        if self.use_lsh:
            self.lsh = LocalitySensitiveHashing(
                config.embed_dim,
                config.num_hashes,
                config.hash_bits,
                config.lsh_temperature
            )
        
        self.dropout = nn.Dropout(config.dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, hidden_states: torch.Tensor,
                topology_features: Optional[Dict[str, torch.Tensor]] = None,
                attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass with topological attention"""
        batch_size, seq_len, _ = hidden_states.shape
        
        # Project to Q, K, V
        queries = self.q_proj(hidden_states)
        keys = self.k_proj(hidden_states)
        values = self.v_proj(hidden_states)
        
        # Reshape for multi-head attention
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / self.scale
        
        # Apply LSH masking if enabled
        if self.use_lsh and topology_features is not None:
            try:
                lsh_mask = self.lsh(
                    hidden_states,
                    hidden_states,
                    topology_features.get('features')
                )
                # Convert to attention mask format
                lsh_mask = lsh_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
                attention_scores = attention_scores.masked_fill(lsh_mask == 0, -1e4)
            except Exception as e:
                logger.warning(f"LSH masking failed: {e}")
        
        # Apply standard attention mask if provided
        if attention_mask is not None:
            # Convert attention_mask to proper format
            # attention_mask: [batch_size, seq_len] with 1 for real tokens, 0 for padding
            mask_expanded = attention_mask.unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, seq_len]
            mask_expanded = mask_expanded.expand(batch_size, self.num_heads, seq_len, seq_len)
            attention_scores = attention_scores.masked_fill(mask_expanded == 0, -1e4)
        
        # Compute attention probabilities
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        
        # Apply attention to values
        context = torch.matmul(attention_probs, values)
        
        # Reshape back
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        # Apply topological gating if enabled
        if self.use_topology and topology_features is not None:
            try:
                topo_features = topology_features['features']
                combined = torch.cat([context, topo_features], dim=-1)
                gate = self.topology_gate(combined)
                context = gate * context + (1 - gate) * hidden_states
            except Exception as e:
                logger.warning(f"Topological gating failed: {e}")
        
        # Output projection
        output = self.out_proj(context)
        output = self.dropout(output)
        
        return output

class TANLayer(nn.Module):
    """Single layer of the Topological Attention Network"""
    
    def __init__(self, config: TANConfig):
        super().__init__()
        self.config = config
        
        # Topological feature extractor
        if config.use_topology:
            self.topology_extractor = TopologicalFeatureExtractor(
                config.embed_dim,
                config.k_neighbors,
                config.topology_dim
            )
        
        # Topological attention
        self.attention = TopologicalAttention(config)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim * 4),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embed_dim * 4, config.embed_dim),
            nn.Dropout(config.dropout)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(config.embed_dim)
        self.norm2 = nn.LayerNorm(config.embed_dim)
    
    def forward(self, hidden_states: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass through the layer"""
        # Extract topological features if enabled
        topology_features = None
        if self.config.use_topology:
            try:
                topology_features = self.topology_extractor(hidden_states)
            except Exception as e:
                logger.warning(f"Topology extraction failed: {e}")
        
        # Self-attention with residual connection
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states = self.attention(hidden_states, topology_features, attention_mask)
        hidden_states = residual + hidden_states
        
        # Feed-forward with residual connection
        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.ffn(hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states

class TAN(nn.Module):
    """Topological Attention Network for Legal Document Classification"""
    
    def __init__(self, config: TANConfig):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.embeddings = nn.Embedding(config.vocab_size, config.embed_dim)
        self.position_embeddings = nn.Embedding(config.max_seq_length, config.embed_dim)
        self.embedding_dropout = nn.Dropout(config.dropout)
        self.embedding_norm = nn.LayerNorm(config.embed_dim)
        
        # TAN layers
        self.layers = nn.ModuleList([
            TANLayer(config) for _ in range(config.num_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(config.embed_dim)
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Initialize weights following BERT initialization"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def forward(self, input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Forward pass through TAN"""
        batch_size, seq_length = input_ids.shape
        device = input_ids.device
        
        # Create position IDs
        position_ids = torch.arange(seq_length, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        token_embeddings = self.embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        # Combine embeddings
        hidden_states = token_embeddings + position_embeddings
        hidden_states = self.embedding_norm(hidden_states)
        hidden_states = self.embedding_dropout(hidden_states)
        
        # Pass through layers
        all_hidden_states = [hidden_states]
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
            all_hidden_states.append(hidden_states)
        
        # Final normalization
        hidden_states = self.final_norm(hidden_states)
        
        return {
            'hidden_states': hidden_states,
            'all_hidden_states': all_hidden_states
        }

class TANForLegalClassification(nn.Module):
    """TAN model for legal document classification"""
    
    def __init__(self, config: TANConfig, num_labels: int = 100):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        
        # Base TAN model
        self.tan = TAN(config)
        
        # Classification head with multiple pooling strategies
        self.pooler = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim),
            nn.Tanh(),
            nn.Dropout(config.dropout)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.embed_dim // 2, num_labels)
        )
        
        logger.info(f"TAN model initialized with {sum(p.numel() for p in self.parameters())/1e6:.1f}M parameters")
    
    def forward(self, input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Forward pass for classification"""
        # Get TAN outputs
        outputs = self.tan(input_ids, attention_mask)
        hidden_states = outputs['hidden_states']
        
        # Pool the hidden states using attention-weighted pooling
        if attention_mask is not None:
            # Use attention mask for pooling
            mask_expanded = attention_mask.unsqueeze(-1).float()
            sum_embeddings = (hidden_states * mask_expanded).sum(1)
            sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
            pooled_output = sum_embeddings / sum_mask
        else:
            # Use CLS token (first token) for pooling
            pooled_output = hidden_states[:, 0, :]
        
        # Apply pooler
        pooled_output = self.pooler(pooled_output)
        
        # Get logits
        logits = self.classifier(pooled_output)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {
            'loss': loss,
            'logits': logits,
            'hidden_states': hidden_states,
            'pooled_output': pooled_output
        }

# ================== TRAINER CLASS WITH COMPREHENSIVE METRICS ==================

class TANTrainer:
    """
    Comprehensive trainer for TAN model on LEDGAR legal document classification
    Includes all requested metrics and best practices
    """
    
    def __init__(self, model: TANForLegalClassification, device: torch.device = None, 
                 model_name: str = "TAN-Legal"):
        """
        Initialize trainer
        
        Args:
            model: TAN model for legal classification
            device: Training device
            model_name: Name for saving/logging
        """
        self.model = model
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_name = model_name
        
        # Move model to device
        self.model = self.model.to(self.device)
        
        # Mixed precision training
        self.scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
        
        # Training history
        self.history = {
            'train_loss': [],
            'train_accuracy': [],
            'val_loss': [],
            'val_accuracy': [],
            'val_f1_micro': [],
            'val_f1_macro': [],
            'val_hamming_loss': [],
            'val_top5_accuracy': [],
            'learning_rates': [],
            'epoch_times': []
        }
        
        # Best metrics tracking
        self.best_metrics = {
            'f1_macro': 0.0,
            'f1_micro': 0.0,
            'accuracy': 0.0,
            'epoch': 0
        }
        
        logger.info(f"Trainer initialized for {model_name}")
        logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M")
        logger.info(f"Training device: {self.device}")
    
    def compute_metrics(self, predictions: np.ndarray, labels: np.ndarray, 
                       probabilities: np.ndarray = None) -> Dict[str, float]:
        """
        Compute comprehensive metrics for legal document classification
        
        Args:
            predictions: Predicted class labels
            labels: True class labels
            probabilities: Class probabilities for top-k metrics
            
        Returns:
            Dictionary of metrics
        """
        metrics = {}
        
        # Basic accuracy
        metrics['accuracy'] = accuracy_score(labels, predictions)
        
        # F1 scores
        metrics['f1_micro'] = f1_score(labels, predictions, average='micro', zero_division=0)
        metrics['f1_macro'] = f1_score(labels, predictions, average='macro', zero_division=0)
        metrics['f1_weighted'] = f1_score(labels, predictions, average='weighted', zero_division=0)
        
        # Hamming loss (fraction of wrong labels)
        metrics['hamming_loss'] = hamming_loss(labels, predictions)
        
        # Top-k accuracy if probabilities provided
        if probabilities is not None and probabilities.shape[1] >= 5:
            try:
                metrics['top5_accuracy'] = top_k_accuracy_score(
                    labels, probabilities, k=5,
                    labels=list(range(probabilities.shape[1]))
                )
            except Exception as e:
                logger.warning(f"Could not compute top-5 accuracy: {e}")
                # Fallback manual computation
                top5_correct = 0
                for i in range(len(labels)):
                    top5_preds = np.argsort(probabilities[i])[-5:]
                    if labels[i] in top5_preds:
                        top5_correct += 1
                metrics['top5_accuracy'] = top5_correct / len(labels)
        
        return metrics
    
    def train_epoch(self, train_loader: DataLoader, optimizer, scheduler, 
                   epoch: int, num_epochs: int) -> Dict[str, float]:
        """Train for one epoch"""
        self.model.train()
        
        total_loss = 0.0
        all_predictions = []
        all_labels = []
        num_batches = len(train_loader)
        
        # Progress bar
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        
        for batch_idx, batch in enumerate(pbar):
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(self.device, non_blocking=True)
            labels = batch['labels'].to(self.device, non_blocking=True)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            if self.scaler:
                with torch.cuda.amp.autocast():
                    outputs = self.model(input_ids, attention_mask, labels)
                    loss = outputs['loss']
                
                # Backward pass
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs['loss']
                
                # Backward pass
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
            
            # Update scheduler
            if scheduler:
                scheduler.step()
                current_lr = scheduler.get_last_lr()[0]
            else:
                current_lr = optimizer.param_groups[0]['lr']
            
            # Accumulate metrics
            total_loss += loss.item()
            
            # Get predictions
            with torch.no_grad():
                predictions = torch.argmax(outputs['logits'], dim=-1)
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            avg_loss = total_loss / (batch_idx + 1)
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{avg_loss:.4f}',
                'lr': f'{current_lr:.2e}'
            })
        
        # Compute epoch metrics
        epoch_metrics = {
            'loss': total_loss / num_batches,
            'accuracy': accuracy_score(all_labels, all_predictions),
            'learning_rate': current_lr
        }
        
        return epoch_metrics
    
    def evaluate(self, dataloader: DataLoader, split_name: str = "Validation") -> Dict[str, float]:
        """Evaluate the model"""
        self.model.eval()
        
        total_loss = 0.0
        all_predictions = []
        all_labels = []
        all_probabilities = []
        
        num_batches = len(dataloader)
        
        with torch.no_grad():
            pbar = tqdm(dataloader, desc=f'{split_name} Evaluation')
            
            for batch in pbar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(self.device, non_blocking=True)
                labels = batch['labels'].to(self.device, non_blocking=True)
                
                # Forward pass
                if self.scaler:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(input_ids, attention_mask, labels)
                else:
                    outputs = self.model(input_ids, attention_mask, labels)
                
                # Accumulate loss
                if outputs['loss'] is not None:
                    total_loss += outputs['loss'].item()
                
                # Get predictions and probabilities
                logits = outputs['logits']
                probabilities = F.softmax(logits, dim=-1)
                predictions = torch.argmax(logits, dim=-1)
                
                # Store results
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probabilities.append(probabilities.cpu().numpy())
                
                # Update progress bar
                pbar.set_postfix({'avg_loss': f'{total_loss/(len(all_predictions)//batch["input_ids"].size(0)):.4f}'})
        
        # Concatenate probabilities
        all_probabilities = np.concatenate(all_probabilities, axis=0)
        all_labels = np.array(all_labels)
        all_predictions = np.array(all_predictions)
        
        # Compute comprehensive metrics
        metrics = self.compute_metrics(all_predictions, all_labels, all_probabilities)
        metrics['loss'] = total_loss / num_batches if num_batches > 0 else 0.0
        
        return metrics
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader, 
              num_epochs: int = 10, learning_rate: float = 2e-5, 
              weight_decay: float = 0.01, warmup_ratio: float = 0.1,
              save_strategy: str = "best_f1_macro") -> Dict[str, float]:
        """
        Train the TAN model
        
        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            num_epochs: Number of training epochs
            learning_rate: Learning rate
            weight_decay: Weight decay for regularization
            warmup_ratio: Warmup ratio for scheduler
            save_strategy: When to save model ("best_f1_macro", "best_accuracy", "every_epoch")
            
        Returns:
            Best validation metrics
        """
        logger.info(f"Starting training for {num_epochs} epochs")
        logger.info(f"Task: Legal Document Categorization/Classification (100 classes)")
        logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}")
        logger.info(f"Warmup ratio: {warmup_ratio}, Save strategy: {save_strategy}")
        
        # Setup optimizer and scheduler
        optimizer = AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=1e-8
        )
        
        total_steps = len(train_loader) * num_epochs
        num_warmup_steps = int(warmup_ratio * total_steps)
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=total_steps
        )
        
        logger.info(f"Total training steps: {total_steps}")
        logger.info(f"Warmup steps: {num_warmup_steps}")
        
        # Training loop
        start_time = time.time()
        
        for epoch in range(num_epochs):
            epoch_start = time.time()
            
            # Train epoch
            train_metrics = self.train_epoch(train_loader, optimizer, scheduler, epoch, num_epochs)
            
            # Validation
            val_metrics = self.evaluate(val_loader, "Validation")
            
            # Update history
            epoch_time = time.time() - epoch_start
            self.history['train_loss'].append(train_metrics['loss'])
            self.history['train_accuracy'].append(train_metrics['accuracy'])
            self.history['val_loss'].append(val_metrics['loss'])
            self.history['val_accuracy'].append(val_metrics['accuracy'])
            self.history['val_f1_micro'].append(val_metrics['f1_micro'])
            self.history['val_f1_macro'].append(val_metrics['f1_macro'])
            self.history['val_hamming_loss'].append(val_metrics['hamming_loss'])
            self.history['val_top5_accuracy'].append(val_metrics.get('top5_accuracy', 0.0))
            self.history['learning_rates'].append(train_metrics['learning_rate'])
            self.history['epoch_times'].append(epoch_time)
            
            # Log results
            logger.info(f"\nEpoch {epoch+1}/{num_epochs} completed in {epoch_time:.2f}s")
            logger.info(f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}")
            logger.info(f"Val - Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")
            logger.info(f"Val - F1-Micro: {val_metrics['f1_micro']:.4f}, F1-Macro: {val_metrics['f1_macro']:.4f}")
            logger.info(f"Val - Hamming Loss: {val_metrics['hamming_loss']:.4f}")
            if 'top5_accuracy' in val_metrics:
                logger.info(f"Val - Top-5 Accuracy: {val_metrics['top5_accuracy']:.4f}")
            
            # Check if this is the best model
            save_model = False
            if save_strategy == "best_f1_macro" and val_metrics['f1_macro'] > self.best_metrics['f1_macro']:
                self.best_metrics.update({
                    'f1_macro': val_metrics['f1_macro'],
                    'f1_micro': val_metrics['f1_micro'],
                    'accuracy': val_metrics['accuracy'],
                    'epoch': epoch + 1
                })
                save_model = True
            elif save_strategy == "best_accuracy" and val_metrics['accuracy'] > self.best_metrics['accuracy']:
                self.best_metrics.update({
                    'f1_macro': val_metrics['f1_macro'],
                    'f1_micro': val_metrics['f1_micro'],
                    'accuracy': val_metrics['accuracy'],
                    'epoch': epoch + 1
                })
                save_model = True
            elif save_strategy == "every_epoch":
                save_model = True
            
            # Save model if needed
            if save_model:
                self.save_model(f'best_{self.model_name}_LEDGAR.pt')
                logger.info(f"✓ Saved best model with F1-Macro: {val_metrics['f1_macro']:.4f}")
        
        total_time = time.time() - start_time
        logger.info(f"\nTraining completed in {total_time/3600:.2f} hours")
        logger.info(f"Best metrics achieved at epoch {self.best_metrics['epoch']}:")
        logger.info(f"  F1-Macro: {self.best_metrics['f1_macro']:.4f}")
        logger.info(f"  F1-Micro: {self.best_metrics['f1_micro']:.4f}")
        logger.info(f"  Accuracy: {self.best_metrics['accuracy']:.4f}")
        
        return self.best_metrics
    
    def save_model(self, filepath: str):
        """Save model checkpoint with comprehensive information"""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'model_config': self.model.config,
            'model_name': self.model_name,
            'num_labels': self.model.num_labels,
            'history': self.history,
            'best_metrics': self.best_metrics,
            'task_type': 'Legal Document Categorization/Classification',
            'dataset': 'LEDGAR-100',
            'classes': LEDGAR_CLASSES
        }
        
        try:
            torch.save(checkpoint, filepath)
            logger.info(f"Model saved to {filepath}")
        except Exception as e:
            logger.error(f"Failed to save model: {e}")
    
    def load_model(self, filepath: str):
        """Load model checkpoint"""
        try:
            checkpoint = torch.load(filepath, map_location=self.device, weights_only=False)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.history = checkpoint.get('history', self.history)
            self.best_metrics = checkpoint.get('best_metrics', self.best_metrics)
            logger.info(f"Model loaded from {filepath}")
            return checkpoint
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise

# ================== MAIN TRAINING SCRIPT ==================

def create_data_loaders(tokenizer, batch_size: int = 16, max_length: int = 512,
                       max_samples_train: int = None, max_samples_val: int = None,
                       max_samples_test: int = None) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create data loaders for LEDGAR dataset"""
    
    logger.info("Creating LEDGAR datasets...")
    
    # Create datasets
    train_dataset = LEDGARDataset('train', tokenizer, max_length, max_samples_train)
    val_dataset = LEDGARDataset('validation', tokenizer, max_length, max_samples_val)
    test_dataset = LEDGARDataset('test', tokenizer, max_length, max_samples_test)
    
    # Verify consistent number of labels
    datasets = [train_dataset, val_dataset, test_dataset]
    dataset_names = ['train', 'validation', 'test']
    
    for dataset, name in zip(datasets, dataset_names):
        logger.info(f"{name.capitalize()} dataset: {len(dataset)} samples, {dataset.num_labels} classes")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader, train_dataset.num_labels

def main():
    """Main training function"""
    logger.info("="*80)
    logger.info("TAN (Topological Attention Network) Training")
    logger.info("Task: Legal Document Categorization/Classification")
    logger.info("Dataset: LEDGAR (100 Legal Contract Provision Classes)")
    logger.info("="*80)
    
    # Configuration
    config = {
        # Model parameters
        'embed_dim': 768,
        'num_heads': 12,
        'num_layers': 6,
        'max_seq_length': 512,
        'dropout': 0.1,
        
        # Topological parameters
        'k_neighbors': 32,
        'use_topology': True,
        'topology_dim': 128,
        
        # LSH parameters
        'use_lsh': True,
        'num_hashes': 8,
        'hash_bits': 64,
        'lsh_temperature': 0.1,
        
        # Training parameters
        'batch_size': 16,
        'num_epochs': 10,
        'learning_rate': 2e-5,
        'weight_decay': 0.01,
        'warmup_ratio': 0.1,
        
        # Data parameters
        'max_samples_train': 40000,  # Limit for faster training
        'max_samples_val': 5000,
        'max_samples_test': 5000
    }
    
    logger.info("Configuration:")
    for key, value in config.items():
        logger.info(f"  {key}: {value}")
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    if torch.cuda.is_available():
        logger.info(f"GPU: {torch.cuda.get_device_name()}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
    
    # Initialize tokenizer
    logger.info("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    # Add special tokens if needed
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Create data loaders
    train_loader, val_loader, test_loader, num_labels = create_data_loaders(
        tokenizer,
        batch_size=config['batch_size'],
        max_length=config['max_seq_length'],
        max_samples_train=config['max_samples_train'],
        max_samples_val=config['max_samples_val'],
        max_samples_test=config['max_samples_test']
    )
    
    logger.info(f"Number of classes detected: {num_labels}")
    
    # Verify we have 100 classes
    if num_labels != 100:
        logger.warning(f"Expected 100 classes but found {num_labels}")
        logger.warning("This may indicate issues with the dataset or preprocessing")
    
    # Create TAN configuration
    tan_config = TANConfig(
        vocab_size=tokenizer.vocab_size,
        embed_dim=config['embed_dim'],
        num_heads=config['num_heads'],
        num_layers=config['num_layers'],
        max_seq_length=config['max_seq_length'],
        dropout=config['dropout'],
        k_neighbors=config['k_neighbors'],
        use_topology=config['use_topology'],
        topology_dim=config['topology_dim'],
        use_lsh=config['use_lsh'],
        num_hashes=config['num_hashes'],
        hash_bits=config['hash_bits'],
        lsh_temperature=config['lsh_temperature']
    )
    
    logger.info(f"TAN Configuration: {tan_config}")
    
    # Create model
    logger.info("Initializing TAN model...")
    model = TANForLegalClassification(tan_config, num_labels=num_labels)
    
    # Create trainer
    trainer = TANTrainer(model, device, "TAN-Legal")
    
    # Start training
    logger.info("Starting training...")
    best_metrics = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['num_epochs'],
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        warmup_ratio=config['warmup_ratio'],
        save_strategy="best_f1_macro"
    )
    
    # Load best model and evaluate on test set
    logger.info("\nEvaluating on test set...")
    trainer.load_model('best_TAN-Legal_LEDGAR.pt')
    test_metrics = trainer.evaluate(test_loader, "Test")
    
    # Final results
    logger.info("\n" + "="*80)
    logger.info("FINAL RESULTS - Legal Document Categorization/Classification")
    logger.info("="*80)
    logger.info(f"Task Type: Legal Document Categorization/Classification")
    logger.info(f"Dataset: LEDGAR (100 contract provision classes)")
    logger.info(f"Model: TAN (Topological Attention Network)")
    logger.info("\nTest Set Results:")
    logger.info(f"  Accuracy: {test_metrics['accuracy']:.4f}")
    logger.info(f"  F1-Micro: {test_metrics['f1_micro']:.4f}")
    logger.info(f"  F1-Macro: {test_metrics['f1_macro']:.4f}")
    logger.info(f"  F1-Weighted: {test_metrics['f1_weighted']:.4f}")
    logger.info(f"  Hamming Loss: {test_metrics['hamming_loss']:.4f}")
    if 'top5_accuracy' in test_metrics:
        logger.info(f"  Top-5 Accuracy: {test_metrics['top5_accuracy']:.4f}")
    
    # Save final results
    results = {
        'task_type': 'Legal Document Categorization/Classification',
        'dataset': 'LEDGAR-100',
        'model': 'TAN',
        'config': config,
        'best_validation_metrics': best_metrics,
        'test_metrics': test_metrics,
        'training_history': trainer.history,
        'total_parameters': sum(p.numel() for p in model.parameters()),
        'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
    }
    
    with open('TAN_LEDGAR_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    logger.info("Results saved to TAN_LEDGAR_results.json")
    logger.info("Training completed successfully!")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        raise