import torch
import numpy as np
import random
import os
import logging
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
import copy
import csv
import ray

# Initialize Ray - ignore reinit error in case Ray is already running
ray.init(ignore_reinit_error=True)

from model.modules import ResNet18Classifier, EnhancedDyGAT, TemporalGAT
from utils.dp_ks_analysis import DifferentialPrivacyAnalyzer
from utils.model_utils import *
from utils.visualization_utils import *
from utils.log_utils import *
from utils.plot_utils import *
from utils.evaluation_utils import *
from utils.communication_tracker import CommunicationTracker
from utils.quality_evaluator import QualityEvaluator

# Import our Server and ModifiedClient classes
from model.server import Server
from model.client import ModifiedClient

logger = logging.getLogger('GFedCL')

def to_tensor(x, device="cuda"):
    """Convert numpy array or tensor to tensor on specified device"""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).to(device)
    else:
        x = x.to(device)
    return x

def add_laplace_noise(data, scale):
    """Add Laplace noise to data (works with both tensors and numpy arrays)"""
    if isinstance(data, np.ndarray):
        # For numpy arrays
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + noise
        return noisy_data
    else:
        # For tensors
        device = data.device if hasattr(data, 'device') else 'cuda'
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + to_tensor(noise, device)
        return noisy_data

def add_laplace_noise_to_graph(relational_graph, scale, normalize=True):
    """
    Add Laplace noise to relational graph while maintaining graph properties
    
    Args:
        relational_graph: numpy array representing the graph
        scale: Scale parameter for Laplace noise
        normalize: Whether to normalize the graph after adding noise
        
    Returns:
        noisy_graph: Graph with added Laplace noise
    """
    # Add Laplace noise
    noisy_graph = add_laplace_noise(relational_graph, scale)
    
    # Ensure non-negative values (attention scores should be non-negative)
    noisy_graph = np.maximum(noisy_graph, 0)
    
    if normalize:
        # Normalize rows to sum to 1 (maintain attention property)
        row_sums = noisy_graph.sum(axis=1, keepdims=True)
        # Avoid division by zero
        row_sums = np.maximum(row_sums, 1e-8)
        noisy_graph = noisy_graph / row_sums
    
    return noisy_graph.astype(np.float32)

@ray.remote(num_gpus=3/10)
def generate_encodings_remote(client, task, relational_graphs, dataloader, generate_synthetic=False):
    """
    Generate encodings from a client in parallel
    """
    logger.info(f"Generating encodings from client {client.getId()} for task {task}")
    
    # Convert relational graphs to float32 if they're numpy arrays
    if isinstance(relational_graphs, list):
        for i in range(len(relational_graphs)):
            if isinstance(relational_graphs[i], np.ndarray):
                relational_graphs[i] = relational_graphs[i].astype(np.float32)
                
    result = client.generate_encodings(task, relational_graphs, dataloader, generate_synthetic)
    return result

@ray.remote(num_gpus=3/10)
def train_client_remote(client, task, relational_graphs, dataloader, epochs, generate_synthetic=False):
    """
    Train a client in parallel
    
    Args:
        client: Client object
        task: Task ID
        relational_graphs: Relational graphs for all tasks
        dataloader: DataLoader for the client
        epochs: Number of epochs to train
        generate_synthetic: Whether to use synthetic samples
        
    Returns:
        dict: Client weights after training
    """
    logger.info(f"Training client {client.getId()} for task {task}")
    for epoch in range(epochs):
        result = client.learn(epoch, task, relational_graphs, dataloader, generate_synthetic)
    
    # Return client weights
    return client.get_weights()

@ray.remote(num_gpus=3/10)
def test_client_remote(client, task, dataloader):
    """
    Test a client in parallel
    
    Args:
        client: Client object
        task: Task ID  
        dataloader: DataLoader for testing
        
    Returns:
        dict: Test metrics
    """
    logger.info(f"Testing client {client.getId()} for task {task}")
    metrics = client.test(task, dataloader)
    return metrics

@ray.remote(num_gpus=3/10)
def train_classifier_remote(resnet18, task, dataloader, num_epochs):
    """
    Train a ResNet18 classifier in parallel
    
    Args:
        resnet18: ResNet18 classifier object
        task: Task ID
        dataloader: DataLoader for training
        num_epochs: Number of epochs to train
        
    Returns:
        int: Status code (0 for success)
    """
    logger.info(f"Training ResNet18 classifier for task {task}")
    for epoch in range(num_epochs):
        metrics = resnet18.learn(epoch, dataloader)
    return 0

