import torch
import numpy as np
import random
import os
import logging
import time
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
import copy
import csv
import ray

# Initialize Ray - ignore reinit error in case Ray is already running
ray.init(ignore_reinit_error=True)

from model.modules import ResNet18Classifier, EnhancedDyGAT, TemporalGAT
from utils.dp_ks_analysis import DifferentialPrivacyAnalyzer
from utils.model_utils import *
from utils.visualization_utils import *
from utils.log_utils import *
from utils.plot_utils import *
from utils.evaluation_utils import *
from utils.metrics_utils import MetricsRecorder  # Import the new metrics recorder
from utils.communication_tracker import CommunicationTracker
from utils.quality_evaluator import QualityEvaluator
from utils.heatmap_analysis import (
    compute_data_distributions,
    build_classifier_update_similarity,
    build_distribution_similarity,
    plot_updates_vs_data_heatmaps,
    plot_similarity_scatter,
    pearson_correlation,
    plot_correlation_over_tasks,
    plot_temporal_heatmaps,
)

# Import our Server and ModifiedClient classes
from model.server import Server
from model.client import ModifiedClient

logger = logging.getLogger('GFedCL')

def to_tensor(x, device="cuda"):
    """Convert numpy array or tensor to tensor on specified device"""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).to(device)
    else:
        x = x.to(device)
    return x

def add_laplace_noise(data, scale):
    """Add Laplace noise to data (works with both tensors and numpy arrays)"""
    if isinstance(data, np.ndarray):
        # For numpy arrays
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + noise
        return noisy_data
    else:
        # For tensors
        device = data.device if hasattr(data, 'device') else 'cuda'
        noise = np.random.laplace(0, scale, data.shape)
        noise = np.array(noise, dtype=np.float32)
        noisy_data = data + to_tensor(noise, device)
        return noisy_data

def add_laplace_noise_to_graph(relational_graph, scale, normalize=True):
    """
    Add Laplace noise to relational graph while maintaining graph properties
    
    Args:
        relational_graph: numpy array representing the graph
        scale: Scale parameter for Laplace noise
        normalize: Whether to normalize the graph after adding noise
        
    Returns:
        noisy_graph: Graph with added Laplace noise
    """
    # Add Laplace noise
    noisy_graph = add_laplace_noise(relational_graph, scale)
    
    # Ensure non-negative values (attention scores should be non-negative)
    noisy_graph = np.maximum(noisy_graph, 0)
    
    if normalize:
        # Normalize rows to sum to 1 (maintain attention property)
        row_sums = noisy_graph.sum(axis=1, keepdims=True)
        # Avoid division by zero
        row_sums = np.maximum(row_sums, 1e-8)
        noisy_graph = noisy_graph / row_sums
    
    return noisy_graph.astype(np.float32)

@ray.remote(num_gpus=3/10)
def generate_encodings_remote(client, task, relational_graphs, dataloader, generate_synthetic=False):
    """
    Generate encodings from a client in parallel
    """
    logger.info(f"Generating encodings from client {client.getId()} for task {task}")
    
    # Convert relational graphs to float32 if they're numpy arrays
    if isinstance(relational_graphs, list):
        for i in range(len(relational_graphs)):
            if isinstance(relational_graphs[i], np.ndarray):
                relational_graphs[i] = relational_graphs[i].astype(np.float32)
                
    result = client.generate_encodings(task, relational_graphs, dataloader, generate_synthetic)
    return result

@ray.remote(num_gpus=3/10)
def train_client_remote(client, task, relational_graphs, dataloader, epochs, generate_synthetic=False, metrics_recorder=None):
    """
    Train a client in parallel with timing
    
    Args:
        client: Client object
        task: Task ID
        relational_graphs: Relational graphs for all tasks
        dataloader: DataLoader for the client
        epochs: Number of epochs to train
        generate_synthetic: Whether to use synthetic samples
        metrics_recorder: MetricsRecorder instance (not used directly in remote function)
        
    Returns:
        dict: Client weights and timing information after training
    """
    logger.info(f"Training client {client.getId()} for task {task}")
    
    epoch_times = []
    for epoch in range(epochs):
        epoch_start = time.time()
        result = client.learn(epoch, task, relational_graphs, dataloader, generate_synthetic)
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)
    
    # Return client weights and timing info
    return {
        'weights': client.get_weights(),
        'epoch_times': epoch_times,
        'client_id': client.getId()
    }

