import matplotlib.pyplot as plt
import os
import logging as logger
import numpy as np

def evaluate_all_tasks(opt, clients, dataloaders):
    """
    Evaluate all clients on all tasks to measure final performance
    
    Returns:
        Dictionary of regression metrics for all clients and tasks
    """
    # Structure to store metrics for each client on each task
    client_task_metrics = {i: {} for i in range(opt.num_clients)}
    
    # For tracking average performance
    task_avg_metrics = {t: {'mse': 0.0, 'mae': 0.0, 'rmse': 0.0, 'r2': 0.0} 
                       for t in range(opt.num_task)}
    client_avg_metrics = {i: {'mse': 0.0, 'mae': 0.0, 'rmse': 0.0, 'r2': 0.0} 
                         for i in range(opt.num_clients)}
    
    # Overall metrics
    overall_metrics = {'mse': 0.0, 'mae': 0.0, 'rmse': 0.0, 'r2': 0.0}
    
    # Evaluate each client on each task
    for client_id in range(opt.num_clients):
        client = clients[client_id]
        client_total_mse = 0.0
        client_total_mae = 0.0
        client_total_r2 = 0.0
        client_task_count = 0
        
        for task_id in range(opt.num_task):
            # Skip if the client doesn't have data for this task
            if task_id not in dataloaders[client_id]:
                logger.warning(f"Client {client_id} has no data for task {task_id}")
                continue
            
            # Test the client on this task
            test_metrics = client.test(
                task_id,
                dataloaders[client_id][task_id]['test']  # Use test data
            )
            
            # Store metrics
            client_task_metrics[client_id][task_id] = {
                'mse': test_metrics.get("mse", test_metrics.get("loss", 0)),
                'mae': test_metrics.get("mae", 0),
                'rmse': test_metrics.get("rmse", 0),
                'r2': test_metrics.get("r2", 0)
            }
            
            # Update task averages
            for metric in ['mse', 'mae', 'rmse', 'r2']:
                task_avg_metrics[task_id][metric] += client_task_metrics[client_id][task_id][metric]
            
            # Update client totals
            client_total_mse += client_task_metrics[client_id][task_id]['mse']
            client_total_mae += client_task_metrics[client_id][task_id]['mae']
            client_total_r2 += client_task_metrics[client_id][task_id]['r2']
            client_task_count += 1
        
        # Compute average metrics for this client
        if client_task_count > 0:
            client_avg_metrics[client_id]['mse'] = client_total_mse / client_task_count
            client_avg_metrics[client_id]['mae'] = client_total_mae / client_task_count
            client_avg_metrics[client_id]['rmse'] = np.sqrt(client_avg_metrics[client_id]['mse'])
            client_avg_metrics[client_id]['r2'] = client_total_r2 / client_task_count
    
    # Compute per-task average across all clients
    for task_id in range(opt.num_task):
        clients_with_task = sum(1 for c in client_task_metrics if task_id in client_task_metrics[c])
        if clients_with_task > 0:
            for metric in ['mse', 'mae', 'rmse', 'r2']:
                task_avg_metrics[task_id][metric] /= clients_with_task
    
    # Compute overall average
    total_entries = sum(len(metrics) for metrics in client_task_metrics.values())
    if total_entries > 0:
        for client_metrics in client_task_metrics.values():
            for task_metrics in client_metrics.values():
                overall_metrics['mse'] += task_metrics['mse']
                overall_metrics['mae'] += task_metrics['mae']
                overall_metrics['r2'] += task_metrics['r2']
        
        overall_metrics['mse'] /= total_entries
        overall_metrics['mae'] /= total_entries
        overall_metrics['rmse'] = np.sqrt(overall_metrics['mse'])
        overall_metrics['r2'] /= total_entries
    
    # Log the results
    logger.info("===== Final Regression Results =====")
    logger.info(f"Overall Metrics:")
    logger.info(f"  MSE: {overall_metrics['mse']:.6f}")
    logger.info(f"  MAE: {overall_metrics['mae']:.6f}")
    logger.info(f"  RMSE: {overall_metrics['rmse']:.6f}")
    logger.info(f"  R²: {overall_metrics['r2']:.4f}")
    
    logger.info("\nAverage metrics per task:")
    for task_id, metrics in task_avg_metrics.items():
        logger.info(f"  Task {task_id}: MSE={metrics['mse']:.6f}, MAE={metrics['mae']:.6f}, "
                   f"RMSE={metrics['rmse']:.6f}, R²={metrics['r2']:.4f}")
    
    logger.info("\nAverage metrics per client:")
    for client_id, metrics in client_avg_metrics.items():
        logger.info(f"  Client {client_id}: MSE={metrics['mse']:.6f}, MAE={metrics['mae']:.6f}, "
                   f"RMSE={metrics['rmse']:.6f}, R²={metrics['r2']:.4f}")
    
    # Return all metrics for plotting
    return {
        'client_task_metrics': client_task_metrics,
        'task_avg_metrics': task_avg_metrics,
        'client_avg_metrics': client_avg_metrics,
        'overall_metrics': overall_metrics,
        # For backward compatibility
        'client_task_acc': {c: {t: m['r2']*100 for t, m in tasks.items()} 
                           for c, tasks in client_task_metrics.items()},
        'task_avg_acc': {t: m['r2']*100 for t, m in task_avg_metrics.items()},
        'client_avg_acc': {c: m['r2']*100 for c, m in client_avg_metrics.items()},
        'overall_avg_acc': overall_metrics['r2'] * 100
    }