import torch
import numpy as np
import random
import os
import logging
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
import copy
import csv
import ray

# Initialize Ray - ignore reinit error in case Ray is already running
ray.init(ignore_reinit_error=True)

from model.modules import ResNet18Classifier, EnhancedDyGAT, TemporalGAT
from utils.dp_ks_analysis import DifferentialPrivacyAnalyzer
from utils.model_utils import *
from utils.visualization_utils import *
from utils.log_utils import *
from utils.plot_utils import *
from utils.evaluation_utils import *
from utils.communication_tracker import CommunicationTracker
from utils.quality_evaluator import QualityEvaluator

# Import our Server and ModifiedClient classes
from model.server import Server
from model.client import ModifiedClient

logger = logging.getLogger('GFedCL')

def to_tensor(x, device="cuda"):
    """Convert numpy array or tensor to tensor on specified device"""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).to(device)
    else:
        x = x.to(device)
    return x

def add_laplace_noise(data, scale):
    """Add Laplace noise to data (works with both tensors and numpy arrays)"""
    if isinstance(data, np.ndarray):
        # For numpy arrays
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + noise
        return noisy_data
    else:
        # For tensors
        device = data.device if hasattr(data, 'device') else 'cuda'
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + to_tensor(noise, device)
        return noisy_data

def add_laplace_noise_to_graph(relational_graph, scale, normalize=True):
    """
    Add Laplace noise to relational graph while maintaining graph properties
    
    Args:
        relational_graph: numpy array representing the graph
        scale: Scale parameter for Laplace noise
        normalize: Whether to normalize the graph after adding noise
        
    Returns:
        noisy_graph: Graph with added Laplace noise
    """
    # Add Laplace noise
    noisy_graph = add_laplace_noise(relational_graph, scale)
    
    # Ensure non-negative values (attention scores should be non-negative)
    noisy_graph = np.maximum(noisy_graph, 0)
    
    if normalize:
        # Normalize rows to sum to 1 (maintain attention property)
        row_sums = noisy_graph.sum(axis=1, keepdims=True)
        # Avoid division by zero
        row_sums = np.maximum(row_sums, 1e-8)
        noisy_graph = noisy_graph / row_sums
    
    return noisy_graph.astype(np.float32)

@ray.remote
def generate_encodings_remote(client, task, relational_graphs, dataloader, generate_synthetic=False):
    """
    Generate encodings from a client in parallel
    """
    logger.info(f"Generating encodings from client {client.getId()} for task {task}")
    
    # Convert relational graphs to float32 if they're numpy arrays
    if isinstance(relational_graphs, list):
        for i in range(len(relational_graphs)):
            if isinstance(relational_graphs[i], np.ndarray):
                relational_graphs[i] = relational_graphs[i].astype(np.float32)
                
    result = client.generate_encodings(task, relational_graphs, dataloader, generate_synthetic)
    return result

@ray.remote
def train_client_remote(client, task, relational_graphs, dataloader, epochs, generate_synthetic=False):
    """
    Train a client in parallel
    
    Args:
        client: Client object
        task: Task ID
        relational_graphs: Relational graphs for all tasks
        dataloader: DataLoader for the client
        epochs: Number of epochs to train
        generate_synthetic: Whether to use synthetic samples
        
    Returns:
        dict: Client weights after training
    """
    logger.info(f"Training client {client.getId()} for task {task}")
    for epoch in range(epochs):
        result = client.learn(epoch, task, relational_graphs, dataloader, generate_synthetic)
    
    # Return client weights
    return client.get_weights()

@ray.remote
def test_client_remote(client, task, dataloader):
    """
    Test a client in parallel
    
    Args:
        client: Client object
        task: Task ID  
        dataloader: DataLoader for testing
        
    Returns:
        dict: Test metrics
    """
    logger.info(f"Testing client {client.getId()} for task {task}")
    metrics = client.test(task, dataloader)
    return metrics

@ray.remote
def train_classifier_remote(resnet18, task, dataloader, num_epochs):
    """
    Train a ResNet18 classifier in parallel
    
    Args:
        resnet18: ResNet18 classifier object
        task: Task ID
        dataloader: DataLoader for training
        num_epochs: Number of epochs to train
        
    Returns:
        int: Status code (0 for success)
    """
    logger.info(f"Training ResNet18 classifier for task {task}")
    for epoch in range(num_epochs):
        metrics = resnet18.learn(epoch, dataloader)
    return 0

