import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import copy
import numpy as np

logger = logging.getLogger('GFedCL')

class Identity(nn.Module):
    """Simple identity module"""
    def __init__(self):
        super(Identity, self).__init__()
    
    def forward(self, x):
        return x

#-----------------------------
# For GFedCL with ILI Dataset
#-----------------------------
class GNet(nn.Module):
    """
    Graph Network - takes client relationship vector and generates embedding
    """
    def __init__(self, opt):
        super(GNet, self).__init__()
        self.num_clients = opt.num_clients
        self.hidden_dim = opt.nh
        self.output_dim = opt.nt
        
        # Simple network with minimal operations
        self.net = nn.Sequential(
            nn.Linear(self.num_clients, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim)
        )
    
    def forward(self, x):
        # Clone input to avoid any in-place operations
        x_copy = x.clone()
        
        # Ensure input is float32
        if x_copy.dtype != torch.float32:
            x_copy = x_copy.float()
        
        # Always use 2D tensors for network operations
        if x_copy.dim() > 2:
            batch_shape = x_copy.shape[:-1]
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
            output = self.net(x_flat)
            # Reshape back to original batch dimensions
            return output.reshape(*batch_shape, self.output_dim)
        else:
            return self.net(x_copy)
        
class FeatureEncoder(nn.Module):
    """
    Encoder for ILI time series data
    """
    def __init__(self, opt):
        super(FeatureEncoder, self).__init__()
        # Input: sequence_length * num_states_per_client
        self.input_dim = opt.sequence_length * opt.states_per_client
        self.hidden_dim = opt.nh
        self.output_dim = opt.ni
        self.graph_dim = opt.nt
        
        # Time series feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        
        # Graph embedding processor
        self.graph_processor = nn.Sequential(
            nn.Linear(self.graph_dim, self.hidden_dim // 4),
            nn.ReLU()
        )
        
        # Combine features
        self.fusion = nn.Sequential(
            nn.Linear(self.hidden_dim + self.hidden_dim // 4, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim)
        )
    
    def forward(self, x, labels, graph_embed=None):
        """
        Args:
            x: Time series input [batch_size, sequence_length * num_states]
            labels: Target values (discretized) [batch_size, num_states]
            graph_embed: Graph embedding [batch_size, graph_dim]
        """
        # Extract time series features
        ts_features = self.feature_extractor(x)
        
        # Process graph embedding if provided
        if graph_embed is not None:
            graph_features = self.graph_processor(graph_embed)
            combined = torch.cat([ts_features, graph_features], dim=1)
            output = self.fusion(combined)
        else:
            output = self.fusion(torch.cat([ts_features, 
                                          torch.zeros(ts_features.size(0), 
                                                    self.hidden_dim // 4, 
                                                    device=ts_features.device)], dim=1))
        
        return output

class GraphDNet(nn.Module):
    """
    Graph Discriminator - reconstructs graph embedding from encoder latent space
    """
    def __init__(self, opt):
        super(GraphDNet, self).__init__()
        self.input_dim = opt.ni  # Changed from opt.nh to opt.ni to match encoder output
        self.hidden_dim = opt.nh
        self.output_dim = opt.nt
        
        # Simple network with minimal operations
        self.net = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim)
        )
    
    def forward(self, x):
        # Clone input to avoid in-place modifications
        x_copy = x.clone()
        
        # Always use 2D tensors for network operations
        if x_copy.dim() > 2:
            batch_shape = x_copy.shape[:-1]
            x_flat = x_copy.reshape(-1, x_copy.size(-1))
            output = self.net(x_flat)
            # Reshape back to original batch dimensions
            return output.reshape(*batch_shape, self.output_dim)
        else:
            return self.net(x_copy)

