import torch
import numpy as np
import random
import os
import logging
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
import copy
import csv
import ray

# Initialize Ray - ignore reinit error in case Ray is already running
ray.init(ignore_reinit_error=True)

from model.modules import TemporalGAT
from utils.dp_ks_analysis import DifferentialPrivacyAnalyzer
from utils.model_utils import *
from utils.visualization_utils import *
from utils.log_utils import *
from utils.plot_utils import *
from utils.evaluation_utils import *

# Import our Server and ModifiedClient classes
from model.server import Server
from model.client import ModifiedClient

logger = logging.getLogger('GFedCL')

def to_tensor(x, device="cuda"):
    """Convert numpy array or tensor to tensor on specified device"""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).to(device)
    else:
        x = x.to(device)
    return x

def add_laplace_noise(data, scale):
    """Add Laplace noise to data (works with both tensors and numpy arrays)"""
    if isinstance(data, np.ndarray):
        # For numpy arrays
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + noise
        return noisy_data
    else:
        # For tensors
        device = data.device if hasattr(data, 'device') else 'cuda'
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + to_tensor(noise, device)
        return noisy_data

def add_laplace_noise_to_graph(relational_graph, scale, normalize=True):
    """
    Add Laplace noise to relational graph while maintaining graph properties
    
    Args:
        relational_graph: numpy array representing the graph
        scale: Scale parameter for Laplace noise
        normalize: Whether to normalize the graph after adding noise
        
    Returns:
        noisy_graph: Graph with added Laplace noise
    """
    # Add Laplace noise
    noisy_graph = add_laplace_noise(relational_graph, scale)
    
    # Ensure non-negative values (attention scores should be non-negative)
    noisy_graph = np.maximum(noisy_graph, 0)
    
    if normalize:
        # Normalize rows to sum to 1 (maintain attention property)
        row_sums = noisy_graph.sum(axis=1, keepdims=True)
        # Avoid division by zero
        row_sums = np.maximum(row_sums, 1e-8)
        noisy_graph = noisy_graph / row_sums
    
    return noisy_graph.astype(np.float32)

@ray.remote(num_gpus=3/10)
def generate_encodings_remote(client, task, relational_graphs, dataloader, generate_synthetic=False):
    """
    Generate encodings from a client in parallel
    """
    logger.info(f"Generating encodings from client {client.getId()} for task {task}")
    
    # Convert relational graphs to float32 if they're numpy arrays
    if isinstance(relational_graphs, list):
        for i in range(len(relational_graphs)):
            if isinstance(relational_graphs[i], np.ndarray):
                relational_graphs[i] = relational_graphs[i].astype(np.float32)
                
    result = client.generate_encodings(task, relational_graphs, dataloader, generate_synthetic)
    return result

@ray.remote(num_gpus=3/10)
def train_client_remote(client, task, relational_graphs, dataloader, epochs, generate_synthetic=False):
    """
    Train a client in parallel
    
    Args:
        client: Client object
        task: Task ID
        relational_graphs: Relational graphs for all tasks
        dataloader: DataLoader for the client
        epochs: Number of epochs to train
        generate_synthetic: Whether to use synthetic samples
        
    Returns:
        dict: Client weights after training
    """
    logger.info(f"Training client {client.getId()} for task {task}")
    for epoch in range(epochs):
        result = client.learn(epoch, task, relational_graphs, dataloader, generate_synthetic)
    
    # Return client weights
    return client.get_weights()

@ray.remote(num_gpus=3/10)
def test_client_remote(client, task, dataloader):
    """
    Test a client in parallel
    
    Args:
        client: Client object
        task: Task ID  
        dataloader: DataLoader for testing
        
    Returns:
        dict: Test metrics
    """
    logger.info(f"Testing client {client.getId()} for task {task}")
    metrics = client.test(task, dataloader)
    return metrics

def create_modified_clients(opt):
    """
    Create modified clients based on the given options
    
    Args:
        opt: Configuration options
        
    Returns:
        list: List of client objects
    """
    clients = []
    for i in range(opt.num_clients):
        client = ModifiedClient(i, opt)
        clients.append(client)
    
    return clients

