import torch
import numpy as np
import random
import os
import logging
import matplotlib
matplotlib.use('Agg')  # Use Agg backend for environments without a display

# Import the EMNIST configuration
from configs.EMNIST import parse_args
from gfedcl import ParallelServerGFedCL
from utils.plot_utils import plot_results, plot_all_tasks_accuracy

# Set random seeds for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def setup_logging(opt):
    """Set up logging configuration with proper directory creation"""
    # Create output directory first
    os.makedirs(opt.output_dir, exist_ok=True)
    
    # Create log file path
    log_file_path = opt.log_path or os.path.join(opt.output_dir, 'run.log')
    
    # Remove any existing handlers to avoid duplicates
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    
    # Set up logging with both console and file handlers
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),  # Console output
            logging.FileHandler(log_file_path, mode='w')  # File output
        ]
    )
    
    logger = logging.getLogger('GFedCL-EMNIST-Server')
    logger.info(f"Logging initialized. Log file: {log_file_path}")
    
    return logger

def apply_ablation(opt):
    if opt.ablation == "no_graph":
        opt.use_graph = False
    elif opt.ablation == "no_temporal":
        opt.use_temporal = False
    elif opt.ablation == "no_dp":
        opt.dp = False
    if getattr(opt, "output_dir_is_default", False) and opt.ablation != "none":
        opt.output_dir = f"{opt.output_dir}_{opt.ablation}"
        opt.log_path = os.path.join(opt.output_dir, "run.log")
    return opt

def read_metrics_summary(output_dir):
    """Read and display the metrics summary if available"""
    metrics_path = os.path.join(output_dir, 'metrics', 'metrics_summary.txt')
    if os.path.exists(metrics_path):
        logger = logging.getLogger('GFedCL-EMNIST-Server')
        logger.info("\n===== METRICS SUMMARY =====")
        with open(metrics_path, 'r') as f:
            content = f.read()
            # Log each line separately to maintain formatting
            for line in content.split('\n'):
                if line.strip():  # Skip empty lines
                    logger.info(line)
        logger.info("===========================\n")
        return True
    return False

def main():
    opt = parse_args()
    # Set random seed
    set_seed(opt.seed)

    opt = apply_ablation(opt)
    
    # Set up logging first (this will create the output directory)
    logger = setup_logging(opt)
    
    logger.info('Initializing Parallel Server-based GFedCL for EMNIST-letter...')
    logger.info(f'Dataset: {opt.dataset}')
    logger.info(f'Number of clients: {opt.num_clients}')
    logger.info(f'Number of tasks per client: {opt.num_task}')
    logger.info(f'Classes per task: {opt.class_per_task}')
    logger.info(f'Total classes: {opt.num_classes}')
    logger.info(f'Image size: {opt.image_size}x{opt.image_size}')
    logger.info(f'Number of channels: {opt.num_channels}')
    logger.info(f'Output directory: {opt.output_dir}')
    logger.info(f'Log file: {opt.log_path}')
    logger.info(f'Ablation: {opt.ablation}')
    
    # Display differential privacy settings if enabled
    if opt.dp:
        logger.info(f'Differential Privacy: ENABLED')
        logger.info(f'  - Epsilon: {opt.epsilon}')
        logger.info(f'  - Sensitivity: {opt.sensitivity}')
        logger.info(f'  - Noise scale (b): {opt.b}')
    else:
        logger.info(f'Differential Privacy: DISABLED')
    
    gfedcl = ParallelServerGFedCL(opt)
    
    logger.info('Starting training with centralized server and Ray parallelization...')
    accuracy_results, all_tasks_accuracy, quality_summary = gfedcl.train_GFedCL()

    logger.info('Training completed.')
    
    # Plot and save the accuracy results
    plots_dir = plot_results(accuracy_results, opt.output_dir)
    
    # Plot all tasks accuracy over time
    if all_tasks_accuracy:
        plot_all_tasks_accuracy(opt, all_tasks_accuracy, plots_dir)
    
    # Print final summary
    logger.info('===== TRAINING SUMMARY =====')
    logger.info(f'Dataset: {opt.dataset}')
    logger.info(f'Number of clients: {opt.num_clients}')
    logger.info(f'Number of tasks per client: {opt.num_task}')
    logger.info(f'Classes per task: {opt.class_per_task}')
    logger.info(f'Total classes: {opt.num_classes}')
    logger.info(f'Overall average accuracy: {accuracy_results["overall_avg_acc"]:.2f}%')

    fid_scores = quality_summary.get("fid_scores") if quality_summary else None
    if fid_scores:
        valid_fid_scores = [
            data["fid_score"]
            for data in fid_scores
            if data.get("fid_score") is not None and not np.isnan(data["fid_score"])
        ]
        if valid_fid_scores:
            avg_fid = sum(valid_fid_scores) / len(valid_fid_scores)
            logger.info(f'Average FID Score: {avg_fid:.2f} (lower is better)')
    logger.info(f'Results saved to {opt.output_dir}')
    logger.info(f'Log file: {opt.log_path}')
    logger.info(f'Accuracy plots: {plots_dir}')
    
    # Log per-task performance
    logger.info('Performance by task:')
    for task_id, acc in accuracy_results['task_avg_acc'].items():
        logger.info(f'  Task {task_id+1}: {acc:.2f}%')
    
    logger.info('===========================')
    
    # Read and display metrics summary if available
    read_metrics_summary(opt.output_dir)
    
    # Display paths to key output files
    logger.info('\n===== OUTPUT FILES =====')
    logger.info(f'Round accuracy CSV: {os.path.join(opt.output_dir, "round_accuracy.csv")}')
    logger.info(f'All tasks accuracy CSV: {os.path.join(opt.output_dir, "all_tasks_accuracy.csv")}')
    logger.info(f'Metrics directory: {os.path.join(opt.output_dir, "metrics")}')
    logger.info(f'Visualizations: {os.path.join(opt.output_dir, "visualizations")}')
    if opt.dp:
        logger.info(f'DP analysis: {os.path.join(opt.output_dir, "dp_analysis")}')
    logger.info('========================')

if __name__ == "__main__":
    os.environ["RAY_memory_monitor_refresh_ms"] = "0"  # Disable Ray memory monitor
    main()