class PredNet(nn.Module):
    """
    Predictor for ILI values (multi-output regression)
    Modified to output continuous values instead of classification
    """
    def __init__(self, opt):
        super(PredNet, self).__init__()
        self.input_dim = opt.ni
        self.hidden_dim = opt.nh
        self.num_states = opt.states_per_client
        
        # Shared layers
        self.shared = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # State-specific heads (one per state) - output single value per state
        self.state_heads = nn.ModuleList([
            nn.Linear(self.hidden_dim // 2, 1)  # Single output for regression
            for _ in range(self.num_states)
        ])
    
    def forward(self, x, return_raw=False):
        # Shared processing
        shared_features = self.shared(x)
        
        # Get predictions for each state
        state_outputs = []
        for i, head in enumerate(self.state_heads):
            state_output = head(shared_features)
            state_outputs.append(state_output)
        
        # Stack outputs: [batch_size, num_states, 1]
        outputs = torch.stack(state_outputs, dim=1)
        
        # Remove the last dimension: [batch_size, num_states]
        outputs = outputs.squeeze(-1)
        
        # Apply sigmoid to bound outputs between 0 and 1 (since data is normalized)
        predictions = torch.sigmoid(outputs)
        
        if return_raw:
            return outputs, predictions
        else:
            return predictions
        
# Keep existing graph attention classes
class GraphAttentionLayer(nn.Module):
    """
    Full Graph Attention Layer implementation
    """
    def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha
        
        # Define transformations
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        
        # Attention mechanism
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, h, adj=None):
        # h: input features [N, in_features]
        # adj: adjacency matrix [N, N] or None
        
        Wh = torch.mm(h, self.W)  # [N, out_features]
        N = Wh.size(0)
        
        # Create all possible pairs for attention
        a_input = self._prepare_attentional_mechanism_input(Wh)
        
        # Compute attention coefficients
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        
        # If adjacency matrix is provided, mask attention
        if adj is not None:
            zero_vec = -9e15 * torch.ones_like(e)
            attention = torch.where(adj > 0, e, zero_vec)
        else:
            attention = e
        
        # Apply softmax to get attention weights
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        
        # Apply attention to node features
        h_prime = torch.matmul(attention, Wh)
        
        return h_prime, attention
    
    def _prepare_attentional_mechanism_input(self, Wh):
        # Prepare attention mechanism input
        N = Wh.size(0)
        
        # Repeat each node's features N times
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        
        # Repeat the entire feature matrix N times
        Wh_repeated_alternating = Wh.repeat(N, 1)
        
        # Combine for all pairs
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        
        return all_combinations_matrix.view(N, N, 2 * self.out_features)

class MultiHeadGAT(nn.Module):
    """
    Multi-Head Graph Attention Network
    """
    def __init__(self, in_features, hidden_dim, out_features, n_heads=4, dropout=0.6, alpha=0.2):
        super(MultiHeadGAT, self).__init__()
        self.n_heads = n_heads
        
        # Multiple GAT layers (one per head)
        self.attentions = nn.ModuleList([
            GraphAttentionLayer(in_features, hidden_dim, dropout, alpha) 
            for _ in range(n_heads)
        ])
        
        # Output projection
        self.out_proj = nn.Linear(hidden_dim * n_heads, out_features)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ELU()
    
    def forward(self, x, adj=None):
        # Apply each attention head
        head_outputs = []
        attention_weights = []
        
        for att in self.attentions:
            head_out, att_weights = att(x, adj)
            head_outputs.append(head_out)
            attention_weights.append(att_weights)
        
        # Concatenate all head outputs
        x = torch.cat(head_outputs, dim=1)
        x = self.dropout(x)
        x = self.out_proj(x)
        x = self.activation(x)
        
        # Average attention weights across heads
        avg_attention = torch.stack(attention_weights).mean(dim=0)
        
        return x, avg_attention

