import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import copy
import numpy as np
from model.resnet import ResNet18

logger = logging.getLogger('GFedCL')

class Identity(nn.Module):
    """Simple identity module"""
    def __init__(self):
        super(Identity, self).__init__()
    
    def forward(self, x):
        return x
#-----------------------------
# For GFedCL - Updated for EMNIST
#-----------------------------
class GNet(nn.Module):
    """
    Graph Network - takes client relationship vector and generates embedding
    """
    def __init__(self, opt):
        super(GNet, self).__init__()
        self.num_clients = opt.num_clients
        self.hidden_dim = opt.nh
        self.output_dim = opt.nt
        
        # Simple network with minimal operations
        self.net = nn.Sequential(
            nn.Linear(self.num_clients, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim)
        )
    
    def forward(self, x):
        # Clone input to avoid any in-place operations
        x_copy = x.clone()
        
        # Ensure input is float32
        if x_copy.dtype != torch.float32:
            x_copy = x_copy.float()
        
        # Always use 2D tensors for network operations
        if x_copy.dim() > 2:
            batch_shape = x_copy.shape[:-1]
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
            output = self.net(x_flat)
            # Reshape back to original batch dimensions
            return output.reshape(*batch_shape, self.output_dim)
        else:
            return self.net(x_copy)
        
class FeatureEncoder(nn.Module):
    """
    Feature Encoder Network - processes input data, labels, and graph embedding
    Updated for EMNIST-letter (grayscale 28x28 images)
    """
    def __init__(self, opt):
        super(FeatureEncoder, self).__init__()
        # Configuration - UPDATED for EMNIST
        self.input_dim = 1 * 28 * 28  # EMNIST dimensions (grayscale 28x28)
        self.hidden_dim = opt.nh
        self.graph_dim = opt.nt
        self.num_classes = opt.num_classes
        
        # Encoder for data
        self.data_encoder = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU()
        )
        
        # Label embedding
        self.label_embedding = nn.Embedding(self.num_classes, self.hidden_dim // 4)
        
        # Graph embedding processor
        self.graph_processor = nn.Sequential(
            nn.Linear(self.graph_dim, self.hidden_dim // 4),
            nn.ReLU()
        )
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(self.hidden_dim + self.hidden_dim // 4 + self.hidden_dim // 4, self.hidden_dim),
            nn.ReLU()
        )
    
    def forward(self, x, labels, graph_embed=None):
        """
        Forward pass through the encoder with proper type handling
        
        Args:
            x: Input data [batch_size, channels, height, width] or [batch_size, input_dim]
            labels: Class labels [batch_size] - integer tensor, don't require gradients
            graph_embed: Graph embedding [batch_size, embed_dim] (optional)
        
        Returns:
            latent: Encoded latent representation [batch_size, nh]
        """
        # Make copies of data tensors without trying to require gradients on labels
        x_copy = x.clone()
        # For labels, just clone without requiring gradients
        labels_copy = labels.clone().detach()
        
        # Always work with flattened input
        batch_size = x_copy.size(0)
        if x_copy.dim() > 2:
            x_flat = x_copy.reshape(batch_size, -1)
        else:
            x_flat = x_copy
        
        # Process data
        data_features = self.data_encoder(x_flat)
        
        # Process labels - don't require gradients for label operations
        if labels_copy.dim() == 1:
            label_features = self.label_embedding(labels_copy.long())
        else:
            _, indices = torch.max(labels_copy, dim=1)
            label_features = self.label_embedding(indices.long())
        
        # Process graph embedding if provided
        if graph_embed is not None:
            graph_copy = graph_embed.clone()
            # Expand graph embedding if needed
            if graph_copy.size(0) == 1 and batch_size > 1:
                graph_expanded = graph_copy.expand(batch_size, -1)
            else:
                graph_expanded = graph_copy
            graph_features = self.graph_processor(graph_expanded)
        else:
            # Use zeros if no graph embedding provided
            graph_features = torch.zeros(batch_size, self.hidden_dim // 4, device=x_copy.device)
            
        # Combine features and generate latent representation
        combined = torch.cat([data_features, label_features, graph_features], dim=1)
        latent = self.fusion(combined)
        
        return latent

class GraphDNet(nn.Module):
    """
    Graph Discriminator - reconstructs graph embedding from encoder latent space
    """
    def __init__(self, opt):
        super(GraphDNet, self).__init__()
        self.input_dim = opt.nh
        self.hidden_dim = opt.nh
        self.output_dim = opt.nt
        
        # Simple network with minimal operations
        self.net = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim)
        )
    
    def forward(self, x):
        # Clone input to avoid in-place modifications
        x_copy = x.clone()
        
        # Always use 2D tensors for network operations
        if x_copy.dim() > 2:
            batch_shape = x_copy.shape[:-1]
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
            output = self.net(x_flat)
            # Reshape back to original batch dimensions
            return output.reshape(*batch_shape, self.output_dim)
        else:
            return self.net(x_copy)

class PredNet(nn.Module):
    """
    Prediction Network - classifies encoded features
    Updated for EMNIST-letter (26 classes)
    """
    def __init__(self, opt):
        super(PredNet, self).__init__()
        self.input_dim = opt.ni
        self.hidden_dim = opt.nh
        self.num_classes = opt.nc  # 26 for EMNIST-letter
        
        # Simple classifier with minimal operations
        self.net = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.num_classes)
        )
    
    def forward(self, x, return_softmax=False):
        # Clone input to avoid in-place operations
        x_copy = x.clone()
        
        # Always use 2D tensors for network operations
        original_shape = x_copy.shape
        if x_copy.dim() > 2:
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
        else:
            x_flat = x_copy
        
        # Forward pass
        logits = self.net(x_flat)
        
        # Get softmax probabilities
        softmax_probs = F.softmax(logits, dim=1)
        
        # Get log probabilities (add small epsilon to avoid log(0))
        log_probs = torch.log(softmax_probs + 1e-10)
        
        # Reshape outputs if needed
        if x_copy.dim() > 2:
            new_shape = original_shape[:-1] + (self.num_classes,)
            log_probs = log_probs.reshape(*new_shape)
            softmax_probs = softmax_probs.reshape(*new_shape)
        
        if return_softmax:
            return log_probs, softmax_probs
        else:
            return log_probs