@ray.remote(num_gpus=3/10)
def test_client_remote(client, task, dataloader):
    """
    Test a client in parallel
    
    Args:
        client: Client object
        task: Task ID  
        dataloader: DataLoader for testing
        
    Returns:
        dict: Test metrics
    """
    logger.info(f"Testing client {client.getId()} for task {task}")
    metrics = client.test(task, dataloader)
    return metrics

@ray.remote(num_gpus=3/10)
def train_classifier_remote(resnet18, task, dataloader, num_epochs):
    """
    Train a ResNet18 classifier in parallel
    
    Args:
        resnet18: ResNet18 classifier object
        task: Task ID
        dataloader: DataLoader for training
        num_epochs: Number of epochs to train
        
    Returns:
        dict: Status and timing information
    """
    logger.info(f"Training ResNet18 classifier for task {task}")
    
    epoch_times = []
    for epoch in range(num_epochs):
        epoch_start = time.time()
        metrics = resnet18.learn(epoch, dataloader)
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)
    
    return {
        'status': 0,
        'epoch_times': epoch_times
    }

def create_modified_clients(opt):
    """
    Create modified clients based on the given options
    
    Args:
        opt: Configuration options
        
    Returns:
        list: List of client objects
    """
    clients = []
    for i in range(opt.num_clients):
        client = ModifiedClient(i, opt)
        clients.append(client)
    
    return clients