class ParallelServerGFedCL:
    def __init__(self, opt):
        self.opt = opt
        # Handle device safely
        if torch.cuda.is_available() and opt.device == 'cuda':
            self.device = torch.device('cuda')
            # Print CUDA device info for debugging
            logger.info(f"Using CUDA: {torch.cuda.get_device_name(0)}")
        else:
            self.device = torch.device('cpu')
            logger.info("Using CPU")
            
        # Update opt.device to match actual device being used
        opt.device = str(self.device)
        
        # Initialize the server
        logger.info("Initializing server with global discriminator...")
        self.server = Server(opt)
        
        # Initialize the Enhanced DyGAT for relational graph generation
        logger.info("Initializing TemporalGAT for relational graph generation...")
        self.dygat = TemporalGAT(opt).to(self.device)

        # Load and partition US-States dataset
        logger.info("Setting up US-States dataloaders...")
        from utils.dataset_utils import setup_ili_loaders
        self.dataloaders = setup_ili_loaders(opt)
        logger.info("US-States dataloaders prepared successfully")

        # Create modified clients
        logger.info("Creating modified clients...")
        self.clients = create_modified_clients(opt)
        logger.info(f"Created {len(self.clients)} clients")
        
        # Initialize DP analyzer if differential privacy is enabled
        self.dp_analyzer = DifferentialPrivacyAnalyzer(
            epsilon=self.opt.epsilon,
            sensitivity=self.opt.sensitivity,
            output_dir=os.path.join(self.opt.output_dir, 'dp_analysis')
        )
        
    def train_GFedCL(self):
        logger.info('Starting Parallel Server-based GFedCL training for ILI time series...')
        
        # Initialize relational graphs for each task as identity matrices
        relational_graphs = [np.eye(self.opt.num_clients) for _ in range(self.opt.num_task)]
        
        # Store original (non-noisy) graphs for visualization
        original_relational_graphs = [None for _ in range(self.opt.num_task)]
        
        # Track accuracy for each round and task
        # Track metrics for each round and task - UPDATED FOR REGRESSION
        round_metrics = {
            'r2': [],
            'mse': [],
            'mae': [],
            'rmse': []
        }
        round_labels = []
        
        # NEW: Track test accuracy on all previous tasks during each round
        all_tasks_metrics = []  # Will store data for all rounds and all tasks
        
        # Training loop for each task
        for task in range(self.opt.num_task):
            logger.info(f'Training for task {task+1}/{self.opt.num_task}')
            
            # Clear CUDA cache between tasks if using GPU
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
            
            # For ILI dataset, we don't need separate classifier training
            # Instead, we'll generate relational graphs based on client similarities
            
            # Generate relational graph using TemporalGAT
            logger.info('Generating attention-based relational graph using TemporalGAT')
            try:
                # For the first task, use identity matrix
                if task == 0:
                    relational_graph = np.eye(self.opt.num_clients, dtype=np.float32)
                else:
                    # Generate graph based on client similarities
                    # This is a placeholder - in practice, you'd compute this based on
                    # client performance or data characteristics
                    relational_graph = self.dygat.learn(
                        self.opt.gat_epochs,
                        {},  # Empty updates for now
                        task_id=task
                    )
                
                # Store the original graph before adding noise
                original_relational_graphs[task] = relational_graph.copy()
                
                # Add Laplace noise to the relational graph for differential privacy
                if self.opt.dp:
                    logger.info(f'Adding Laplace noise to relational graph with scale {self.opt.b}')
                    relational_graphs[task] = add_laplace_noise_to_graph(
                        relational_graph, 
                        scale=self.opt.b,
                        normalize=True
                    )
                    
                    # Log statistics about the noise
                    noise_magnitude = np.abs(relational_graphs[task] - original_relational_graphs[task])
                    logger.info(f'Noise statistics - Mean: {np.mean(noise_magnitude):.4f}, '
                               f'Max: {np.max(noise_magnitude):.4f}, '
                               f'Std: {np.std(noise_magnitude):.4f}')
                else:
                    relational_graphs[task] = relational_graph
                
                # Save both original and noisy graphs for visualization
                save_heatmap(self.opt, task, relational_graphs)
                
                # Also save the original graph if noise was added
                if self.opt.dp:
                    # Create a temporary option object for saving original graph
                    original_opt = copy.deepcopy(self.opt)
                    original_opt.output_dir = os.path.join(self.opt.output_dir, 'original_graphs')
                    os.makedirs(original_opt.output_dir, exist_ok=True)
                    save_heatmap(original_opt, task, original_relational_graphs)
                
            except Exception as e:
                # Fallback to identity matrix if there's an error
                logger.error(f"Error generating relational graph: {str(e)}")
                logger.info("Using identity matrix as fallback")
                relational_graphs[task] = np.eye(self.opt.num_clients, dtype=np.float32)
                original_relational_graphs[task] = relational_graphs[task].copy()
            
            try:
                # Get the components from the TemporalGAT
                spatial_attention = self.dygat.last_spatial_attention if hasattr(self.dygat, 'last_spatial_attention') else None
                temporal_patterns = self.dygat.last_temporal_patterns if hasattr(self.dygat, 'last_temporal_patterns') else None
                
                if spatial_attention is not None:
                    self.visualize_attention_components(task, spatial_attention, temporal_patterns, relational_graphs[task])
            except Exception as e:
                logger.error(f"Error visualizing attention components: {str(e)}")

            if self.opt.dp and original_relational_graphs[task] is not None:
                self.dp_analyzer.analyze_relational_graph(
                    original_relational_graphs[task],
                    task_id=task
                )
    
            # Main training loop with the server and modified clients
            for r in range(self.opt.num_rounds):
                logger.info(f'Round {r+1}/{self.opt.num_rounds}')
                
                # PHASE 1: Collect encodings from all clients with current encoders (no training)
                logger.info(f'Phase 1: Collecting encodings from all clients in parallel')
                all_encodings = []
                all_graph_embeddings = []
                
                # First, distribute the server's current discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                
                # Collect encodings from each client without training (in parallel)
                encoding_futures = []
                
                # Launch encoding generation for current task
                for client_id, client in enumerate(self.clients):
                    future = generate_encodings_remote.remote(
                        client, 
                        task,
                        relational_graphs,
                        self.dataloaders[client_id][task]['train'],
                        False  # No synthetic samples
                    )
                    encoding_futures.append(future)
                
                # Launch encoding generation for previous task if applicable
                if self.opt.replay and task >= 1:
                    logger.info(f'Using replay for previous task {task-1} encodings')
                    for client_id, client in enumerate(self.clients):
                        future = generate_encodings_remote.remote(
                            client,
                            task-1,
                            relational_graphs,
                            self.dataloaders[client_id][task-1]['train'],
                            True  # Use synthetic samples
                        )
                        encoding_futures.append(future)
                
                # Collect results from all encoding futures
                encoding_results = ray.get(encoding_futures)
                
                # Process results
                for result in encoding_results:
                    all_encodings.extend(result['encodings'])
                    all_graph_embeddings.extend(result['graph_embeddings'])
                
                # PHASE 2: Train the server's discriminator with collected encodings
                logger.info(f'Phase 2: Training server discriminator with {len(all_encodings)} samples')
                if len(all_encodings) > 0:
                    # Apply Laplace noise to the encodings for differential privacy
                    all_encodings_noised = []
                    for encoding in all_encodings:
                        # Make sure the encoding is on the correct device
                        encoding = encoding.to(self.device)
                        # Add Laplace noise with scale parameter self.opt.b
                        noised_encoding = add_laplace_noise(encoding, self.opt.b)
                        all_encodings_noised.append(noised_encoding)
                    server_loss = self.server.train_discriminator(all_encodings_noised, all_graph_embeddings)
                    logger.info(f'Server discriminator loss: {server_loss:.4f}')
                else:
                    logger.warning("No encoded samples collected for server training")
                
                # PHASE 3: Distribute the updated discriminator to clients and train them in parallel
                logger.info(f'Phase 3: Training clients with updated server discriminator in parallel')
                
                # Distribute the updated discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                
                # Train clients with the updated discriminator in parallel
                training_futures = []
                
                # Launch client training for current task
                for client_id, client in enumerate(self.clients):
                    future = train_client_remote.remote(
                        client,
                        task,
                        relational_graphs, 
                        self.dataloaders[client_id][task]['train'],
                        self.opt.num_local_epochs,
                        False  # No synthetic samples for current task
                    )
                    training_futures.append(future)
                
                if self.opt.replay and task >= 1:
                    logger.info(f'Using replay for previous task {task-1} training')
                    for client_id, client in enumerate(self.clients):
                        future = train_client_remote.remote(
                            client,
                            task-1,
                            relational_graphs,
                            self.dataloaders[client_id][task-1]['train'],
                            self.opt.num_local_epochs,
                            True  # Use synthetic samples
                        )
                        training_futures.append(future)
                
                # Collect results from all training futures
                logger.info(f'Waiting for client training to complete...')
                training_results = ray.get(training_futures[:len(self.clients)])  # Get only current task results
                
                # Process results and extract weights
                encoder_weights = []
                predictor_weights = []
                
                for result in training_results:
                    encoder_weights.append(result['encoder'])
                    predictor_weights.append(result['predictor'])
                
                # Update server's learning rate
                self.server.update_learning_rate()
                
                # Average weights across clients using utility function from server_utils.py
                from utils.server_utils import average_weights
                global_encoder = average_weights(encoder_weights)
                global_predictor = average_weights(predictor_weights)

                # Update the local models with the averaged weights
                for client in self.clients:
                    client.set_weights({
                        'encoder': global_encoder,
                        'predictor': global_predictor
                    })
                
                # Evaluate current performance for all clients on this task in parallel
                logger.info(f'Evaluating clients for task {task+1} in parallel...')
                testing_futures = [
                    test_client_remote.remote(
                        client,
                        task,
                        self.dataloaders[client_id][task]['test']
                    )
                    for client_id, client in enumerate(self.clients)
                ]
                
                # Collect test results
                test_results = ray.get(testing_futures)
                
                # Extract regression metrics - UPDATED
                task_r2_scores = [result["r2"] for result in test_results]
                task_mse_scores = [result["mse"] for result in test_results]
                task_mae_scores = [result["mae"] for result in test_results]
                task_rmse_scores = [result["rmse"] for result in test_results]
                
                # Calculate average metrics across all clients for this round
                avg_r2 = sum(task_r2_scores) / len(task_r2_scores)
                avg_mse = sum(task_mse_scores) / len(task_mse_scores)
                avg_mae = sum(task_mae_scores) / len(task_mae_scores)
                avg_rmse = sum(task_rmse_scores) / len(task_rmse_scores)
                
                # Store metrics
                round_metrics['r2'].append(avg_r2)
                round_metrics['mse'].append(avg_mse)
                round_metrics['mae'].append(avg_mae)
                round_metrics['rmse'].append(avg_rmse)
                round_labels.append(f"Task {task+1}, Round {r+1}")
                
                # Log with regression metrics
                logger.info(f"Task {task+1}, Round {r+1} - Regression Metrics:")
                logger.info(f"  R²: {avg_r2:.4f}, MSE: {avg_mse:.6f}, MAE: {avg_mae:.6f}, RMSE: {avg_rmse:.6f}")
                
                # Track metrics for all tasks
                round_all_tasks_data = {
                    'round': f"Task {task+1}, Round {r+1}",
                    'current_task': task,
                    'round_number': r+1,
                    'tasks': {},  # Still store R² * 100 for compatibility
                    'tasks_mse': {},  # NEW: Store MSE separately
                    'tasks_mae': {},  # NEW: Store MAE separately
                    'tasks_r2': {}    # NEW: Store raw R² values
                }
                
                # Add current task metrics
                round_all_tasks_data['tasks'][task] = avg_r2 * 100  # For compatibility
                round_all_tasks_data['tasks_mse'][task] = avg_mse
                round_all_tasks_data['tasks_mae'][task] = avg_mae
                round_all_tasks_data['tasks_r2'][task] = avg_r2
                
                # Evaluate on all previous tasks for this round
                if task > 0:
                    logger.info(f"Evaluating performance on previous tasks during Task {task+1}, Round {r+1}...")
                    
                    for prev_task in range(task):
                        prev_task_futures = []
                        for client_id, client in enumerate(self.clients):
                            if prev_task in self.dataloaders[client_id]:
                                future = test_client_remote.remote(
                                    client,
                                    prev_task,
                                    self.dataloaders[client_id][prev_task]['test']
                                )
                                prev_task_futures.append(future)
                        
                        # Collect test results for previous task
                        if prev_task_futures:
                            prev_task_results = ray.get(prev_task_futures)
                            
                            # Extract regression metrics for previous task
                            prev_r2_scores = [result["r2"] for result in prev_task_results]
                            prev_mse_scores = [result["mse"] for result in prev_task_results]
                            prev_mae_scores = [result["mae"] for result in prev_task_results]
                            
                            avg_prev_r2 = sum(prev_r2_scores) / len(prev_r2_scores)
                            avg_prev_mse = sum(prev_mse_scores) / len(prev_mse_scores)
                            avg_prev_mae = sum(prev_mae_scores) / len(prev_mae_scores)
                            
                            # Log and store the metrics on this previous task
                            logger.info(f"  Task {task+1}, Round {r+1} - Metrics on previous Task {prev_task+1}:")
                            logger.info(f"    R²: {avg_prev_r2:.4f}, MSE: {avg_prev_mse:.6f}, MAE: {avg_prev_mae:.6f}")
                            
                            round_all_tasks_data['tasks'][prev_task] = avg_prev_r2 * 100
                            round_all_tasks_data['tasks_mse'][prev_task] = avg_prev_mse
                            round_all_tasks_data['tasks_mae'][prev_task] = avg_prev_mae
                            round_all_tasks_data['tasks_r2'][prev_task] = avg_prev_r2
                
                # Add this round's data to our tracking
                all_tasks_metrics.append(round_all_tasks_data)

        # Save regression metrics to CSV files
        # Save R² scores (primary metric)
        csv_path_r2 = os.path.join(self.opt.output_dir, 'round_r2_scores.csv')
        with open(csv_path_r2, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Round', 'R² Score'])
            for i, label in enumerate(round_labels):
                writer.writerow([label, f"{round_metrics['r2'][i]:.4f}"])
        logger.info(f"Saved R² scores to {csv_path_r2}")
        
        # Save all regression metrics
        csv_path_all = os.path.join(self.opt.output_dir, 'round_regression_metrics.csv')
        with open(csv_path_all, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Round', 'R²', 'MSE', 'MAE', 'RMSE'])
            for i, label in enumerate(round_labels):
                writer.writerow([
                    label,
                    f"{round_metrics['r2'][i]:.4f}",
                    f"{round_metrics['mse'][i]:.6f}",
                    f"{round_metrics['mae'][i]:.6f}",
                    f"{round_metrics['rmse'][i]:.6f}"
                ])
        logger.info(f"Saved all regression metrics to {csv_path_all}")
        
        # Save all tasks metrics to CSV
        all_tasks_csv_path = os.path.join(self.opt.output_dir, 'all_tasks_regression_metrics.csv')
        with open(all_tasks_csv_path, 'w', newline='') as csvfile:
            # Define CSV header
            fieldnames = ['Round', 'Current Task']
            # Add columns for each task and metric
            for t in range(self.opt.num_task):
                fieldnames.extend([
                    f'Task {t+1} R²',
                    f'Task {t+1} MSE',
                    f'Task {t+1} MAE'
                ])
            
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Write data for each round
            for round_data in all_tasks_metrics:
                row = {
                    'Round': round_data['round'],
                    'Current Task': round_data['current_task'] + 1
                }
                
                # Add metrics for each task
                for t in range(self.opt.num_task):
                    if t in round_data.get('tasks_r2', {}):
                        row[f'Task {t+1} R²'] = f"{round_data['tasks_r2'][t]:.4f}"
                        row[f'Task {t+1} MSE'] = f"{round_data['tasks_mse'][t]:.6f}"
                        row[f'Task {t+1} MAE'] = f"{round_data['tasks_mae'][t]:.6f}"
                    else:
                        row[f'Task {t+1} R²'] = ""
                        row[f'Task {t+1} MSE'] = ""
                        row[f'Task {t+1} MAE'] = ""
                
                writer.writerow(row)
        
        logger.info(f"Saved all tasks regression metrics to {all_tasks_csv_path}")
        
        # Plot training curves for different metrics
        from utils.plot_utils import plot_training_curve
        plot_training_curve(self.opt, round_metrics['r2'], round_labels, metric='r2')
        plot_training_curve(self.opt, round_metrics['mse'], round_labels, metric='mse')
        plot_training_curve(self.opt, round_metrics['rmse'], round_labels, metric='rmse')
        
        # Test accuracy after training all tasks
        logger.info("Evaluating final model performance...")
        all_tasks_acc = evaluate_all_tasks(self.opt, self.clients, self.dataloaders)
        
        if self.opt.dp:
            self.dp_analyzer.save_results()

        return all_tasks_acc, all_tasks_metrics  # Return both standard metrics and the all-tasks tracking data
    
    def visualize_attention_components(self, task, spatial_attention, temporal_patterns, combined_attention):
        """
        Visualize the components of the attention mechanism
        
        Args:
            task: Current task ID
            spatial_attention: Spatial attention matrix
            temporal_patterns: Temporal pattern similarity matrix (can be None)
            combined_attention: Combined attention matrix
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        import os
        
        # Create output directory if it doesn't exist
        vis_dir = os.path.join(self.opt.output_dir, 'attention_visualizations')
        os.makedirs(vis_dir, exist_ok=True)
        
        # Create figure
        if temporal_patterns is not None:
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot temporal patterns
            sns.heatmap(temporal_patterns, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Temporal Patterns')
            
            # Plot combined attention
            sns.heatmap(combined_attention, ax=axes[2], cmap='viridis', annot=False)
            axes[2].set_title(f'Task {task+1}: Combined Attention')
        else:
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot combined attention (same as spatial in this case)
            sns.heatmap(combined_attention, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Combined Attention')
        
        plt.tight_layout()
        plt.savefig(os.path.join(vis_dir, f'attention_components_task{task+1}.png'), dpi=300)
        plt.close()
        
        logger.info(f"Task {task+1}: Saved attention component visualization")
    
    def visualize_dp_comparison(self, original_graphs, noisy_graphs):
        """
        Create a visualization comparing original and noisy relational graphs
        
        Args:
            original_graphs: List of original relational graphs
            noisy_graphs: List of relational graphs with Laplace noise
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        vis_dir = os.path.join(self.opt.output_dir, 'dp_comparison')
        os.makedirs(vis_dir, exist_ok=True)
        
        for task in range(self.opt.num_task):
            if original_graphs[task] is None:
                continue
                
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Original graph
            sns.heatmap(original_graphs[task], ax=axes[0], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[0].set_title(f'Task {task+1}: Original Graph')
            
            # Noisy graph
            sns.heatmap(noisy_graphs[task], ax=axes[1], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[1].set_title(f'Task {task+1}: Graph with Laplace Noise (ε={self.opt.epsilon})')
            
            # Difference
            diff = np.abs(original_graphs[task] - noisy_graphs[task])
            sns.heatmap(diff, ax=axes[2], cmap='Reds', annot=False)
            axes[2].set_title(f'Task {task+1}: Absolute Difference')
            
            plt.tight_layout()
            plt.savefig(os.path.join(vis_dir, f'dp_comparison_task{task+1}.png'), dpi=300)
            plt.close()
            
        logger.info(f"Saved DP comparison visualizations to {vis_dir}")