class ResNetPredNet(nn.Module):
    """
    Prediction Network using a smaller ResNet architecture for classification
    Replaces the ResNet1001PredNet with a more efficient implementation
    """
    def __init__(self, opt):
        super(ResNetPredNet, self).__init__()
        self.input_dim = opt.nh
        self.hidden_dim = opt.nh
        self.num_classes = opt.nc
        self.device = opt.device
        
        # Projection network to convert 1D feature vectors to 3D inputs for ResNet
        self.proj_layer = nn.Sequential(
            nn.Linear(self.input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 3 * 32 * 32),
            nn.ReLU()
        )
        
        # Initialize ResNet18 - much more efficient than ResNet1001
        self.resnet = ResNet18(num_classes=self.num_classes, input_channels=3)
        
    def forward(self, x, return_softmax=False):
        """
        Forward pass through the ResNet predictor
        
        Args:
            x: Input feature tensor [batch_size, input_dim]
            return_softmax: Whether to return softmax probabilities in addition to log probabilities
            
        Returns:
            log_probs: Log probabilities [batch_size, num_classes]
            softmax_probs: Softmax probabilities [batch_size, num_classes] (if return_softmax=True)
        """
        # Clone input to avoid in-place operations
        x_copy = x.clone()
        
        # Handle different input dimensions
        original_shape = x_copy.shape
        if x_copy.dim() > 2:
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
        else:
            x_flat = x_copy
        
        # Convert to 3D format expected by ResNet
        # Project to the right size
        x_3d = self.proj_layer(x_flat)
        x_3d = x_3d.reshape(-1, 3, 32, 32)
        
        # Forward pass through ResNet
        # ResNet already applies log_softmax internally
        log_probs = self.resnet(x_3d)
        
        # Calculate softmax probabilities if needed
        if return_softmax:
            # Convert log probabilities to softmax probabilities
            softmax_probs = torch.exp(log_probs)
            
            # Reshape outputs if needed
            if x_copy.dim() > 2:
                new_shape = original_shape[:-1] + (self.num_classes,)
                log_probs = log_probs.reshape(*new_shape)
                softmax_probs = softmax_probs.reshape(*new_shape)
            
            return log_probs, softmax_probs
        else:
            # Reshape outputs if needed
            if x_copy.dim() > 2:
                new_shape = original_shape[:-1] + (self.num_classes,)
                log_probs = log_probs.reshape(*new_shape)
            
            return log_probs