def create_modified_clients(opt):
    """
    Create modified clients based on the given options
    
    Args:
        opt: Configuration options
        
    Returns:
        list: List of client objects
    """
    clients = []
    for i in range(opt.num_clients):
        client = ModifiedClient(i, opt)
        clients.append(client)
    
    return clients

class ParallelServerGFedCL:
    def __init__(self, opt):
        self.opt = opt
        self.use_graph = getattr(opt, "use_graph", True)
        self.use_temporal = getattr(opt, "use_temporal", True)
        # Handle device safely
        if torch.cuda.is_available() and opt.device == 'cuda':
            self.device = torch.device('cuda')
            # Print CUDA device info for debugging
            logger.info(f"Using CUDA: {torch.cuda.get_device_name(0)}")
        else:
            self.device = torch.device('cpu')
            logger.info("Using CPU")
            
        # Update opt.device to match actual device being used
        opt.device = str(self.device)
        
        # Initialize the server
        logger.info("Initializing server with global discriminator...")
        self.server = Server(opt)
        
        # Initialize the GAT for relational graph generation
        if self.use_graph:
            if self.use_temporal:
                logger.info("Initializing TemporalGAT for relational graph generation...")
                self.dygat = TemporalGAT(opt).to(self.device)
            else:
                logger.info("Initializing Enhanced DyGAT for relational graph generation...")
                self.dygat = EnhancedDyGAT(opt).to(self.device)
        else:
            self.dygat = None
            logger.info("Relational graph generation disabled (ablation: no_graph)")

        # Load and partition TinyImageNet dataset
        logger.info("Setting up TinyImageNet dataloaders...")
        from utils.dataset_utils import setup_tinyimagenet_loaders
        self.dataloaders = setup_tinyimagenet_loaders(opt)
        logger.info("TinyImageNet dataloaders prepared successfully")
        
        # Create modified clients
        logger.info("Creating modified clients...")
        self.clients = create_modified_clients(opt)
        logger.info(f"Created {len(self.clients)} clients")
        
        # Initialize DP analyzer if differential privacy is enabled
        self.dp_analyzer = DifferentialPrivacyAnalyzer(
        epsilon=self.opt.epsilon,
        sensitivity=self.opt.sensitivity,
        output_dir=os.path.join(self.opt.output_dir, 'dp_analysis')
        )

        # Initialize communication tracker
        self.comm_tracker = CommunicationTracker()
        logger.info("Initialized communication tracker")

        # Initialize quality evaluator (FID/IS) if enabled
        self.quality_evaluator = QualityEvaluator(opt)
        
    def train_GFedCL(self):
        logger.info('Starting Parallel Server-based GFedCL training for TinyImageNet...')
        
        # Initialize relational graphs for each task as identity matrices
        relational_graphs = [np.eye(self.opt.num_clients) for _ in range(self.opt.num_task)]
        
        # Store original (non-noisy) graphs for visualization
        original_relational_graphs = [None for _ in range(self.opt.num_task)]
        
        # Create local ResNet18 models for all clients (adapted for TinyImageNet)
        logger.info("Initializing ResNet18 models for all clients (TinyImageNet-adapted)...")
        resnet18s = []
        for i in range(self.opt.num_clients):
            # Create model with safe device handling - adapted for TinyImageNet
            model = ResNet18Classifier(
                num_classes=self.opt.num_classes,  # 200 classes for TinyImageNet
                device=self.opt.device
            )
            # Initialize optimizer
            model.init_optimizer(
                lr=self.opt.lr_f,
                momentum=0.9,
                weight_decay=self.opt.weight_decay
            )
            resnet18s.append(model)
        logger.info(f"Created {len(resnet18s)} ResNet18 models for TinyImageNet")
        
        # Track accuracy for each round and task
        round_accuracy = []
        round_labels = []
        
        # NEW: Track test accuracy on all previous tasks during each round
        all_tasks_accuracy = []  # Will store data for all rounds and all tasks
        
        # Training loop for each task
        for task in range(self.opt.num_task):
            logger.info(f'Training for task {task+1}/{self.opt.num_task}')
            
            # Clear CUDA cache between tasks if using GPU
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
            
            # Save initial model states for computing updates
            old_states = {}
            for i, model in enumerate(resnet18s):
                old_states[i] = model.get_weights()
            
            # Train local classifiers in parallel using Ray
            logger.info(f'Training classifiers for task {task+1} in parallel...')
            classifier_futures = [
                train_classifier_remote.remote(resnet18, task, self.dataloaders[i][task]['train'], 15)  # Increased epochs for TinyImageNet
                for i, resnet18 in enumerate(resnet18s)
            ]
            # Wait for all classifiers to finish training
            ray.get(classifier_futures)
            logger.info(f'Classifiers training completed for task {task+1}')
            
            # Calculate model updates for each client
            classifier_updates = {}
            for i, model in enumerate(resnet18s):
                # Get gradient updates
                updates = model.get_gradient_updates(old_states[i])
                classifier_updates[i] = updates

            logger.info('Tracking communication overhead for classifier updates sent to GAT...')
            self.comm_tracker.add_model_updates_communication(classifier_updates, self.opt.num_clients)
            
            if self.use_graph:
                logger.info('Generating attention-based relational graph using GAT')
                try:
                    relational_graph = self.dygat.learn(
                        self.opt.gat_epochs,
                        classifier_updates,
                        task_id=task
                    )

                    # Store the original graph before adding noise
                    original_relational_graphs[task] = relational_graph.copy()

                    # Add Laplace noise to the relational graph for differential privacy
                    if self.opt.dp:
                        logger.info(f'Adding Laplace noise to relational graph with scale {self.opt.b}')
                        relational_graphs[task] = add_laplace_noise_to_graph(
                            relational_graph,
                            scale=self.opt.b,
                            normalize=True
                        )

                        # Log statistics about the noise
                        noise_magnitude = np.abs(relational_graphs[task] - original_relational_graphs[task])
                        logger.info(
                            f'Noise statistics - Mean: {np.mean(noise_magnitude):.4f}, '
                            f'Max: {np.max(noise_magnitude):.4f}, '
                            f'Std: {np.std(noise_magnitude):.4f}'
                        )
                    else:
                        relational_graphs[task] = relational_graph

                    # Save both original and noisy graphs for visualization
                    save_heatmap(self.opt, task, relational_graphs)

                    # Also save the original graph if noise was added
                    if self.opt.dp:
                        original_opt = copy.deepcopy(self.opt)
                        original_opt.output_dir = os.path.join(self.opt.output_dir, 'original_graphs')
                        os.makedirs(original_opt.output_dir, exist_ok=True)
                        save_heatmap(original_opt, task, original_relational_graphs)

                except Exception as e:
                    logger.error(f"Error generating relational graph: {str(e)}")
                    logger.info("Using identity matrix as fallback")
                    relational_graphs[task] = np.eye(self.opt.num_clients)
                    original_relational_graphs[task] = relational_graphs[task].copy()

                if self.use_temporal:
                    try:
                        spatial_attention = (
                            self.dygat.last_spatial_attention
                            if hasattr(self.dygat, 'last_spatial_attention')
                            else None
                        )
                        temporal_patterns = (
                            self.dygat.last_temporal_patterns
                            if hasattr(self.dygat, 'last_temporal_patterns')
                            else None
                        )

                        if spatial_attention is not None:
                            self.visualize_attention_components(
                                task, spatial_attention, temporal_patterns, relational_graphs[task]
                            )
                    except Exception as e:
                        logger.error(f"Error visualizing attention components: {str(e)}")
            else:
                relational_graphs[task] = np.eye(self.opt.num_clients)
                original_relational_graphs[task] = relational_graphs[task].copy()

            if self.opt.dp and self.use_graph and original_relational_graphs[task] is not None:
                self.dp_analyzer.analyze_relational_graph(
                    original_relational_graphs[task],
                    task_id=task
                )

            for _ in range(self.opt.num_clients):
                self.comm_tracker.add_tensor_communication(
                    relational_graphs[task],
                    direction='download',
                    category='relational_graph'
                )
    
            # Main training loop with the server and modified clients
            for r in range(self.opt.num_rounds):
                logger.info(f'Round {r+1}/{self.opt.num_rounds}')

                # Start tracking this round
                self.comm_tracker.start_round(task, r)
                
                # PHASE 1: Collect encodings from all clients with current encoders (no training)
                logger.info(f'Phase 1: Collecting encodings from all clients in parallel')
                all_encodings = []
                all_graph_embeddings = []
                
                # First, distribute the server's current discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                    self.comm_tracker.add_model_weights_communication(
                        server_discriminator,
                        direction='download',
                        model_type='discriminator_weights'
                    )
                
                # Collect encodings from each client without training (in parallel)
                encoding_futures = []
                
                # Launch encoding generation for current task
                for client_id, client in enumerate(self.clients):
                    future = generate_encodings_remote.remote(
                        client, 
                        task,
                        relational_graphs,
                        self.dataloaders[client_id][task]['train'],
                        False  # No synthetic samples
                    )
                    encoding_futures.append(future)
                
                # Launch encoding generation for previous task if applicable
                if self.opt.replay:
                    logger.info(f'Using replay for previous task {task-1} encodings')
                    # Generate encodings for previous task with synthetic samples
                    # This is only done if the task is greater than 0
                    if task >= 1:
                        for client_id, client in enumerate(self.clients):
                            future = generate_encodings_remote.remote(
                                client,
                                task-1,
                                relational_graphs,
                                self.dataloaders[client_id][task-1]['train'],
                                True  # Use synthetic samples
                            )
                            encoding_futures.append(future)
                
                # Collect results from all encoding futures
                encoding_results = ray.get(encoding_futures)
                
                # Process results
                for result in encoding_results:
                    all_encodings.extend(result['encodings'])
                    all_graph_embeddings.extend(result['graph_embeddings'])

                # Track encodings uploaded to server
                self.comm_tracker.add_encodings_communication(
                    all_encodings,
                    all_graph_embeddings
                )
                
                # PHASE 2: Train the server's discriminator with collected encodings
                logger.info(f'Phase 2: Training server discriminator with {len(all_encodings)} samples')
                if len(all_encodings) > 0:
                    # Apply Laplace noise to the encodings for differential privacy
                    all_encodings_noised = []
                    for encoding in all_encodings:
                        # Make sure the encoding is on the correct device
                        encoding = encoding.to(self.device)
                        # Add Laplace noise with scale parameter self.opt.b
                        noised_encoding = add_laplace_noise(encoding, self.opt.b)
                        all_encodings_noised.append(noised_encoding)
                    server_loss = self.server.train_discriminator(all_encodings_noised, all_graph_embeddings)
                    logger.info(f'Server discriminator loss: {server_loss:.4f}')
                else:
                    logger.warning("No encoded samples collected for server training")
                
                # PHASE 3: Distribute the updated discriminator to clients and train them in parallel
                logger.info(f'Phase 3: Training clients with updated server discriminator in parallel')
                
                # Distribute the updated discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                    self.comm_tracker.add_model_weights_communication(
                        server_discriminator,
                        direction='download',
                        model_type='discriminator_weights'
                    )
                
                # Train clients with the updated discriminator in parallel
                training_futures = []
                
                # Launch client training for current task
                for client_id, client in enumerate(self.clients):
                    future = train_client_remote.remote(
                        client,
                        task,
                        relational_graphs, 
                        self.dataloaders[client_id][task]['train'],
                        self.opt.num_local_epochs,
                        False  # No synthetic samples for current task
                    )
                    training_futures.append(future)
                
                if self.opt.replay:
                    logger.info(f'Using replay for previous task {task-1} training')
                    # Launch client training for previous task if applicable
                    if task >= 1:
                        for client_id, client in enumerate(self.clients):
                            future = train_client_remote.remote(
                                client,
                                task-1,
                                relational_graphs,
                                self.dataloaders[client_id][task-1]['train'],
                                self.opt.num_local_epochs,
                                True  # Use synthetic samples
                            )
                            training_futures.append(future)
                
                # Collect results from all training futures
                logger.info(f'Waiting for client training to complete...')
                training_results = ray.get(training_futures[:len(self.clients)])  # Get only current task results
                
                # Process results and extract weights
                encoder_weights = []
                predictor_weights = []
                
                for result in training_results:
                    encoder_weights.append(result['encoder'])
                    predictor_weights.append(result['predictor'])
                    self.comm_tracker.add_model_weights_communication(
                        result['encoder'],
                        direction='upload',
                        model_type='encoder_weights'
                    )
                    self.comm_tracker.add_model_weights_communication(
                        result['predictor'],
                        direction='upload',
                        model_type='predictor_weights'
                    )
                
                # Update server's learning rate
                self.server.update_learning_rate()
                
                # Average weights across clients using utility function from server_utils.py
                from utils.server_utils import average_weights
                global_encoder = average_weights(encoder_weights)
                global_predictor = average_weights(predictor_weights)

                # Update the local models with the averaged weights
                for client in self.clients:
                    client.set_weights({
                        'encoder': global_encoder,
                        'predictor': global_predictor
                    })
                    self.comm_tracker.add_model_weights_communication(
                        global_encoder,
                        direction='download',
                        model_type='encoder_weights'
                    )
                    self.comm_tracker.add_model_weights_communication(
                        global_predictor,
                        direction='download',
                        model_type='predictor_weights'
                    )

                # End round tracking
                self.comm_tracker.end_round()
                
                # Evaluate current performance for all clients on this task in parallel
                logger.info(f'Evaluating clients for task {task+1} in parallel...')
                testing_futures = [
                    test_client_remote.remote(
                        client,
                        task,
                        self.dataloaders[client_id][task]['test']
                    )
                    for client_id, client in enumerate(self.clients)
                ]
                
                # Collect test results
                test_results = ray.get(testing_futures)
                task_accuracies = [result["acc"] for result in test_results]
                
                # Calculate average accuracy across all clients for this round
                avg_accuracy = sum(task_accuracies) / len(task_accuracies)
                round_accuracy.append(avg_accuracy)
                round_labels.append(f"Task {task+1}, Round {r+1}")
                
                logger.info(f"Task {task+1}, Round {r+1} - Average Accuracy: {avg_accuracy:.2f}%")
                
                # NEW: Collect data for all tasks (including current and previous) during this round
                round_all_tasks_data = {
                    'round': f"Task {task+1}, Round {r+1}",
                    'current_task': task,
                    'round_number': r+1,
                    'tasks': {}
                }
                
                # Add current task accuracy
                round_all_tasks_data['tasks'][task] = avg_accuracy
                
                # Evaluate on all previous tasks for this round
                if task > 0:
                    logger.info(f"Evaluating performance on previous tasks during Task {task+1}, Round {r+1}...")
                    
                    for prev_task in range(task):
                        prev_task_futures = []
                        for client_id, client in enumerate(self.clients):
                            if prev_task in self.dataloaders[client_id]:
                                future = test_client_remote.remote(
                                    client,
                                    prev_task,
                                    self.dataloaders[client_id][prev_task]['test']
                                )
                                prev_task_futures.append(future)
                        
                        # Collect test results for previous task
                        if prev_task_futures:
                            prev_task_results = ray.get(prev_task_futures)
                            prev_task_accuracies = [result["acc"] for result in prev_task_results]
                            avg_prev_accuracy = sum(prev_task_accuracies) / len(prev_task_accuracies)
                            
                            # Log and store the accuracy on this previous task
                            logger.info(f"  Task {task+1}, Round {r+1} - Accuracy on previous Task {prev_task+1}: {avg_prev_accuracy:.2f}%")
                            round_all_tasks_data['tasks'][prev_task] = avg_prev_accuracy
                
                # Add this round's data to our tracking
                all_tasks_accuracy.append(round_all_tasks_data)

                # Optional FID/IS evaluations after each round
                self.quality_evaluator.evaluate_round(
                    task, r, self.clients, self.dataloaders, relational_graphs
                )

        # Save round accuracy to CSV
        csv_path = os.path.join(self.opt.output_dir, 'round_accuracy.csv')
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Round', 'Average Accuracy'])
            for i, label in enumerate(round_labels):
                writer.writerow([label, f"{round_accuracy[i]:.2f}"])
        
        logger.info(f"Saved round accuracy data to {csv_path}")
        
        # NEW: Save all tasks accuracy to CSV
        all_tasks_csv_path = os.path.join(self.opt.output_dir, 'all_tasks_accuracy.csv')
        with open(all_tasks_csv_path, 'w', newline='') as csvfile:
            # Define CSV header
            fieldnames = ['Round', 'Current Task']
            # Add columns for each task
            for t in range(self.opt.num_task):
                fieldnames.append(f'Task {t+1} Accuracy')
            
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Write data for each round
            for round_data in all_tasks_accuracy:
                row = {
                    'Round': round_data['round'],
                    'Current Task': round_data['current_task'] + 1  # Add 1 for 1-based indexing
                }
                
                # Add accuracy for each task
                for t in range(self.opt.num_task):
                    if t in round_data['tasks']:
                        row[f'Task {t+1} Accuracy'] = f"{round_data['tasks'][t]:.2f}"
                    else:
                        row[f'Task {t+1} Accuracy'] = ""
                
                writer.writerow(row)
        
        logger.info(f"Saved all tasks accuracy data to {all_tasks_csv_path}")

        # Save communication overhead data
        comm_csv_path = os.path.join(self.opt.output_dir, 'communication_overhead.csv')
        self.comm_tracker.save_to_csv(comm_csv_path)
        logger.info(f"Saved communication overhead data to {comm_csv_path}")
        
        # Plot communication overhead
        comm_plots_dir = os.path.join(self.opt.output_dir, 'communication_plots')
        self.comm_tracker.plot_communication_overhead(comm_plots_dir)
        logger.info(f"Saved communication overhead plots to {comm_plots_dir}")
        
        comm_summary = self.comm_tracker.get_summary()
        logger.info("===== COMMUNICATION SUMMARY =====")
        for key, value in comm_summary.items():
            logger.info(f"{key}: {value}")
        logger.info("===============================")
        
        # Test accuracy after training all tasks
        logger.info("Evaluating final model accuracy...")
        all_tasks_acc = evaluate_all_tasks(self.opt, self.clients, self.dataloaders)
        
        if self.opt.dp:
            self.dp_analyzer.save_results()

        quality_summary = self.quality_evaluator.finalize()

        return all_tasks_acc, all_tasks_accuracy, quality_summary
    
    def visualize_attention_components(self, task, spatial_attention, temporal_patterns, combined_attention):
        """
        Visualize the components of the attention mechanism
        
        Args:
            task: Current task ID
            spatial_attention: Spatial attention matrix
            temporal_patterns: Temporal pattern similarity matrix (can be None)
            combined_attention: Combined attention matrix
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        import os
        
        # Create output directory if it doesn't exist
        vis_dir = os.path.join(self.opt.output_dir, 'attention_visualizations')
        os.makedirs(vis_dir, exist_ok=True)
        
        # Create figure
        if temporal_patterns is not None:
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot temporal patterns
            sns.heatmap(temporal_patterns, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Temporal Patterns')
            
            # Plot combined attention
            sns.heatmap(combined_attention, ax=axes[2], cmap='viridis', annot=False)
            axes[2].set_title(f'Task {task+1}: Combined Attention')
        else:
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot combined attention (same as spatial in this case)
            sns.heatmap(combined_attention, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Combined Attention')
        
        plt.tight_layout()
        plt.savefig(os.path.join(vis_dir, f'attention_components_task{task+1}.png'), dpi=300)
        plt.close()
        
        logger.info(f"Task {task+1}: Saved attention component visualization")
    
    def visualize_dp_comparison(self, original_graphs, noisy_graphs):
        """
        Create a visualization comparing original and noisy relational graphs
        
        Args:
            original_graphs: List of original relational graphs
            noisy_graphs: List of relational graphs with Laplace noise
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        vis_dir = os.path.join(self.opt.output_dir, 'dp_comparison')
        os.makedirs(vis_dir, exist_ok=True)
        
        for task in range(self.opt.num_task):
            if original_graphs[task] is None:
                continue
                
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Original graph
            sns.heatmap(original_graphs[task], ax=axes[0], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[0].set_title(f'Task {task+1}: Original Graph')
            
            # Noisy graph
            sns.heatmap(noisy_graphs[task], ax=axes[1], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[1].set_title(f'Task {task+1}: Graph with Laplace Noise (ε={self.opt.epsilon})')
            
            # Difference
            diff = np.abs(original_graphs[task] - noisy_graphs[task])
            sns.heatmap(diff, ax=axes[2], cmap='Reds', annot=False)
            axes[2].set_title(f'Task {task+1}: Absolute Difference')
            
            plt.tight_layout()
            plt.savefig(os.path.join(vis_dir, f'dp_comparison_task{task+1}.png'), dpi=300)
            plt.close()
            
        logger.info(f"Saved DP comparison visualizations to {vis_dir}")