def create_modified_clients(opt):
    """
    Create modified clients based on the given options
    
    Args:
        opt: Configuration options
        
    Returns:
        list: List of client objects
    """
    clients = []
    for i in range(opt.num_clients):
        client = ModifiedClient(i, opt)
        clients.append(client)
    
    return clients

class ParallelServerGFedCL:
    def __init__(self, opt):
        self.opt = opt
        self.use_graph = getattr(opt, "use_graph", True)
        self.use_temporal = getattr(opt, "use_temporal", True)
        # Handle device safely
        if torch.cuda.is_available() and opt.device == 'cuda':
            self.device = torch.device('cuda')
            # Print CUDA device info for debugging
            logger.info(f"Using CUDA: {torch.cuda.get_device_name(0)}")
        else:
            self.device = torch.device('cpu')
            logger.info("Using CPU")
            
        # Update opt.device to match actual device being used
        opt.device = str(self.device)
        
        # Initialize the server
        logger.info("Initializing server with global discriminator...")
        self.server = Server(opt)
        
        # Initialize the GAT for relational graph generation
        if self.use_graph:
            if self.use_temporal:
                logger.info("Initializing TemporalGAT for relational graph generation...")
                self.dygat = TemporalGAT(opt).to(self.device)
            else:
                logger.info("Initializing Enhanced DyGAT for relational graph generation...")
                self.dygat = EnhancedDyGAT(opt).to(self.device)
        else:
            self.dygat = None
            logger.info("Relational graph generation disabled (ablation: no_graph)")

        # Load and partition CIFAR100 dataset
        logger.info("Setting up CIFAR100 dataloaders...")
        from utils.dataset_utils import setup_cifar100_loaders
        self.dataloaders = setup_cifar100_loaders(opt)
        logger.info("Dataloaders prepared successfully")
        
        # Create modified clients
        logger.info("Creating modified clients...")
        self.clients = create_modified_clients(opt)
        logger.info(f"Created {len(self.clients)} clients")
        
        # Initialize DP analyzer if differential privacy is enabled
        self.dp_analyzer = DifferentialPrivacyAnalyzer(
        epsilon=self.opt.epsilon,
        sensitivity=self.opt.sensitivity,
        output_dir=os.path.join(self.opt.output_dir, 'dp_analysis')
        )

        # Initialize communication tracker
        self.comm_tracker = CommunicationTracker()
        logger.info("Initialized communication tracker")

        # Initialize quality evaluator (FID/IS) if enabled
        self.quality_evaluator = QualityEvaluator(opt)
        
    def train_GFedCL(self):
        logger.info('Starting Parallel Server-based GFedCL training...')
        
        # Initialize relational graphs for each task as identity matrices
        relational_graphs = [np.eye(self.opt.num_clients) for _ in range(self.opt.num_task)]
        
        # Store original (non-noisy) graphs for visualization
        original_relational_graphs = [None for _ in range(self.opt.num_task)]
        
        # Create local ResNet18 models for all clients
        logger.info("Initializing ResNet18 models for all clients...")
        resnet18s = []
        for i in range(self.opt.num_clients):
            # Create model with safe device handling
            model = ResNet18Classifier(
                num_classes=self.opt.num_classes, 
                device=self.opt.device
            )
            # Initialize optimizer
            model.init_optimizer(
                lr=self.opt.lr_f,
                momentum=0.9,
                weight_decay=self.opt.weight_decay
            )
            resnet18s.append(model)
        logger.info(f"Created {len(resnet18s)} ResNet18 models")
        
        # Track accuracy for each round and task
        round_accuracy = []
        round_labels = []
        
        # NEW: Track test accuracy on all previous tasks during each round
        all_tasks_accuracy = []  # Will store data for all rounds and all tasks
        
        # Training loop for each task
        for task in range(self.opt.num_task):
            logger.info(f'Training for task {task+1}/{self.opt.num_task}')
            
            # Clear CUDA cache between tasks if using GPU
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
            
            # Save initial model states for computing updates
            old_states = {}
            for i, model in enumerate(resnet18s):
                old_states[i] = model.get_weights()
            
            # Train local classifiers in parallel using Ray
            logger.info(f'Training classifiers for task {task+1} in parallel...')
            # Run classifier training in batches to limit memory pressure
            for i in range(0, len(resnet18s), self.opt.ray_max_in_flight):
                batch = resnet18s[i : i + self.opt.ray_max_in_flight]
                futures = [
                    train_classifier_remote.options(
                        num_gpus=self.opt.ray_num_gpus_per_task,
                        num_cpus=self.opt.ray_num_cpus_per_task,
                    ).remote(
                        model,
                        task,
                        self.dataloaders[i + j][task]['train'],
                        20
                    )
                    for j, model in enumerate(batch)
                ]
                ray.get(futures)
            logger.info(f'Classifiers training completed for task {task+1}')
            
            # Calculate model updates for each client
            classifier_updates = {}
            for i, model in enumerate(resnet18s):
                # Get gradient updates
                updates = model.get_gradient_updates(old_states[i])
                classifier_updates[i] = updates
            
            # This happens once per task, before generating the relational graph
            logger.info('Tracking communication overhead for classifier updates sent to GAT...')
            self.comm_tracker.add_model_updates_communication(classifier_updates, self.opt.num_clients)

            if self.use_graph:
                # Generate relational graph using GAT with attention scores
                logger.info('Generating attention-based relational graph using GAT')
                try:
                    relational_graph = self.dygat.learn(
                        self.opt.gat_epochs,
                        classifier_updates,
                        task_id=task
                    )

                    # Store the original graph before adding noise
                    original_relational_graphs[task] = relational_graph.copy()

                    # Add Laplace noise to the relational graph for differential privacy
                    if self.opt.dp:
                        logger.info(f'Adding Laplace noise to relational graph with scale {self.opt.b}')
                        relational_graphs[task] = add_laplace_noise_to_graph(
                            relational_graph,
                            scale=self.opt.b,
                            normalize=True
                        )

                        # Log statistics about the noise
                        noise_magnitude = np.abs(relational_graphs[task] - original_relational_graphs[task])
                        logger.info(
                            f'Noise statistics - Mean: {np.mean(noise_magnitude):.4f}, '
                            f'Max: {np.max(noise_magnitude):.4f}, '
                            f'Std: {np.std(noise_magnitude):.4f}'
                        )
                    else:
                        relational_graphs[task] = relational_graph

                    # Save both original and noisy graphs for visualization
                    save_heatmap(self.opt, task, relational_graphs)

                    # Also save the original graph if noise was added
                    if self.opt.dp:
                        original_opt = copy.deepcopy(self.opt)
                        original_opt.output_dir = os.path.join(self.opt.output_dir, 'original_graphs')
                        os.makedirs(original_opt.output_dir, exist_ok=True)
                        save_heatmap(original_opt, task, original_relational_graphs)

                except Exception as e:
                    logger.error(f"Error generating relational graph: {str(e)}")
                    logger.info("Using identity matrix as fallback")
                    relational_graphs[task] = np.eye(self.opt.num_clients)
                    original_relational_graphs[task] = relational_graphs[task].copy()

                if self.use_temporal:
                    try:
                        spatial_attention = (
                            self.dygat.last_spatial_attention
                            if hasattr(self.dygat, "last_spatial_attention")
                            else None
                        )
                        temporal_patterns = (
                            self.dygat.last_temporal_patterns
                            if hasattr(self.dygat, "last_temporal_patterns")
                            else None
                        )

                        if spatial_attention is not None:
                            self.visualize_attention_components(
                                task, spatial_attention, temporal_patterns, relational_graphs[task]
                            )
                    except Exception as e:
                        logger.error(f"Error visualizing attention components: {str(e)}")
            else:
                relational_graphs[task] = np.eye(self.opt.num_clients)
                original_relational_graphs[task] = relational_graphs[task].copy()

            if self.opt.dp and self.use_graph and original_relational_graphs[task] is not None:
                self.dp_analyzer.analyze_relational_graph(
                    original_relational_graphs[task],
                    task_id=task
                )

            for _ in range(self.opt.num_clients):
                self.comm_tracker.add_tensor_communication(
                    relational_graphs[task], 
                    direction='download', 
                    category='relational_graph'
                )

            # Main training loop with the server and modified clients
            for r in range(self.opt.num_rounds):
                logger.info(f'Round {r+1}/{self.opt.num_rounds}')
                
                # Start tracking this round
                self.comm_tracker.start_round(task, r)

                # PHASE 1: Collect encodings from all clients with current encoders (no training)
                logger.info(f'Phase 1: Collecting encodings from all clients in parallel')
                all_encodings = []
                all_graph_embeddings = []
                
                # Track discriminator sent to clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                    # Each client downloads the discriminator
                    self.comm_tracker.add_model_weights_communication(
                        server_discriminator, 
                        direction='download',
                        model_type='discriminator_weights'
                    )

                # First, distribute the server's current discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                
                # Collect encodings from each client without training (batched)
                encoding_results = []

                for i in range(0, len(self.clients), self.opt.ray_max_in_flight):
                    batch_clients = self.clients[i : i + self.opt.ray_max_in_flight]
                    futures = []
                    for j, client in enumerate(batch_clients):
                        client_id = i + j
                        futures.append(
                            generate_encodings_remote.options(
                                num_gpus=self.opt.ray_num_gpus_per_task,
                                num_cpus=self.opt.ray_num_cpus_per_task,
                            ).remote(
                                client,
                                task,
                                relational_graphs,
                                self.dataloaders[client_id][task]['train'],
                                False,
                            )
                        )
                    encoding_results.extend(ray.get(futures))

                if task >= 1:
                    for i in range(0, len(self.clients), self.opt.ray_max_in_flight):
                        batch_clients = self.clients[i : i + self.opt.ray_max_in_flight]
                        futures = []
                        for j, client in enumerate(batch_clients):
                            client_id = i + j
                            futures.append(
                                generate_encodings_remote.options(
                                    num_gpus=self.opt.ray_num_gpus_per_task,
                                    num_cpus=self.opt.ray_num_cpus_per_task,
                                ).remote(
                                    client,
                                    task - 1,
                                    relational_graphs,
                                    self.dataloaders[client_id][task - 1]['train'],
                                    True,
                                )
                            )
                        encoding_results.extend(ray.get(futures))
                
                # Process results
                for result in encoding_results:
                    all_encodings.extend(result['encodings'])
                    all_graph_embeddings.extend(result['graph_embeddings'])
                
                # Track encodings uploaded to server
                self.comm_tracker.add_encodings_communication(
                    all_encodings, 
                    all_graph_embeddings
                )

                if self.opt.dp:
                    # Analyze encodings before applying noise
                    self.dp_analyzer.analyze_encodings(
                        all_encodings, 
                        client_ids=list(range(len(all_encodings))),
                        task_id=task,
                        round_id=r
                    )

                # PHASE 2: Train the server's discriminator with collected encodings
                logger.info(f'Phase 2: Training server discriminator with {len(all_encodings)} samples')
                if len(all_encodings) > 0:
                    # Apply Laplace noise to the encodings for differential privacy
                    # We need to handle each encoding tensor individually since all_encodings is a list of tensors
                    all_encodings_noised = []
                    for encoding in all_encodings:
                        # Make sure the encoding is on the correct device
                        encoding = encoding.to(self.device)
                        # Add Laplace noise with scale parameter self.opt.b
                        noised_encoding = add_laplace_noise(encoding, self.opt.b)
                        all_encodings_noised.append(noised_encoding)
                    server_loss = self.server.train_discriminator(all_encodings_noised, all_graph_embeddings)
                    logger.info(f'Server discriminator loss: {server_loss:.4f}')
                else:
                    logger.warning("No encoded samples collected for server training")
                
                # PHASE 3: Distribute the updated discriminator to clients and train them in parallel
                logger.info(f'Phase 3: Training clients with updated server discriminator in parallel')
                
                # Distribute the updated discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                    # Track updated discriminator download for each client
                    self.comm_tracker.add_model_weights_communication(
                        server_discriminator, 
                        direction='download',
                        model_type='discriminator_weights'
                    )

                # Train clients with the updated discriminator in parallel (batched)
                logger.info('Waiting for client training to complete...')
                training_results = []

                for i in range(0, len(self.clients), self.opt.ray_max_in_flight):
                    batch_clients = self.clients[i : i + self.opt.ray_max_in_flight]
                    futures = []
                    for j, client in enumerate(batch_clients):
                        client_id = i + j
                        futures.append(
                            train_client_remote.options(
                                num_gpus=self.opt.ray_num_gpus_per_task,
                                num_cpus=self.opt.ray_num_cpus_per_task,
                            ).remote(
                                client,
                                task,
                                relational_graphs,
                                self.dataloaders[client_id][task]['train'],
                                self.opt.num_local_epochs,
                                False,
                            )
                        )
                    training_results.extend(ray.get(futures))

                if task >= 1:
                    for i in range(0, len(self.clients), self.opt.ray_max_in_flight):
                        batch_clients = self.clients[i : i + self.opt.ray_max_in_flight]
                        futures = []
                        for j, client in enumerate(batch_clients):
                            client_id = i + j
                            futures.append(
                                train_client_remote.options(
                                    num_gpus=self.opt.ray_num_gpus_per_task,
                                    num_cpus=self.opt.ray_num_cpus_per_task,
                                ).remote(
                                    client,
                                    task - 1,
                                    relational_graphs,
                                    self.dataloaders[client_id][task - 1]['train'],
                                    self.opt.num_local_epochs,
                                    True,
                                )
                            )
                        ray.get(futures)
                
                # Process results and extract weights
                encoder_weights = []
                predictor_weights = []
                
                for result in training_results:
                    encoder_weights.append(result['encoder'])
                    predictor_weights.append(result['predictor'])

                    # Track upload of each client's weights separately
                    self.comm_tracker.add_model_weights_communication(
                        result['encoder'], 
                        direction='upload',
                        model_type='encoder_weights'
                    )
                    self.comm_tracker.add_model_weights_communication(
                        result['predictor'], 
                        direction='upload',
                        model_type='predictor_weights'
                    )
                
                # Update server's learning rate
                self.server.update_learning_rate()
                
                # Average weights across clients using utility function from server_utils.py
                from utils.server_utils import average_weights
                global_encoder = average_weights(encoder_weights)
                global_predictor = average_weights(predictor_weights)

                # Update the local models with the averaged weights
                for client in self.clients:
                    client.set_weights({
                        'encoder': global_encoder,
                        'predictor': global_predictor
                    })
                    
                    # Track download of global model to each client
                    self.comm_tracker.add_model_weights_communication(
                        global_encoder, 
                        direction='download',
                        model_type='encoder_weights'
                    )
                    self.comm_tracker.add_model_weights_communication(
                        global_predictor, 
                        direction='download',
                        model_type='predictor_weights'
                    )
                
                # End round tracking
                round_comm_data = self.comm_tracker.end_round()

                # Evaluate current performance for all clients on this task in parallel
                logger.info(f'Evaluating clients for task {task+1} in parallel...')
                test_results = []
                for i in range(0, len(self.clients), self.opt.ray_max_in_flight):
                    batch_clients = self.clients[i : i + self.opt.ray_max_in_flight]
                    futures = []
                    for j, client in enumerate(batch_clients):
                        client_id = i + j
                        futures.append(
                            test_client_remote.options(
                                num_gpus=self.opt.ray_num_gpus_per_task,
                                num_cpus=self.opt.ray_num_cpus_per_task,
                            ).remote(
                                client,
                                task,
                                self.dataloaders[client_id][task]['test']
                            )
                        )
                    test_results.extend(ray.get(futures))
                task_accuracies = [result["acc"] for result in test_results]
                
                # Calculate average accuracy across all clients for this round
                avg_accuracy = sum(task_accuracies) / len(task_accuracies)
                round_accuracy.append(avg_accuracy)
                round_labels.append(f"Task {task+1}, Round {r+1}")
                
                logger.info(f"Task {task+1}, Round {r+1} - Average Accuracy: {avg_accuracy:.2f}%")
                
                # NEW: Collect data for all tasks (including current and previous) during this round
                round_all_tasks_data = {
                    'round': f"Task {task+1}, Round {r+1}",
                    'current_task': task,
                    'round_number': r+1,
                    'tasks': {}
                }
                
                # Add current task accuracy
                round_all_tasks_data['tasks'][task] = avg_accuracy
                
                # Evaluate on all previous tasks for this round
                if task > 0:
                    logger.info(f"Evaluating performance on previous tasks during Task {task+1}, Round {r+1}...")
                    
                    for prev_task in range(task):
                        prev_task_results = []
                        for i in range(0, len(self.clients), self.opt.ray_max_in_flight):
                            batch_clients = self.clients[i : i + self.opt.ray_max_in_flight]
                            futures = []
                            for j, client in enumerate(batch_clients):
                                client_id = i + j
                                if prev_task in self.dataloaders[client_id]:
                                    futures.append(
                                        test_client_remote.options(
                                            num_gpus=self.opt.ray_num_gpus_per_task,
                                            num_cpus=self.opt.ray_num_cpus_per_task,
                                        ).remote(
                                            client,
                                            prev_task,
                                            self.dataloaders[client_id][prev_task]['test']
                                        )
                                    )

                            if futures:
                                prev_task_results.extend(ray.get(futures))
                        
                        # Collect test results for previous task
                        if prev_task_results:
                            prev_task_accuracies = [result["acc"] for result in prev_task_results]
                            avg_prev_accuracy = sum(prev_task_accuracies) / len(prev_task_accuracies)
                            
                            # Log and store the accuracy on this previous task
                            logger.info(f"  Task {task+1}, Round {r+1} - Accuracy on previous Task {prev_task+1}: {avg_prev_accuracy:.2f}%")
                            round_all_tasks_data['tasks'][prev_task] = avg_prev_accuracy
                
                # Add this round's data to our tracking
                all_tasks_accuracy.append(round_all_tasks_data)

                # Optional FID/IS evaluations after each round
                self.quality_evaluator.evaluate_round(
                    task, r, self.clients, self.dataloaders, relational_graphs
                )

        # Save round accuracy to CSV
        csv_path = os.path.join(self.opt.output_dir, 'round_accuracy.csv')
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Round', 'Average Accuracy'])
            for i, label in enumerate(round_labels):
                writer.writerow([label, f"{round_accuracy[i]:.2f}"])
        
        logger.info(f"Saved round accuracy data to {csv_path}")
        
        # NEW: Save all tasks accuracy to CSV
        all_tasks_csv_path = os.path.join(self.opt.output_dir, 'all_tasks_accuracy.csv')
        with open(all_tasks_csv_path, 'w', newline='') as csvfile:
            # Define CSV header
            fieldnames = ['Round', 'Current Task']
            # Add columns for each task
            for t in range(self.opt.num_task):
                fieldnames.append(f'Task {t+1} Accuracy')
            
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Write data for each round
            for round_data in all_tasks_accuracy:
                row = {
                    'Round': round_data['round'],
                    'Current Task': round_data['current_task'] + 1  # Add 1 for 1-based indexing
                }
                
                # Add accuracy for each task
                for t in range(self.opt.num_task):
                    if t in round_data['tasks']:
                        row[f'Task {t+1} Accuracy'] = f"{round_data['tasks'][t]:.2f}"
                    else:
                        row[f'Task {t+1} Accuracy'] = ""
                
                writer.writerow(row)
        
        logger.info(f"Saved all tasks accuracy data to {all_tasks_csv_path}")
        
        # Create comparison visualization if DP is enabled
        if self.opt.dp and self.use_graph:
            self.visualize_dp_comparison(original_relational_graphs, relational_graphs)
        
        # Test accuracy after training all tasks
        logger.info("Evaluating final model accuracy...")
        all_tasks_acc = evaluate_all_tasks(self.opt, self.clients, self.dataloaders)
        
        if self.opt.dp:
            self.dp_analyzer.visualize_encoding_analysis()
            self.dp_analyzer.visualize_graph_analysis()
            self.dp_analyzer.save_results()

        # After all training is complete
        # Save communication data
        comm_csv_path = os.path.join(self.opt.output_dir, 'communication_overhead.csv')
        self.comm_tracker.save_to_csv(comm_csv_path)
        
        # Create visualizations
        comm_plots_dir = os.path.join(self.opt.output_dir, 'communication_plots')
        self.comm_tracker.plot_communication_overhead(comm_plots_dir)
        
        # Log detailed summary
        comm_summary = self.comm_tracker.get_summary()
        if comm_summary:
            logger.info("===== Communication Overhead Summary =====")
            logger.info(f"Pre-Trained Classifiers: {comm_summary['upload_breakdown'].get('classifier_updates', 0)} MB")
            logger.info(f"Total Upload: {comm_summary['total_upload_mb']:.2f} MB")
            logger.info(f"Total Download: {comm_summary['total_download_mb']:.2f} MB")
            logger.info(f"Total Communication: {comm_summary['total_communication_mb']:.2f} MB")
            logger.info(f"Average per Round: {comm_summary['avg_per_round_mb']:.2f} MB")
            logger.info(f"Number of Rounds: {comm_summary['num_rounds']}")
            
            logger.info("\nUpload Breakdown:")
            for category, size_mb in comm_summary['upload_breakdown'].items():
                if size_mb > 0:
                    percentage = (size_mb / comm_summary['total_upload_mb']) * 100
                    logger.info(f"  {category}: {size_mb:.2f} MB ({percentage:.1f}%)")
            
            logger.info("\nDownload Breakdown:")
            for category, size_mb in comm_summary['download_breakdown'].items():
                if size_mb > 0:
                    percentage = (size_mb / comm_summary['total_download_mb']) * 100
                    logger.info(f"  {category}: {size_mb:.2f} MB ({percentage:.1f}%)")
            
            logger.info("\nCommunication by Task:")
            for task_id, stats in comm_summary['task_statistics'].items():
                logger.info(f"  Task {task_id + 1}:")
                logger.info(f"    Upload: {stats['upload_mb']:.2f} MB")
                logger.info(f"    Download: {stats['download_mb']:.2f} MB")
                logger.info(f"    Total: {stats['total_mb']:.2f} MB")
                logger.info(f"    Rounds: {stats['rounds']}")
            logger.info("=========================================")
            
        quality_summary = self.quality_evaluator.finalize()

        return all_tasks_acc, all_tasks_accuracy, quality_summary
    
    def visualize_attention_components(self, task, spatial_attention, temporal_patterns, combined_attention):
        """
        Visualize the components of the attention mechanism
        
        Args:
            task: Current task ID
            spatial_attention: Spatial attention matrix
            temporal_patterns: Temporal pattern similarity matrix (can be None)
            combined_attention: Combined attention matrix
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        import os
        
        # Create output directory if it doesn't exist
        vis_dir = os.path.join(self.opt.output_dir, 'attention_visualizations')
        os.makedirs(vis_dir, exist_ok=True)
        
        # Create figure
        if temporal_patterns is not None:
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot temporal patterns
            sns.heatmap(temporal_patterns, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Temporal Patterns')
            
            # Plot combined attention
            sns.heatmap(combined_attention, ax=axes[2], cmap='viridis', annot=False)
            axes[2].set_title(f'Task {task+1}: Combined Attention')
        else:
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot combined attention (same as spatial in this case)
            sns.heatmap(combined_attention, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Combined Attention')
        
        plt.tight_layout()
        plt.savefig(os.path.join(vis_dir, f'attention_components_task{task+1}.png'), dpi=300)
        plt.close()
        
        logger.info(f"Task {task+1}: Saved attention component visualization")
    
    def visualize_dp_comparison(self, original_graphs, noisy_graphs):
        """
        Create a visualization comparing original and noisy relational graphs
        
        Args:
            original_graphs: List of original relational graphs
            noisy_graphs: List of relational graphs with Laplace noise
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        vis_dir = os.path.join(self.opt.output_dir, 'dp_comparison')
        os.makedirs(vis_dir, exist_ok=True)
        
        for task in range(self.opt.num_task):
            if original_graphs[task] is None:
                continue
                
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Original graph
            sns.heatmap(original_graphs[task], ax=axes[0], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[0].set_title(f'Task {task+1}: Original Graph')
            
            # Noisy graph
            sns.heatmap(noisy_graphs[task], ax=axes[1], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[1].set_title(f'Task {task+1}: Graph with Laplace Noise (ε={self.opt.epsilon})')
            
            # Difference
            diff = np.abs(original_graphs[task] - noisy_graphs[task])
            sns.heatmap(diff, ax=axes[2], cmap='Reds', annot=False)
            axes[2].set_title(f'Task {task+1}: Absolute Difference')
            
            plt.tight_layout()
            plt.savefig(os.path.join(vis_dir, f'dp_comparison_task{task+1}.png'), dpi=300)
            plt.close()
            
        logger.info(f"Saved DP comparison visualizations to {vis_dir}")