class Resnet_plus(nn.Module):
    def __init__(self, image_size, xa_dim, num_classes=100):
        # configurations
        super().__init__()
        self.image_size = image_size
        self.num_classes = num_classes
        
        # For GFedCL, we create a simple network from encoder features to class predictions
        self.fc1 = nn.Linear(512, xa_dim)
        self.fc2 = nn.Linear(xa_dim, xa_dim)
        
        # aux-classifier fc
        self.fc_classifier = nn.Linear(xa_dim, self.num_classes)
        
        # activation functions:
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, return_softmax=False):
        # Get the latent representation through forward_to_xa
        xa = self.forward_to_xa(x)
        
        # Then get the class probabilities through forward_from_xa
        log_probs, softmax_probs = self.forward_from_xa(xa)
        
        # Return appropriate values based on return_softmax flag
        if return_softmax:
            return log_probs, softmax_probs
        else:
            return log_probs

    def forward_to_xa(self, x):
        # In GFedCL, x is already a feature vector from the encoder [batch_size, 512]
        return x

    def forward_from_xa(self, xa):
        xa = F.leaky_relu(self.fc1(xa))
        xb = F.leaky_relu(self.fc2(xa))
        logits = self.fc_classifier(xb)
        
        # Get softmax probabilities
        softmax_probs = self.softmax(logits)
        
        # Get log probabilities (with a small epsilon to avoid log(0))
        log_probs = torch.log(softmax_probs + 1e-10)
        
        return log_probs, softmax_probs
    
#-----------------------------
# For FedAvg - Updated for EMNIST
#-----------------------------