class ParallelServerGFedCL:
    def __init__(self, opt):
        self.opt = opt
        self.use_graph = getattr(opt, "use_graph", True)
        self.use_temporal = getattr(opt, "use_temporal", True)
        # Handle device safely
        if torch.cuda.is_available() and opt.device == 'cuda':
            self.device = torch.device('cuda')
            # Print CUDA device info for debugging
            logger.info(f"Using CUDA: {torch.cuda.get_device_name(0)}")
        else:
            self.device = torch.device('cpu')
            logger.info("Using CPU")
            
        # Update opt.device to match actual device being used
        opt.device = str(self.device)
        
        # Initialize the server
        logger.info("Initializing server with global discriminator...")
        self.server = Server(opt)
        
        # Initialize the GAT for relational graph generation
        if self.use_graph:
            if self.use_temporal:
                logger.info("Initializing TemporalGAT for relational graph generation...")
                self.dygat = TemporalGAT(opt).to(self.device)
            else:
                logger.info("Initializing Enhanced DyGAT for relational graph generation...")
                self.dygat = EnhancedDyGAT(opt).to(self.device)
        else:
            self.dygat = None
            logger.info("Relational graph generation disabled (ablation: no_graph)")

        # Load and partition EMNIST-letter dataset
        logger.info("Setting up EMNIST-letter dataloaders...")
        from utils.dataset_utils import setup_emnist_loaders
        self.dataloaders = setup_emnist_loaders(opt)
        logger.info("EMNIST dataloaders prepared successfully")
        logger.info("Computing per-client data distributions for visualization...")
        self.data_distributions = compute_data_distributions(
            self.dataloaders, self.opt.num_clients, self.opt.num_task, self.opt.num_classes
        )
        logger.info("Data distribution snapshots ready")
        
        # Create modified clients
        logger.info("Creating modified clients...")
        self.clients = create_modified_clients(opt)
        logger.info(f"Created {len(self.clients)} clients")
        
        # Initialize DP analyzer if differential privacy is enabled
        self.dp_analyzer = DifferentialPrivacyAnalyzer(
            epsilon=self.opt.epsilon,
            sensitivity=self.opt.sensitivity,
            output_dir=os.path.join(self.opt.output_dir, 'dp_analysis')
        )

        # Initialize communication tracker
        self.comm_tracker = CommunicationTracker()
        logger.info("Initialized communication tracker")

        # Initialize quality evaluator (FID/IS) if enabled
        self.quality_evaluator = QualityEvaluator(opt)
        
        # Initialize metrics recorder
        logger.info("Initializing metrics recorder...")
        self.metrics_recorder = MetricsRecorder(opt)
        # Track per-task correlations and similarity matrices between updates and data
        self.update_data_correlations = []
        self.update_similarity_matrices = []
        self.data_similarity_matrices = []
        
    def train_GFedCL(self):
        logger.info('Starting Parallel Server-based GFedCL training for EMNIST-letter...')
        
        # Initialize relational graphs for each task as identity matrices
        relational_graphs = [np.eye(self.opt.num_clients) for _ in range(self.opt.num_task)]
        
        # Store original (non-noisy) graphs for visualization
        original_relational_graphs = [None for _ in range(self.opt.num_task)]
        
        # Create local ResNet18 models for all clients (adapted for EMNIST)
        logger.info("Initializing ResNet18 models for all clients (EMNIST-adapted)...")
        resnet18s = []
        for i in range(self.opt.num_clients):
            # Create model with safe device handling - adapted for EMNIST
            model = ResNet18Classifier(
                num_classes=self.opt.num_classes,  # 26 classes for EMNIST-letter
                device=self.opt.device
            )
            # Initialize optimizer
            model.init_optimizer(
                lr=self.opt.lr_f,
                momentum=0.9,
                weight_decay=self.opt.weight_decay
            )
            resnet18s.append(model)
        logger.info(f"Created {len(resnet18s)} ResNet18 models for EMNIST")
        
        # Track accuracy for each round and task
        round_accuracy = []
        round_labels = []
        
        # NEW: Track test accuracy on all previous tasks during each round
        all_tasks_accuracy = []  # Will store data for all rounds and all tasks
        
        # Training loop for each task
        for task in range(self.opt.num_task):
            logger.info(f'Training for task {task+1}/{self.opt.num_task}')
            
            # Clear CUDA cache between tasks if using GPU
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
            
            # Save initial model states for computing updates
            old_states = {}
            for i, model in enumerate(resnet18s):
                old_states[i] = model.get_weights()
            
            # Train local classifiers in parallel using Ray
            logger.info(f'Training classifiers for task {task+1} in parallel...')
            
            # Start timing for classifier training phase
            classifier_phase_start = self.metrics_recorder.computation_tracker.start_timer()
            
            classifier_futures = [
                train_classifier_remote.remote(resnet18, task, self.dataloaders[i][task]['train'], 15)  # Reduced epochs for EMNIST
                for i, resnet18 in enumerate(resnet18s)
            ]
            # Wait for all classifiers to finish training
            classifier_results = ray.get(classifier_futures)
            
            # Record classifier training time and individual epoch times
            self.metrics_recorder.record_phase_time(
                classifier_phase_start, 'classifier_training', task_id=task
            )
            
            # Aggregate epoch times from classifiers
            for i, result in enumerate(classifier_results):
                for epoch_time in result['epoch_times']:
                    self.metrics_recorder.computation_tracker.epoch_times[task][0].append(epoch_time)
            
            logger.info(f'Classifiers training completed for task {task+1}')
            
            # Calculate model updates for each client
            classifier_updates = {}
            for i, model in enumerate(resnet18s):
                # Get gradient updates
                updates = model.get_gradient_updates(old_states[i])
                classifier_updates[i] = updates

            # Visualize how classifier updates relate to underlying data distributions
            try:
                update_sim = build_classifier_update_similarity(
                    classifier_updates, self.opt.num_clients
                )
                data_sim = build_distribution_similarity(
                    self.data_distributions, task, self.opt.num_clients, self.opt.num_classes
                )
                corr_value = pearson_correlation(update_sim, data_sim)
                self.update_data_correlations.append(corr_value)
                self.update_similarity_matrices.append(update_sim)
                self.data_similarity_matrices.append(data_sim)
                heatmap_dir = os.path.join(self.opt.output_dir, "visualizations")
                pdf_path = plot_updates_vs_data_heatmaps(
                    update_sim, data_sim, task, heatmap_dir, relational_graphs[task]
                )
                scatter_path = plot_similarity_scatter(
                    update_sim, data_sim, task, heatmap_dir, corr_value
                )
                logger.info(
                    f"Task {task+1}: Saved classifier-update vs data-distribution heatmaps to {pdf_path}"
                )
                logger.info(
                    f"Task {task+1}: Pearson correlation (updates vs data) = {corr_value:.3f} "
                    f"(scatter: {scatter_path})"
                )
                # Save running temporal visuals so figures exist even if training stops early
                if self.update_similarity_matrices and self.data_similarity_matrices:
                    _ = plot_temporal_heatmaps(
                        self.update_similarity_matrices, self.data_similarity_matrices, heatmap_dir
                    )
                if self.update_data_correlations:
                    _ = plot_correlation_over_tasks(self.update_data_correlations, heatmap_dir)
            except Exception as e:
                logger.error(f"Failed to create update/data heatmaps for task {task+1}: {e}")

            logger.info('Tracking communication overhead for classifier updates sent to GAT...')
            self.comm_tracker.add_model_updates_communication(classifier_updates, self.opt.num_clients)
            
            if self.use_graph:
                logger.info('Generating attention-based relational graph using GAT')
                try:
                    relational_graph = self.dygat.learn(
                        self.opt.gat_epochs,
                        classifier_updates,
                        task_id=task
                    )

                    # Store the original graph before adding noise
                    original_relational_graphs[task] = relational_graph.copy()

                    # Add Laplace noise to the relational graph for differential privacy
                    if self.opt.dp:
                        logger.info(f'Adding Laplace noise to relational graph with scale {self.opt.b}')
                        relational_graphs[task] = add_laplace_noise_to_graph(
                            relational_graph,
                            scale=self.opt.b,
                            normalize=True
                        )

                        # Log statistics about the noise
                        noise_magnitude = np.abs(relational_graphs[task] - original_relational_graphs[task])
                        logger.info(
                            f'Noise statistics - Mean: {np.mean(noise_magnitude):.4f}, '
                            f'Max: {np.max(noise_magnitude):.4f}, '
                            f'Std: {np.std(noise_magnitude):.4f}'
                        )
                    else:
                        relational_graphs[task] = relational_graph

                    # Save both original and noisy graphs for visualization
                    save_heatmap(self.opt, task, relational_graphs)

                    # Also save the original graph if noise was added
                    if self.opt.dp:
                        original_opt = copy.deepcopy(self.opt)
                        original_opt.output_dir = os.path.join(self.opt.output_dir, 'original_graphs')
                        os.makedirs(original_opt.output_dir, exist_ok=True)
                        save_heatmap(original_opt, task, original_relational_graphs)

                except Exception as e:
                    logger.error(f"Error generating relational graph: {str(e)}")
                    logger.info("Using identity matrix as fallback")
                    relational_graphs[task] = np.eye(self.opt.num_clients)
                    original_relational_graphs[task] = relational_graphs[task].copy()

                if self.use_temporal:
                    try:
                        spatial_attention = (
                            self.dygat.last_spatial_attention
                            if hasattr(self.dygat, 'last_spatial_attention')
                            else None
                        )
                        temporal_patterns = (
                            self.dygat.last_temporal_patterns
                            if hasattr(self.dygat, 'last_temporal_patterns')
                            else None
                        )

                        if spatial_attention is not None:
                            self.visualize_attention_components(
                                task, spatial_attention, temporal_patterns, relational_graphs[task]
                            )
                    except Exception as e:
                        logger.error(f"Error visualizing attention components: {str(e)}")
            else:
                relational_graphs[task] = np.eye(self.opt.num_clients)
                original_relational_graphs[task] = relational_graphs[task].copy()

            if self.opt.dp and self.use_graph and original_relational_graphs[task] is not None:
                self.dp_analyzer.analyze_relational_graph(
                    original_relational_graphs[task],
                    task_id=task
                )

            for _ in range(self.opt.num_clients):
                self.comm_tracker.add_tensor_communication(
                    relational_graphs[task],
                    direction='download',
                    category='relational_graph'
                )
    
            # Main training loop with the server and modified clients
            for r in range(self.opt.num_rounds):
                logger.info(f'Round {r+1}/{self.opt.num_rounds}')

                # Start tracking this round
                self.comm_tracker.start_round(task, r)

                # Start timing for the entire round
                round_start = self.metrics_recorder.computation_tracker.start_timer()
                
                # PHASE 1: Collect encodings from all clients with current encoders (no training)
                logger.info(f'Phase 1: Collecting encodings from all clients in parallel')
                
                phase1_start = self.metrics_recorder.computation_tracker.start_timer()
                
                all_encodings = []
                all_graph_embeddings = []
                
                # First, distribute the server's current discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                    self.comm_tracker.add_model_weights_communication(
                        server_discriminator,
                        direction='download',
                        model_type='discriminator_weights'
                    )
                
                # Collect encodings from each client without training (in parallel)
                encoding_futures = []
                
                # Launch encoding generation for current task
                for client_id, client in enumerate(self.clients):
                    future = generate_encodings_remote.remote(
                        client, 
                        task,
                        relational_graphs,
                        self.dataloaders[client_id][task]['train'],
                        False  # No synthetic samples
                    )
                    encoding_futures.append(future)
                
                # Launch encoding generation for previous task if applicable
                if task >= 1:
                    for client_id, client in enumerate(self.clients):
                        future = generate_encodings_remote.remote(
                            client,
                            task-1,
                            relational_graphs,
                            self.dataloaders[client_id][task-1]['train'],
                            True  # Use synthetic samples
                        )
                        encoding_futures.append(future)
                
                # Collect results from all encoding futures
                encoding_results = ray.get(encoding_futures)
                
                # Process results
                synthetic_encodings_by_task = {prev_task: [] for prev_task in range(task)}
                real_encodings_by_task = {t: [] for t in range(task+1)}
                
                for i, result in enumerate(encoding_results):
                    encodings = result['encodings']
                    all_encodings.extend(encodings)
                    all_graph_embeddings.extend(result['graph_embeddings'])
                    
                    # Separate real and synthetic encodings for quality evaluation
                    if i < len(self.clients):  
                        # Real encodings for current task
                        real_encodings_by_task[task].extend(encodings)
                    else:  
                        # Synthetic encodings for previous tasks
                        # Calculate which previous task this is for
                        prev_task_idx = (i - len(self.clients)) // len(self.clients)
                        if prev_task_idx < task:
                            synthetic_encodings_by_task[prev_task_idx].extend(encodings)

                # Track encodings uploaded to server
                self.comm_tracker.add_encodings_communication(
                    all_encodings,
                    all_graph_embeddings
                )
                
                # Record phase 1 time
                self.metrics_recorder.record_phase_time(
                    phase1_start, 'encoding_collection', task_id=task, round_id=r
                )
                
                # Store real encodings for current task for future comparisons
                if task in real_encodings_by_task and real_encodings_by_task[task]:
                    # Store these for quality evaluation when we generate synthetic data for this task in the future
                    if not hasattr(self, 'stored_real_encodings'):
                        self.stored_real_encodings = {}
                    self.stored_real_encodings[task] = real_encodings_by_task[task]
                
                # Evaluate synthetic data quality by comparing with real encodings from the SAME task
                if task >= 1:
                    logger.info("Evaluating synthetic data quality...")
                    
                    # For each previous task, compare synthetic with real encodings
                    for prev_task in range(task):
                        if prev_task in synthetic_encodings_by_task and synthetic_encodings_by_task[prev_task]:
                            synthetic_encs = synthetic_encodings_by_task[prev_task]
                            
                            # Get stored real encodings for this task
                            if hasattr(self, 'stored_real_encodings') and prev_task in self.stored_real_encodings:
                                real_encs = self.stored_real_encodings[prev_task]
                                
                                logger.info(f"  Task {prev_task+1} comparison:")
                                logger.info(f"    - Real encodings: {len(real_encs)} batches (from real images)")
                                logger.info(f"    - Synthetic encodings: {len(synthetic_encs)} batches (from random noise)")
                                
                                # Concatenate encodings
                                real_tensor = torch.cat(real_encs, dim=0)
                                synthetic_tensor = torch.cat(synthetic_encs, dim=0)
                                
                                # Calculate IS and FID for same-task comparison
                                if synthetic_tensor.shape[0] > 0 and real_tensor.shape[0] > 0:
                                    # For IS, we need labels - using task-specific class range
                                    task_classes = list(range(prev_task * self.opt.class_per_task, 
                                                            (prev_task + 1) * self.opt.class_per_task))
                                    
                                    # Create dummy labels for this task
                                    num_samples = real_tensor.shape[0]
                                    labels = torch.tensor([task_classes[i % len(task_classes)] 
                                                         for i in range(num_samples)], device=self.device)
                                    
                                    # Add real encodings and labels to metrics recorder
                                    # Pass as lists of individual samples, not concatenated tensors
                                    real_list = [real_tensor[i:i+1] for i in range(real_tensor.shape[0])]
                                    labels_list = [labels[i:i+1] for i in range(labels.shape[0])]
                                    self.metrics_recorder.add_real_encodings(real_list, labels_list)
                                    
                                    # Evaluate synthetic quality
                                    is_score, fid_score = self.metrics_recorder.evaluate_synthetic_quality(synthetic_tensor)
                                    if is_score is not None:
                                        logger.info(f"    - Quality metrics: IS={is_score:.4f}, FID={fid_score:.4f}")
                            else:
                                logger.warning(f"  No stored real encodings for task {prev_task+1} comparison")

                # PHASE 2: Train the server's discriminator with collected encodings
                logger.info(f'Phase 2: Training server discriminator with {len(all_encodings)} samples')
                
                phase2_start = self.metrics_recorder.computation_tracker.start_timer()
                
                if len(all_encodings) > 0:
                    # Apply Laplace noise to the encodings for differential privacy
                    all_encodings_noised = []
                    for encoding in all_encodings:
                        # Make sure the encoding is on the correct device
                        encoding = encoding.to(self.device)
                        # Add Laplace noise with scale parameter self.opt.b
                        noised_encoding = add_laplace_noise(encoding, self.opt.b)
                        all_encodings_noised.append(noised_encoding)
                    server_loss = self.server.train_discriminator(all_encodings_noised, all_graph_embeddings)
                    logger.info(f'Server discriminator loss: {server_loss:.4f}')
                else:
                    logger.warning("No encoded samples collected for server training")
                
                # Record phase 2 time
                self.metrics_recorder.record_phase_time(
                    phase2_start, 'server_training', task_id=task, round_id=r
                )
                
                # PHASE 3: Distribute the updated discriminator to clients and train them in parallel
                logger.info(f'Phase 3: Training clients with updated server discriminator in parallel')
                
                phase3_start = self.metrics_recorder.computation_tracker.start_timer()
                
                # Distribute the updated discriminator to all clients
                server_discriminator = self.server.get_discriminator()
                for client in self.clients:
                    client.set_server_discriminator(server_discriminator)
                    self.comm_tracker.add_model_weights_communication(
                        server_discriminator,
                        direction='download',
                        model_type='discriminator_weights'
                    )
                
                # Train clients with the updated discriminator in parallel
                training_futures = []
                
                # Launch client training for current task
                for client_id, client in enumerate(self.clients):
                    future = train_client_remote.remote(
                        client,
                        task,
                        relational_graphs, 
                        self.dataloaders[client_id][task]['train'],
                        self.opt.num_local_epochs,
                        False  # No synthetic samples for current task
                    )
                    training_futures.append(future)
                
                # Launch client training for previous task if applicable
                if task >= 1:
                    for client_id, client in enumerate(self.clients):
                        future = train_client_remote.remote(
                            client,
                            task-1,
                            relational_graphs,
                            self.dataloaders[client_id][task-1]['train'],
                            self.opt.num_local_epochs,
                            True  # Use synthetic samples
                        )
                        training_futures.append(future)
                
                # Collect results from all training futures
                logger.info(f'Waiting for client training to complete...')
                training_results = ray.get(training_futures[:len(self.clients)])  # Get only current task results
                
                # Process results and extract weights
                encoder_weights = []
                predictor_weights = []
                
                for result in training_results:
                    encoder_weights.append(result['weights']['encoder'])
                    predictor_weights.append(result['weights']['predictor'])

                    # Record individual client epoch times
                    for epoch_time in result['epoch_times']:
                        self.metrics_recorder.computation_tracker.epoch_times[task][r].append(epoch_time)

                    self.comm_tracker.add_model_weights_communication(
                        result['weights']['encoder'],
                        direction='upload',
                        model_type='encoder_weights'
                    )
                    self.comm_tracker.add_model_weights_communication(
                        result['weights']['predictor'],
                        direction='upload',
                        model_type='predictor_weights'
                    )
                
                # Record phase 3 time
                self.metrics_recorder.record_phase_time(
                    phase3_start, 'client_training', task_id=task, round_id=r
                )
                
                # Update server's learning rate
                self.server.update_learning_rate()
                
                # Average weights across clients using utility function from server_utils.py
                from utils.server_utils import average_weights
                global_encoder = average_weights(encoder_weights)
                global_predictor = average_weights(predictor_weights)

                # Update the local models with the averaged weights
                for client in self.clients:
                    client.set_weights({
                        'encoder': global_encoder,
                        'predictor': global_predictor
                    })
                    self.comm_tracker.add_model_weights_communication(
                        global_encoder,
                        direction='download',
                        model_type='encoder_weights'
                    )
                    self.comm_tracker.add_model_weights_communication(
                        global_predictor,
                        direction='download',
                        model_type='predictor_weights'
                    )

                # End round tracking
                self.comm_tracker.end_round()
                
                # Evaluate current performance for all clients on this task in parallel
                logger.info(f'Evaluating clients for task {task+1} in parallel...')
                testing_futures = [
                    test_client_remote.remote(
                        client,
                        task,
                        self.dataloaders[client_id][task]['test']
                    )
                    for client_id, client in enumerate(self.clients)
                ]
                
                # Collect test results
                test_results = ray.get(testing_futures)
                task_accuracies = [result["acc"] for result in test_results]
                
                # Calculate average accuracy across all clients for this round
                avg_accuracy = sum(task_accuracies) / len(task_accuracies)
                round_accuracy.append(avg_accuracy)
                round_labels.append(f"Task {task+1}, Round {r+1}")
                
                logger.info(f"Task {task+1}, Round {r+1} - Average Accuracy: {avg_accuracy:.2f}%")
                
                # Record total round time
                self.metrics_recorder.record_phase_time(
                    round_start, 'total_round', task_id=task, round_id=r
                )
                
                # NEW: Collect data for all tasks (including current and previous) during this round
                round_all_tasks_data = {
                    'round': f"Task {task+1}, Round {r+1}",
                    'current_task': task,
                    'round_number': r+1,
                    'tasks': {}
                }
                
                # Add current task accuracy
                round_all_tasks_data['tasks'][task] = avg_accuracy
                
                # Evaluate on all previous tasks for this round
                if task > 0:
                    logger.info(f"Evaluating performance on previous tasks during Task {task+1}, Round {r+1}...")
                    
                    for prev_task in range(task):
                        prev_task_futures = []
                        for client_id, client in enumerate(self.clients):
                            if prev_task in self.dataloaders[client_id]:
                                future = test_client_remote.remote(
                                    client,
                                    prev_task,
                                    self.dataloaders[client_id][prev_task]['test']
                                )
                                prev_task_futures.append(future)
                        
                        # Collect test results for previous task
                        if prev_task_futures:
                            prev_task_results = ray.get(prev_task_futures)
                            prev_task_accuracies = [result["acc"] for result in prev_task_results]
                            avg_prev_accuracy = sum(prev_task_accuracies) / len(prev_task_accuracies)
                            
                            # Log and store the accuracy on this previous task
                            logger.info(f"  Task {task+1}, Round {r+1} - Accuracy on previous Task {prev_task+1}: {avg_prev_accuracy:.2f}%")
                            round_all_tasks_data['tasks'][prev_task] = avg_prev_accuracy
                
                # Add this round's data to our tracking
                all_tasks_accuracy.append(round_all_tasks_data)

                # Optional FID/IS evaluations after each round
                self.quality_evaluator.evaluate_round(
                    task, r, self.clients, self.dataloaders, relational_graphs
                )

        # Save all metrics
        logger.info("Saving all collected metrics...")
        self.metrics_recorder.save_all_metrics()

        # Save round accuracy to CSV
        csv_path = os.path.join(self.opt.output_dir, 'round_accuracy.csv')
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Round', 'Average Accuracy'])
            for i, label in enumerate(round_labels):
                writer.writerow([label, f"{round_accuracy[i]:.2f}"])
        
        logger.info(f"Saved round accuracy data to {csv_path}")
        
        # NEW: Save all tasks accuracy to CSV
        all_tasks_csv_path = os.path.join(self.opt.output_dir, 'all_tasks_accuracy.csv')
        with open(all_tasks_csv_path, 'w', newline='') as csvfile:
            # Define CSV header
            fieldnames = ['Round', 'Current Task']
            # Add columns for each task
            for t in range(self.opt.num_task):
                fieldnames.append(f'Task {t+1} Accuracy')
            
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Write data for each round
            for round_data in all_tasks_accuracy:
                row = {
                    'Round': round_data['round'],
                    'Current Task': round_data['current_task'] + 1  # Add 1 for 1-based indexing
                }
                
                # Add accuracy for each task
                for t in range(self.opt.num_task):
                    if t in round_data['tasks']:
                        row[f'Task {t+1} Accuracy'] = f"{round_data['tasks'][t]:.2f}"
                    else:
                        row[f'Task {t+1} Accuracy'] = ""
                
                writer.writerow(row)
        
        logger.info(f"Saved all tasks accuracy data to {all_tasks_csv_path}")

        # Save communication overhead data
        comm_csv_path = os.path.join(self.opt.output_dir, 'communication_overhead.csv')
        self.comm_tracker.save_to_csv(comm_csv_path)
        logger.info(f"Saved communication overhead data to {comm_csv_path}")
        
        # Plot communication overhead
        comm_plots_dir = os.path.join(self.opt.output_dir, 'communication_plots')
        self.comm_tracker.plot_communication_overhead(comm_plots_dir)
        logger.info(f"Saved communication overhead plots to {comm_plots_dir}")
        
        comm_summary = self.comm_tracker.get_summary()
        logger.info("===== COMMUNICATION SUMMARY =====")
        for key, value in comm_summary.items():
            logger.info(f"{key}: {value}")
        logger.info("===============================")
        
        # Test accuracy after training all tasks
        logger.info("Evaluating final model accuracy...")
        all_tasks_acc = evaluate_all_tasks(self.opt, self.clients, self.dataloaders)
        
        if self.opt.dp:
            self.dp_analyzer.save_results()

        quality_summary = self.quality_evaluator.finalize()

        # Plot temporal similarities and Pearson correlation dynamics after all tasks complete
        vis_dir = os.path.join(self.opt.output_dir, "visualizations")
        if self.update_similarity_matrices and self.data_similarity_matrices:
            temporal_heatmap_path = plot_temporal_heatmaps(
                self.update_similarity_matrices, self.data_similarity_matrices, vis_dir
            )
            logger.info(
                f"Saved temporal similarity heatmaps (updates vs data) to {temporal_heatmap_path}"
            )
        if self.update_data_correlations:
            corr_path = plot_correlation_over_tasks(
                self.update_data_correlations, vis_dir
            )
            logger.info(
                f"Saved temporal correlation plot (updates vs data distributions) to {corr_path}"
            )

        return all_tasks_acc, all_tasks_accuracy, quality_summary
    
    def visualize_attention_components(self, task, spatial_attention, temporal_patterns, combined_attention):
        """
        Visualize the components of the attention mechanism
        
        Args:
            task: Current task ID
            spatial_attention: Spatial attention matrix
            temporal_patterns: Temporal pattern similarity matrix (can be None)
            combined_attention: Combined attention matrix
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        import os
        
        # Create output directory if it doesn't exist
        vis_dir = os.path.join(self.opt.output_dir, 'attention_visualizations')
        os.makedirs(vis_dir, exist_ok=True)
        
        # Create figure
        if temporal_patterns is not None:
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot temporal patterns
            sns.heatmap(temporal_patterns, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Temporal Patterns')
            
            # Plot combined attention
            sns.heatmap(combined_attention, ax=axes[2], cmap='viridis', annot=False)
            axes[2].set_title(f'Task {task+1}: Combined Attention')
        else:
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Plot spatial attention
            sns.heatmap(spatial_attention, ax=axes[0], cmap='viridis', annot=False)
            axes[0].set_title(f'Task {task+1}: Spatial Attention')
            
            # Plot combined attention (same as spatial in this case)
            sns.heatmap(combined_attention, ax=axes[1], cmap='viridis', annot=False)
            axes[1].set_title(f'Task {task+1}: Combined Attention')
        
        plt.tight_layout()
        plt.savefig(os.path.join(vis_dir, f'attention_components_task{task+1}.png'), dpi=300)
        plt.close()
        
        logger.info(f"Task {task+1}: Saved attention component visualization")
    
    def visualize_dp_comparison(self, original_graphs, noisy_graphs):
        """
        Create a visualization comparing original and noisy relational graphs
        
        Args:
            original_graphs: List of original relational graphs
            noisy_graphs: List of relational graphs with Laplace noise
        """
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        vis_dir = os.path.join(self.opt.output_dir, 'dp_comparison')
        os.makedirs(vis_dir, exist_ok=True)
        
        for task in range(self.opt.num_task):
            if original_graphs[task] is None:
                continue
                
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Original graph
            sns.heatmap(original_graphs[task], ax=axes[0], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[0].set_title(f'Task {task+1}: Original Graph')
            
            # Noisy graph
            sns.heatmap(noisy_graphs[task], ax=axes[1], cmap='viridis', 
                       vmin=0, vmax=1, annot=False)
            axes[1].set_title(f'Task {task+1}: Graph with Laplace Noise (ε={self.opt.epsilon})')
            
            # Difference
            diff = np.abs(original_graphs[task] - noisy_graphs[task])
            sns.heatmap(diff, ax=axes[2], cmap='Reds', annot=False)
            axes[2].set_title(f'Task {task+1}: Absolute Difference')
            
            plt.tight_layout()
            plt.savefig(os.path.join(vis_dir, f'dp_comparison_task{task+1}.png'), dpi=300)
            plt.close()
            
        logger.info(f"Saved DP comparison visualizations to {vis_dir}")