class TemporalGAT(nn.Module):
    """
    Enhanced Graph Attention Network that incorporates temporal variation patterns
    based on the GAEN (Graph Attention Evolving Networks) approach.
    
    This implementation:
    1. Processes model updates from clients
    2. Maintains a history of adjacency matrices
    3. Calculates temporal variation patterns using a simplified tensor factorization approach
    4. Combines spatial attention with temporal patterns for enhanced client relationship modeling
    """
    def __init__(self, opt):
        super(TemporalGAT, self).__init__()
        self.opt = opt
        self.num_clients = opt.num_clients
        self.device = torch.device(opt.device)
        self.hidden_dim = 128
        self.embedding_dim = 64
        self.max_time_window = 2  # Maximum window size as specified
        self.current_task = 0     # Track the current task ID
        
        # Feature extraction network - adapted for ILI data
        self.encoder = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_dim, self.embedding_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.embedding_dim)
        )
        
        # Graph attention network for spatial attention
        self.gat = MultiHeadGAT(
            in_features=self.embedding_dim,
            hidden_dim=self.embedding_dim // 2,
            out_features=self.embedding_dim,
            n_heads=4,
            dropout=0.2
        )
        
        # Historical adjacency matrices storage - initialized as empty
        self.spatial_attention_history = []
        
    def _process_client_features(self, client_features=None):
        """
        Process client features for ILI dataset
        
        Args:
            client_features: Optional client features. If None, generate random features
            
        Returns:
            torch.Tensor: Encoded features for all clients
        """
        if client_features is None:
            # Generate random features for initialization
            features = torch.randn(self.num_clients, self.embedding_dim, 
                                 device=self.device, dtype=torch.float32)
        else:
            features = client_features
        
        # Ensure features are float32
        features = features.float()
        
        # Process through encoder
        encoded = self.encoder(features)
        
        return encoded
    
    def _calculate_temporal_patterns(self):
        """
        Calculate temporal variation patterns based on historical spatial attention matrices
        
        This implements a simplified version of the PARAFAC tensor factorization
        described in the GAEN paper, suitable for a small time window.
        
        Returns:
            numpy.ndarray: Client-to-client similarity matrix based on temporal patterns,
                          or None if not enough historical data
        """
        # Calculate appropriate window size based on current task
        effective_window = min(self.current_task + 1, self.max_time_window)
        
        # For task 1, we can't calculate temporal patterns (need at least 2 graphs)
        if effective_window < 2:
            logger.info(f"Task {self.current_task}: Cannot calculate temporal patterns for first task")
            return None
        
        # Ensure we have enough historical data
        if len(self.spatial_attention_history) < effective_window - 1:  # -1 because we don't include current task yet
            logger.info(f"Task {self.current_task}: Not enough historical data for temporal patterns")
            return None
        
        # Use the most recent (effective_window-1) graphs from history
        # Note: we don't include current spatial attention here as it's passed separately to the learn method
        recent_graphs = self.spatial_attention_history[-(effective_window-1):]
        logger.info(f"Task {self.current_task}: Calculating temporal patterns with {len(recent_graphs)} previous graphs")
        
        try:
            # Stack adjacency matrices into a 3-way tensor: [clients × clients × time]
            stacked_tensor = np.stack(recent_graphs, axis=2)
            
            # Simplified version of PARAFAC for small time window
            # Extract temporal variation patterns for each client
            n_clients = self.num_clients
            time_steps = stacked_tensor.shape[2]
            
            # For each client, calculate how their connections change over time
            # This is similar to the variation pattern extraction in GAEN
            variation_patterns = np.zeros((n_clients, n_clients * time_steps))
            
            for i in range(n_clients):
                # Extract all connections for client i over time
                # Reshape to flatten the client × time dimensions
                client_temporal_connections = stacked_tensor[i, :, :].reshape(-1)
                variation_patterns[i] = client_temporal_connections
            
            # Calculate pairwise similarity between variation patterns
            pattern_similarities = np.zeros((n_clients, n_clients))
            
            for i in range(n_clients):
                for j in range(n_clients):
                    # Compute cosine similarity between variation patterns
                    pattern_i = variation_patterns[i]
                    pattern_j = variation_patterns[j]
                    
                    norm_i = np.linalg.norm(pattern_i)
                    norm_j = np.linalg.norm(pattern_j)
                    
                    if norm_i > 0 and norm_j > 0:
                        # Cosine similarity
                        similarity = np.dot(pattern_i, pattern_j) / (norm_i * norm_j)
                        pattern_similarities[i, j] = similarity
                    else:
                        # Default to zero similarity if one pattern is all zeros
                        pattern_similarities[i, j] = 0.0
            
            # Apply softmax row-wise to get attention weights
            def softmax(x, axis=1):
                exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
                return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
            
            pattern_similarities = softmax(pattern_similarities)
            
            logger.info(f"Task {self.current_task}: Successfully calculated temporal pattern similarities")
            return pattern_similarities
            
        except Exception as e:
            logger.error(f"Task {self.current_task}: Error in calculating temporal patterns: {str(e)}")
            return None
    
    def learn(self, epochs, model_updates=None, task_id=None, client_features=None):
        """
        Generate relational graph using attention mechanism enhanced with temporal patterns
        
        Args:
            epochs: Number of epochs (unused, kept for API compatibility)
            model_updates: Dictionary mapping client IDs to parameter updates (optional for ILI)
            task_id: Current task ID (if provided, updates the internal task counter)
            client_features: Optional client features for generating the graph
            
        Returns:
            numpy.ndarray: Relational graph with combined spatial and temporal attention
        """
        # Update current task if provided
        if task_id is not None:
            self.current_task = task_id
            logger.info(f"Updated current task to {self.current_task}")
        
        # Ensure last_spatial_attention and last_temporal_patterns are initialized
        if not hasattr(self, 'last_spatial_attention'):
            self.last_spatial_attention = None
        if not hasattr(self, 'last_temporal_patterns'):
            self.last_temporal_patterns = None
        
        with torch.no_grad():
            try:
                # Process client features
                embeddings = self._process_client_features(client_features)
                
                # Apply graph attention network to get spatial attention for current task
                refined_embeddings, spatial_attention = self.gat(embeddings)
                
                # Convert spatial attention weights to numpy for processing
                spatial_attention_np = spatial_attention.cpu().numpy().astype(np.float32)
                
                # Store for visualization
                self.last_spatial_attention = spatial_attention_np
                
                # Calculate temporal pattern similarities using previous spatial attention matrices
                logger.debug(f"Calculating temporal patterns with {len(self.spatial_attention_history)} historical graphs")
                temporal_patterns = self._calculate_temporal_patterns()
                
                # Store for visualization
                self.last_temporal_patterns = temporal_patterns
                
                # Combine spatial and temporal attention
                if temporal_patterns is not None:
                    # Balance between spatial and temporal components (tunable parameter)
                    alpha = 0.7  # Weight for spatial attention
                    combined_attention = alpha * spatial_attention_np + (1 - alpha) * temporal_patterns
                    logger.info(f"Task {self.current_task}: Combined spatial and temporal attention with alpha={alpha}")
                else:
                    # If we don't have enough historical data, use only spatial attention
                    combined_attention = spatial_attention_np
                    logger.info(f"Task {self.current_task}: Using only spatial attention (no temporal patterns available)")
                
                # Store the current spatial attention for future temporal pattern calculation
                # This is critical - we're storing the "pure" spatial attention, not the combined attention
                self.spatial_attention_history.append(spatial_attention_np)
                logger.debug(f"Added spatial attention matrix to history, now have {len(self.spatial_attention_history)} matrices")
                
                # Limit history length to prevent memory issues
                max_history = 10
                if len(self.spatial_attention_history) > max_history:
                    self.spatial_attention_history = self.spatial_attention_history[-max_history:]
                
                # Log the generated graph properties
                density = np.mean(combined_attention > 0.01)
                avg_relation = np.mean(combined_attention)
                max_relation = np.max(combined_attention)
                min_relation = np.min(combined_attention[combined_attention > 0])
                
                logger.info(f"Task {self.current_task}: Generated relational graph with density: {density:.4f}, "
                           f"avg: {avg_relation:.4f}, max: {max_relation:.4f}, min: {min_relation:.4f}")
                
                return combined_attention.astype(np.float32)
                
            except Exception as e:
                # Log the error details with traceback
                import traceback
                logger.error(f"Error in TemporalGAT.learn for task {self.current_task}: {str(e)}")
                logger.error(traceback.format_exc())
                
                # Return fallback as float32
                fallback_graph = np.random.rand(self.num_clients, self.num_clients).astype(np.float32) * 0.3
                fallback_graph = (fallback_graph + fallback_graph.T) / 2  # Make symmetric
                np.fill_diagonal(fallback_graph, 1.0)  # Self-connections are strong
                
                logger.warning(f"Task {self.current_task}: Using fallback relation graph due to error")
                
                # Still store the fallback graph in our history
                self.spatial_attention_history.append(fallback_graph)
                
                # Set last components to fallback for visualization
                self.last_spatial_attention = fallback_graph
                self.last_temporal_patterns = None
                
                return fallback_graph.astype(np.float32)