class ClassicEncoder(nn.Module):
    """
    Simple Encoder that only uses samples and labels as input,
    without relying on graph embeddings.
    Updated for EMNIST-letter
    """
    def __init__(self, opt):
        super(ClassicEncoder, self).__init__()
        # Configuration - UPDATED for EMNIST
        self.opt = opt
        self.input_dim = 1 * 28 * 28  # EMNIST dimensions
        self.hidden_dim = opt.nh
        self.num_classes = opt.num_classes
        
        # Encoder for data
        self.data_encoder = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.opt.ni)
        )
        
        # Label embedding
        self.label_embedding = nn.Embedding(self.num_classes, self.opt.ni // 4)
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(self.opt.ni + self.opt.ni // 4, self.opt.ni),
            nn.ReLU()
        )
    
    def forward(self, x, labels):
        """
        Forward pass through the encoder with proper type handling
        
        Args:
            x: Input data [batch_size, channels, height, width] or [batch_size, input_dim]
            labels: Class labels [batch_size] - integer tensor, don't require gradients
        
        Returns:
            latent: Encoded latent representation [batch_size, nh]
        """
        # Make copies of data tensors without trying to require gradients on labels
        x_copy = x.clone()
        # For labels, just clone without requiring gradients
        labels_copy = labels.clone().detach()
        
        # Always work with flattened input
        batch_size = x_copy.size(0)
        if x_copy.dim() > 2:
            x_flat = x_copy.reshape(batch_size, -1)
        else:
            x_flat = x_copy
            
        # Process data
        data_features = self.data_encoder(x_flat)
        
        # Process labels - don't require gradients for label operations
        if labels_copy.dim() == 1:
            label_features = self.label_embedding(labels_copy.long())
        else:
            _, indices = torch.max(labels_copy, dim=1)
            label_features = self.label_embedding(indices.long())
        
        # Combine features and generate latent representation
        combined = torch.cat([data_features, label_features], dim=1)
        latent = self.fusion(combined)
        
        return latent

class ClassicDiscriminator(nn.Module):
    """
    Feature-based discriminator that distinguishes between real and fake feature representations.
    Takes encoder outputs (feature vectors) rather than raw images.
    """
    def __init__(self, opt):
        super(ClassicDiscriminator, self).__init__()
        self.input_dim = opt.ni  # Matches encoder output dimension
        self.hidden_dim = opt.nh // 2
        self.device = opt.device
        
        # Define discriminator network for feature vectors
        self.discriminator = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.BatchNorm1d(self.hidden_dim) if not opt.no_bn else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.BatchNorm1d(self.hidden_dim) if not opt.no_bn else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.BatchNorm1d(self.hidden_dim // 2) if not opt.no_bn else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hidden_dim // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """
        Forward pass through the discriminator
        
        Args:
            x: Feature vectors from encoder [batch_size, opt.nh]
            
        Returns:
            Real/fake probability [batch_size, 1]
        """
        # Ensure we work with 2D tensor
        batch_size = x.size(0)
        if x.dim() > 2:
            x_flat = x.reshape(batch_size, -1)
        else:
            x_flat = x
        
        # Forward pass
        return self.discriminator(x_flat)
    
#-----------------------------
# Graph Attention Network (GAT) Implementation
#-----------------------------

class GraphAttentionLayer(nn.Module):
    """
    Full Graph Attention Layer implementation
    """
    def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha
        
        # Define transformations
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        
        # Attention mechanism
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, h, adj=None):
        # h: input features [N, in_features]
        # adj: adjacency matrix [N, N] or None
        
        Wh = torch.mm(h, self.W)  # [N, out_features]
        N = Wh.size(0)
        
        # Create all possible pairs for attention
        a_input = self._prepare_attentional_mechanism_input(Wh)
        
        # Compute attention coefficients
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        
        # If adjacency matrix is provided, mask attention
        if adj is not None:
            zero_vec = -9e15 * torch.ones_like(e)
            attention = torch.where(adj > 0, e, zero_vec)
        else:
            attention = e
        
        # Apply softmax to get attention weights
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        
        # Apply attention to node features
        h_prime = torch.matmul(attention, Wh)
        
        return h_prime, attention
    
    def _prepare_attentional_mechanism_input(self, Wh):
        # Prepare attention mechanism input
        N = Wh.size(0)
        
        # Repeat each node's features N times
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        
        # Repeat the entire feature matrix N times
        Wh_repeated_alternating = Wh.repeat(N, 1)
        
        # Combine for all pairs
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        
        return all_combinations_matrix.view(N, N, 2 * self.out_features)

class MultiHeadGAT(nn.Module):
    """
    Multi-Head Graph Attention Network
    """
    def __init__(self, in_features, hidden_dim, out_features, n_heads=4, dropout=0.6, alpha=0.2):
        super(MultiHeadGAT, self).__init__()
        self.n_heads = n_heads
        
        # Multiple GAT layers (one per head)
        self.attentions = nn.ModuleList([
            GraphAttentionLayer(in_features, hidden_dim, dropout, alpha) 
            for _ in range(n_heads)
        ])
        
        # Output projection
        self.out_proj = nn.Linear(hidden_dim * n_heads, out_features)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ELU()
    
    def forward(self, x, adj=None):
        # Apply each attention head
        head_outputs = []
        attention_weights = []
        
        for att in self.attentions:
            head_out, att_weights = att(x, adj)
            head_outputs.append(head_out)
            attention_weights.append(att_weights)
        
        # Concatenate all head outputs
        x = torch.cat(head_outputs, dim=1)
        x = self.dropout(x)
        x = self.out_proj(x)
        x = self.activation(x)
        
        # Average attention weights across heads
        avg_attention = torch.stack(attention_weights).mean(dim=0)
        
        return x, avg_attention

class EnhancedDyGAT(nn.Module):
    """
    Enhanced Dynamic Graph Attention Network for generating relational graphs
    with attention scores between clients
    """
    def __init__(self, opt):
        super(EnhancedDyGAT, self).__init__()
        self.opt = opt
        self.num_clients = opt.num_clients
        self.device = torch.device(opt.device)
        self.hidden_dim = 128
        self.embedding_dim = 64
        
        # Feature extraction network
        self.encoder = nn.Sequential(
            nn.Linear(1000, self.hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_dim, self.embedding_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.embedding_dim)
        )
        
        # Graph attention network
        self.gat = MultiHeadGAT(
            in_features=self.embedding_dim,
            hidden_dim=self.embedding_dim // 2,
            out_features=self.embedding_dim,
            n_heads=4,
            dropout=0.2
        )
    
    def _flatten_updates(self, updates):
        """
        Flatten model updates into a fixed-size feature vector
        """
        param_list = []
        for param_name, param_update in updates.items():
            if param_update.numel() > 0:
                param_list.append(param_update.view(-1))
        
        if param_list:
            flat_tensor = torch.cat(param_list)
            
            # Use importance sampling to select most significant parameters
            if flat_tensor.size(0) >= 1000:
                # Use magnitude as importance
                importance = torch.abs(flat_tensor)
                _, indices = torch.topk(importance, 1000)
                return flat_tensor[indices]
            else:
                # Pad if smaller
                padded = torch.zeros(1000, device=self.device)
                padded[:flat_tensor.size(0)] = flat_tensor
                return padded
        else:
            return torch.zeros(1000, device=self.device)
    
    def _process_updates(self, model_updates):
        """
        Process and encode client model updates
        """
        client_features = []
        for client_id in range(self.num_clients):
            if client_id in model_updates:
                flat_params = self._flatten_updates(model_updates[client_id])
                client_features.append(flat_params)
            else:
                client_features.append(torch.zeros(1000, device=self.device))
        
        # Stack all client features and encode
        features = torch.stack(client_features)
        encoded = self.encoder(features)
        
        return encoded
    
    def learn(self, epochs, model_updates):
        """
        Generate relational graph using attention mechanism
        """
        with torch.no_grad():
            try:
                # Process model updates into embeddings
                embeddings = self._process_updates(model_updates)
                
                # Apply graph attention network (initially without adjacency)
                refined_embeddings, attention_weights = self.gat(embeddings)
                
                # Convert attention weights to numpy for easier processing
                relation_graph = attention_weights.cpu().numpy()
                
                # Ensure proper normalization
                # Each row sums to 1 due to softmax in attention mechanism
                
                # Log the generated graph properties
                density = np.mean(relation_graph > 0.01)
                avg_relation = np.mean(relation_graph)
                max_relation = np.max(relation_graph)
                min_relation = np.min(relation_graph[relation_graph > 0])
                
                # logger.info(f"Generated relational graph with density: {density:.4f}, "
                #             f"avg: {avg_relation:.4f}, max: {max_relation:.4f}, min: {min_relation:.4f}")
                
                return relation_graph
                
            except Exception as e:
                # Log the error details
                logger.error(f"Error in DyGAT.learn: {str(e)}")
                
                # Create a fallback relation graph with some random structure
                fallback_graph = np.random.rand(self.num_clients, self.num_clients) * 0.3
                fallback_graph = (fallback_graph + fallback_graph.T) / 2  # Make symmetric
                np.fill_diagonal(fallback_graph, 1.0)  # Self-connections are strong
                
                logger.warning("Using fallback relation graph due to error")
                return fallback_graph
    
    def forward(self, model_updates):
        """
        Forward pass for compatibility
        """
        with torch.no_grad():
            return torch.tensor(self.learn(1, model_updates))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from model.modules import MultiHeadGAT
import logging

logger = logging.getLogger('GFedCL')

class TemporalGAT(nn.Module):
    """
    Enhanced Graph Attention Network that incorporates temporal variation patterns
    based on the GAEN (Graph Attention Evolving Networks) approach.
    
    This implementation:
    1. Processes model updates from clients
    2. Maintains a history of adjacency matrices
    3. Calculates temporal variation patterns using a simplified tensor factorization approach
    4. Combines spatial attention with temporal patterns for enhanced client relationship modeling
    """
    def __init__(self, opt):
        super(TemporalGAT, self).__init__()
        self.opt = opt
        self.num_clients = opt.num_clients
        self.device = torch.device(opt.device)
        self.hidden_dim = 128
        self.embedding_dim = 64
        self.max_time_window = 2  # Maximum window size as specified
        self.current_task = 0     # Track the current task ID
        
        # Feature extraction network
        self.encoder = nn.Sequential(
            nn.Linear(1000, self.hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_dim, self.embedding_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.embedding_dim)
        )
        
        # Graph attention network for spatial attention
        self.gat = MultiHeadGAT(
            in_features=self.embedding_dim,
            hidden_dim=self.embedding_dim // 2,
            out_features=self.embedding_dim,
            n_heads=4,
            dropout=0.2
        )
        
        # Historical adjacency matrices storage - initialized as empty
        self.historical_graphs = []
        
    def _flatten_updates(self, updates):
        """
        Flatten model updates into a fixed-size feature vector
        
        Args:
            updates: Dictionary of parameter updates
            
        Returns:
            torch.Tensor: Flattened parameter vector of size 1000
        """
        param_list = []
        for param_name, param_update in updates.items():
            if param_update.numel() > 0:
                # Convert to float32 to ensure consistent dtype
                param_list.append(param_update.view(-1).float())
        
        if param_list:
            flat_tensor = torch.cat(param_list)
            
            # Use importance sampling to select most significant parameters
            if flat_tensor.size(0) >= 1000:
                # Use magnitude as importance
                importance = torch.abs(flat_tensor)
                _, indices = torch.topk(importance, 1000)
                return flat_tensor[indices]
            else:
                # Pad if smaller
                padded = torch.zeros(1000, device=self.device, dtype=torch.float32)
                padded[:flat_tensor.size(0)] = flat_tensor
                return padded
        else:
            return torch.zeros(1000, device=self.device, dtype=torch.float32)
    
    def _process_updates(self, model_updates):
        """
        Process and encode client model updates
        
        Args:
            model_updates: Dictionary mapping client IDs to parameter updates
            
        Returns:
            torch.Tensor: Encoded features for all clients
        """
        client_features = []
        for client_id in range(self.num_clients):
            if client_id in model_updates:
                flat_params = self._flatten_updates(model_updates[client_id])
                client_features.append(flat_params)
            else:
                client_features.append(torch.zeros(1000, device=self.device, dtype=torch.float32))
        
        # Stack all client features and encode
        features = torch.stack(client_features)
        
        # Ensure features are float32
        features = features.float()
        
        # Process through encoder
        encoded = self.encoder(features)
        
        return encoded
    
    def _calculate_temporal_patterns(self):
        """
        Calculate temporal variation patterns based on historical spatial attention matrices
        
        This implements a simplified version of the PARAFAC tensor factorization
        described in the GAEN paper, suitable for a small time window.
        
        Returns:
            numpy.ndarray: Client-to-client similarity matrix based on temporal patterns,
                          or None if not enough historical data
        """
        # Calculate appropriate window size based on current task
        effective_window = min(self.current_task + 1, self.max_time_window)
        
        # For task 1, we can't calculate temporal patterns (need at least 2 graphs)
        if effective_window < 2:
            logger.info(f"Task {self.current_task}: Cannot calculate temporal patterns for first task")
            return None
        
        # Ensure we have enough historical data
        if len(self.spatial_attention_history) < effective_window - 1:  # -1 because we don't include current task yet
            logger.info(f"Task {self.current_task}: Not enough historical data for temporal patterns")
            return None
        
        # Use the most recent (effective_window-1) graphs from history
        # Note: we don't include current spatial attention here as it's passed separately to the learn method
        recent_graphs = self.spatial_attention_history[-(effective_window-1):]
        logger.info(f"Task {self.current_task}: Calculating temporal patterns with {len(recent_graphs)} previous graphs")
        
        try:
            # Stack adjacency matrices into a 3-way tensor: [clients × clients × time]
            stacked_tensor = np.stack(recent_graphs, axis=2)
            
            # Simplified version of PARAFAC for small time window
            # Extract temporal variation patterns for each client
            n_clients = self.num_clients
            time_steps = stacked_tensor.shape[2]
            
            # For each client, calculate how their connections change over time
            # This is similar to the variation pattern extraction in GAEN
            variation_patterns = np.zeros((n_clients, n_clients * time_steps))
            
            for i in range(n_clients):
                # Extract all connections for client i over time
                # Reshape to flatten the client × time dimensions
                client_temporal_connections = stacked_tensor[i, :, :].reshape(-1)
                variation_patterns[i] = client_temporal_connections
            
            # Calculate pairwise similarity between variation patterns
            pattern_similarities = np.zeros((n_clients, n_clients))
            
            for i in range(n_clients):
                for j in range(n_clients):
                    # Compute cosine similarity between variation patterns
                    pattern_i = variation_patterns[i]
                    pattern_j = variation_patterns[j]
                    
                    norm_i = np.linalg.norm(pattern_i)
                    norm_j = np.linalg.norm(pattern_j)
                    
                    if norm_i > 0 and norm_j > 0:
                        # Cosine similarity
                        similarity = np.dot(pattern_i, pattern_j) / (norm_i * norm_j)
                        pattern_similarities[i, j] = similarity
                    else:
                        # Default to zero similarity if one pattern is all zeros
                        pattern_similarities[i, j] = 0.0
            
            # Apply softmax row-wise to get attention weights
            def softmax(x, axis=1):
                exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
                return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
            
            pattern_similarities = softmax(pattern_similarities)
            
            logger.info(f"Task {self.current_task}: Successfully calculated temporal pattern similarities")
            return pattern_similarities
            
        except Exception as e:
            logger.error(f"Task {self.current_task}: Error in calculating temporal patterns: {str(e)}")
            return None
    
    def learn(self, epochs, model_updates, task_id=None):
        """
        Generate relational graph using attention mechanism enhanced with temporal patterns
        
        Args:
            epochs: Number of epochs (unused, kept for API compatibility)
            model_updates: Dictionary mapping client IDs to parameter updates
            task_id: Current task ID (if provided, updates the internal task counter)
            
        Returns:
            numpy.ndarray: Relational graph with combined spatial and temporal attention
        """
        # Update current task if provided
        if task_id is not None:
            self.current_task = task_id
            logger.info(f"Updated current task to {self.current_task}")
        
        # Ensure spatial_attention_history is initialized
        if not hasattr(self, 'spatial_attention_history'):
            logger.warning("spatial_attention_history not initialized, creating it now")
            self.spatial_attention_history = []
            
        # Ensure last_spatial_attention and last_temporal_patterns are initialized
        if not hasattr(self, 'last_spatial_attention'):
            self.last_spatial_attention = None
        if not hasattr(self, 'last_temporal_patterns'):
            self.last_temporal_patterns = None
        
        with torch.no_grad():
            try:
                # Process model updates into embeddings
                embeddings = self._process_updates(model_updates)
                
                # Apply graph attention network to get spatial attention for current task
                refined_embeddings, spatial_attention = self.gat(embeddings)
                
                # Convert spatial attention weights to numpy for processing
                spatial_attention_np = spatial_attention.cpu().numpy().astype(np.float32)
                
                # Store for visualization
                self.last_spatial_attention = spatial_attention_np
                
                # Calculate temporal pattern similarities using previous spatial attention matrices
                logger.debug(f"Calculating temporal patterns with {len(self.spatial_attention_history)} historical graphs")
                temporal_patterns = self._calculate_temporal_patterns()
                
                # Store for visualization
                self.last_temporal_patterns = temporal_patterns
                
                # Combine spatial and temporal attention
                if temporal_patterns is not None:
                    # Balance between spatial and temporal components (tunable parameter)
                    alpha = 0.7  # Weight for spatial attention
                    combined_attention = alpha * spatial_attention_np + (1 - alpha) * temporal_patterns
                    logger.info(f"Task {self.current_task}: Combined spatial and temporal attention with alpha={alpha}")
                else:
                    # If we don't have enough historical data, use only spatial attention
                    combined_attention = spatial_attention_np
                    logger.info(f"Task {self.current_task}: Using only spatial attention (no temporal patterns available)")
                
                # Store the current spatial attention for future temporal pattern calculation
                # This is critical - we're storing the "pure" spatial attention, not the combined attention
                self.spatial_attention_history.append(spatial_attention_np)
                logger.debug(f"Added spatial attention matrix to history, now have {len(self.spatial_attention_history)} matrices")
                
                # Limit history length to prevent memory issues
                max_history = 10
                if len(self.spatial_attention_history) > max_history:
                    self.spatial_attention_history = self.spatial_attention_history[-max_history:]
                
                # Log the generated graph properties
                density = np.mean(combined_attention > 0.01)
                avg_relation = np.mean(combined_attention)
                max_relation = np.max(combined_attention)
                min_relation = np.min(combined_attention[combined_attention > 0])
                
                logger.info(f"Task {self.current_task}: Generated relational graph with density: {density:.4f}, "
                           f"avg: {avg_relation:.4f}, max: {max_relation:.4f}, min: {min_relation:.4f}")
                
                return combined_attention.astype(np.float32)
                
            except Exception as e:
                # Log the error details with traceback
                import traceback
                logger.error(f"Error in TemporalGAT.learn for task {self.current_task}: {str(e)}")
                logger.error(traceback.format_exc())
                
                # Return fallback as float32
                fallback_graph = np.random.rand(self.num_clients, self.num_clients).astype(np.float32) * 0.3
                fallback_graph = (fallback_graph + fallback_graph.T) / 2  # Make symmetric
                np.fill_diagonal(fallback_graph, 1.0)  # Self-connections are strong
                
                logger.warning(f"Task {self.current_task}: Using fallback relation graph due to error")
                
                # Still store the fallback graph in our history
                self.spatial_attention_history.append(fallback_graph)
                
                # Set last components to fallback for visualization
                self.last_spatial_attention = fallback_graph
                self.last_temporal_patterns = None
                
                return fallback_graph.astype(np.float32)
    
    # Add this updated forward method to the GNet class in model/modules.py
    def forward(self, x):
        # Clone input to avoid any in-place operations
        x_copy = x.clone()
        
        # Ensure input is float32
        if x_copy.dtype != torch.float32:
            x_copy = x_copy.float()
        
        # Always use 2D tensors for network operations
        if x_copy.dim() > 2:
            batch_shape = x_copy.shape[:-1]
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
            output = self.net(x_flat)
            # Reshape back to original batch dimensions
            return output.reshape(*batch_shape, self.output_dim)
        else:
            return self.net(x_copy)
        

# CNN-based ResNet18 for EMNIST (updated for grayscale)
class ResNet18EMNIST(nn.Module):
    """
    ResNet-18 implementation adapted for EMNIST-letter with 28x28 grayscale images
    """
    def __init__(self, num_classes=26):
        super(ResNet18EMNIST, self).__init__()
        from model.resnet import ResNet18
        
        # Use the existing ResNet18 but modify for grayscale input
        self.resnet = ResNet18(num_classes=num_classes, input_channels=1)  # 1 channel for grayscale
        
    def forward(self, x):
        return self.resnet(x)
    
    def extract_features(self, x):
        return self.resnet.extract_features(x)
    
class ResNet18Classifier(nn.Module):
    """
    ResNet18 classifier adapted for EMNIST-letter dataset
    """
    def __init__(self, num_classes=26, device='cuda'):
        super(ResNet18Classifier, self).__init__()
        self.device = device
        
        # Use ResNet18 adapted for EMNIST
        self.model = ResNet18EMNIST(num_classes=num_classes)
        self.model = self.model.to(device)
        
        # Initialize criterion for training
        self.criterion = nn.CrossEntropyLoss()
        
    def init_optimizer(self, lr=0.001, momentum=0.9, weight_decay=5e-4):
        """Initialize optimizer for the model with given parameters"""
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay
        )
        return self.optimizer
        
    def get_weights(self):
        """Return the model weights"""
        return self.model.state_dict()
        
    def get_gradient_updates(self, old_state):
        """Compute gradient updates from old state"""
        updates = {}
        current_state = self.model.state_dict()
        for key in current_state:
            if key in old_state:
                updates[key] = current_state[key] - old_state[key]
        return updates
        
    def learn(self, epoch, dataloader):
        """Train the model for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            # Track statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        
        # Calculate metrics
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = 100.0 * correct / total
        
        return {
            "loss": epoch_loss,
            "acc": epoch_acc
        }
    
    def forward(self, x):
        """Forward pass through the model"""
        return self.model(x)
    
class ResNetClassifierCBAM(nn.Module):
    """
    ResNet18_CBAM-based classifier that predicts class labels from encoded features.
    """
    def __init__(self, opt):
        super(ResNetClassifierCBAM, self).__init__()
        self.input_dim = opt.nh  # Matches encoder output dimension
        self.num_classes = opt.num_classes
        self.device = opt.device
        
        # Import ResNet18_CBAM here
        from model.resnet_cbam import ResNet18_CBAM
        
        # Create the base ResNet18 with CBAM
        self.resnet = ResNet18_CBAM(num_classes=self.num_classes, input_channels=3)
        
        # Projection layer to convert 1D feature vectors to 3D inputs for ResNet
        self.proj_layer = nn.Sequential(
            nn.Linear(self.input_dim, 3 * 32 * 32),
            nn.ReLU()
        )
    
    def forward(self, x, return_softmax=False):
        """
        Forward pass through the classifier
        
        Args:
            x: Feature vectors from encoder [batch_size, opt.nh]
            return_softmax: Whether to return softmax probabilities
            
        Returns:
            log_probs: Log probabilities [batch_size, num_classes]
            softmax_probs: (optional) Softmax probabilities
        """
        # Ensure we work with 2D tensor
        batch_size = x.size(0)
        if x.dim() > 2:
            x_flat = x.reshape(batch_size, -1)
        else:
            x_flat = x
        
        # Project to 3D format expected by ResNet
        x_3d = self.proj_layer(x_flat)
        x_3d = x_3d.reshape(batch_size, 3, 32, 32)
        
        # Forward pass through ResNet18_CBAM
        log_probs = self.resnet(x_3d)  # Already returns log_softmax
        
        if return_softmax:
            # Convert log probabilities to softmax probabilities
            softmax_probs = torch.exp(log_probs)
            return log_probs, softmax_probs
        else:
            return log_probs
        
def tensor_memory_in_MB(tensor):
    return tensor.element_size() * tensor.nelement() / (1024 ** 2)