#-----------------------------
# For FedAvg - Updated for TinyImageNet
#-----------------------------

class ClassicEncoder(nn.Module):
    """
    Simple Encoder that only uses samples and labels as input,
    without relying on graph embeddings.
    Updated for TinyImageNet
    """
    def __init__(self, opt):
        super(ClassicEncoder, self).__init__()
        # Configuration - UPDATED for TinyImageNet
        self.opt = opt
        self.input_dim = 3 * 64 * 64  # TinyImageNet dimensions
        self.hidden_dim = opt.nh
        self.num_classes = opt.num_classes
        
        # CNN-based encoder for TinyImageNet
        self.data_encoder = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 64x64 -> 32x32
            
            # Second conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32x32 -> 16x16
            
            # Third conv block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 16x16 -> 8x8
            
            # Global average pooling
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(256, self.opt.ni)
        )
        
        # Label embedding
        self.label_embedding = nn.Embedding(self.num_classes, self.opt.ni // 4)
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(self.opt.ni + self.opt.ni // 4, self.opt.ni),
            nn.ReLU()
        )
    
    def forward(self, x, labels):
        """
        Forward pass through the encoder with proper type handling
        
        Args:
            x: Input data [batch_size, channels, height, width]
            labels: Class labels [batch_size] - integer tensor, don't require gradients
        
        Returns:
            latent: Encoded latent representation [batch_size, nh]
        """
        # Make copies of data tensors without trying to require gradients on labels
        x_copy = x.clone()
        # For labels, just clone without requiring gradients
        labels_copy = labels.clone().detach()
        
        batch_size = x_copy.size(0)
            
        # Process data through CNN
        data_features = self.data_encoder(x_copy)
        
        # Process labels - don't require gradients for label operations
        if labels_copy.dim() == 1:
            label_features = self.label_embedding(labels_copy.long())
        else:
            _, indices = torch.max(labels_copy, dim=1)
            label_features = self.label_embedding(indices.long())
        
        # Combine features and generate latent representation
        combined = torch.cat([data_features, label_features], dim=1)
        latent = self.fusion(combined)
        
        return latent

class ClassicDiscriminator(nn.Module):
    """
    Feature-based discriminator that distinguishes between real and fake feature representations.
    Takes encoder outputs (feature vectors) rather than raw images.
    """
    def __init__(self, opt):
        super(ClassicDiscriminator, self).__init__()
        self.input_dim = opt.ni  # Matches encoder output dimension
        self.hidden_dim = opt.nh // 2
        self.device = opt.device
        
        # Define discriminator network for feature vectors
        self.discriminator = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.BatchNorm1d(self.hidden_dim) if not opt.no_bn else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.BatchNorm1d(self.hidden_dim) if not opt.no_bn else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.BatchNorm1d(self.hidden_dim // 2) if not opt.no_bn else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hidden_dim // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """
        Forward pass through the discriminator
        
        Args:
            x: Feature vectors from encoder [batch_size, opt.nh]
            
        Returns:
            Real/fake probability [batch_size, 1]
        """
        # Ensure we work with 2D tensor
        batch_size = x.size(0)
        if x.dim() > 2:
            x_flat = x.reshape(batch_size, -1)
        else:
            x_flat = x
        
        # Forward pass
        return self.discriminator(x_flat)

# ResNet18 classifier adapted for TinyImageNet
class ResNet18TinyImageNet(nn.Module):
    """
    ResNet-18 implementation adapted for TinyImageNet with 64x64 RGB images
    """
    def __init__(self, num_classes=200):
        super(ResNet18TinyImageNet, self).__init__()
        from model.resnet import ResNet18
        
        # Use the existing ResNet18 for RGB input
        self.resnet = ResNet18(num_classes=num_classes, input_channels=3)
        
    def forward(self, x):
        return self.resnet(x)
    
    def extract_features(self, x):
        return self.resnet.extract_features(x)
    
class ResNet18Classifier(nn.Module):
    """
    ResNet18 classifier adapted for TinyImageNet dataset
    """
    def __init__(self, num_classes=200, device='cuda'):
        super(ResNet18Classifier, self).__init__()
        self.device = device
        
        # Use ResNet18 adapted for TinyImageNet
        self.model = ResNet18TinyImageNet(num_classes=num_classes)
        self.model = self.model.to(device)
        
        # Initialize criterion for training
        self.criterion = nn.CrossEntropyLoss()
        
    def init_optimizer(self, lr=0.001, momentum=0.9, weight_decay=5e-4):
        """Initialize optimizer for the model with given parameters"""
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay
        )
        return self.optimizer
        
    def get_weights(self):
        """Return the model weights"""
        return self.model.state_dict()
        
    def get_gradient_updates(self, old_state):
        """Compute gradient updates from old state"""
        updates = {}
        current_state = self.model.state_dict()
        for key in current_state:
            if key in old_state:
                updates[key] = current_state[key] - old_state[key]
        return updates
        
    def learn(self, epoch, dataloader):
        """Train the model for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            # Track statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        
        # Calculate metrics
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = 100.0 * correct / total
        
        return {
            "loss": epoch_loss,
            "acc": epoch_acc
        }
    
    def forward(self, x):
        """Forward pass through the model"""
        return self.model(x)

class EnhancedDyGAT(nn.Module):
    """
    Enhanced Dynamic Graph Attention Network for generating relational graphs
    with attention scores between clients
    """
    def __init__(self, opt):
        super(EnhancedDyGAT, self).__init__()
        self.opt = opt
        self.num_clients = opt.num_clients
        self.device = torch.device(opt.device)
        self.hidden_dim = 128
        self.embedding_dim = 64
        
        # Feature extraction network
        self.encoder = nn.Sequential(
            nn.Linear(1000, self.hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_dim, self.embedding_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.embedding_dim)
        )
        
        # Graph attention network
        self.gat = MultiHeadGAT(
            in_features=self.embedding_dim,
            hidden_dim=self.embedding_dim // 2,
            out_features=self.embedding_dim,
            n_heads=4,
            dropout=0.2
        )
    
    def _flatten_updates(self, updates):
        """
        Flatten model updates into a fixed-size feature vector
        """
        param_list = []
        for param_name, param_update in updates.items():
            if param_update.numel() > 0:
                param_list.append(param_update.view(-1))
        
        if param_list:
            flat_tensor = torch.cat(param_list)
            
            # Use importance sampling to select most significant parameters
            if flat_tensor.size(0) >= 1000:
                # Use magnitude as importance
                importance = torch.abs(flat_tensor)
                _, indices = torch.topk(importance, 1000)
                return flat_tensor[indices]
            else:
                # Pad if smaller
                padded = torch.zeros(1000, device=self.device)
                padded[:flat_tensor.size(0)] = flat_tensor
                return padded
        else:
            return torch.zeros(1000, device=self.device)
    
    def _process_updates(self, model_updates):
        """
        Process and encode client model updates
        """
        client_features = []
        for client_id in range(self.num_clients):
            if client_id in model_updates:
                flat_params = self._flatten_updates(model_updates[client_id])
                client_features.append(flat_params)
            else:
                client_features.append(torch.zeros(1000, device=self.device))
        
        # Stack all client features and encode
        features = torch.stack(client_features)
        encoded = self.encoder(features)
        
        return encoded
    
    def learn(self, epochs, model_updates):
        """
        Generate relational graph using attention mechanism
        """
        with torch.no_grad():
            try:
                # Process model updates into embeddings
                embeddings = self._process_updates(model_updates)
                
                # Apply graph attention network (initially without adjacency)
                refined_embeddings, attention_weights = self.gat(embeddings)
                
                # Convert attention weights to numpy for easier processing
                relation_graph = attention_weights.cpu().numpy()
                
                # Ensure proper normalization
                # Each row sums to 1 due to softmax in attention mechanism
                
                # Log the generated graph properties
                density = np.mean(relation_graph > 0.01)
                avg_relation = np.mean(relation_graph)
                max_relation = np.max(relation_graph)
                min_relation = np.min(relation_graph[relation_graph > 0])
                
                # logger.info(f"Generated relational graph with density: {density:.4f}, "
                #             f"avg: {avg_relation:.4f}, max: {max_relation:.4f}, min: {min_relation:.4f}")
                
                return relation_graph
                
            except Exception as e:
                # Log the error details
                logger.error(f"Error in DyGAT.learn: {str(e)}")
                
                # Create a fallback relation graph with some random structure
                fallback_graph = np.random.rand(self.num_clients, self.num_clients) * 0.3
                fallback_graph = (fallback_graph + fallback_graph.T) / 2  # Make symmetric
                np.fill_diagonal(fallback_graph, 1.0)  # Self-connections are strong
                
                logger.warning("Using fallback relation graph due to error")
                return fallback_graph
    
    def forward(self, model_updates):
        """
        Forward pass for compatibility
        """
        with torch.no_grad():
            return torch.tensor(self.learn(1, model_updates))

 
def tensor_memory_in_MB(tensor):
    return tensor.element_size() * tensor.nelement() / (1024 ** 2)