import torch
import numpy as np
import random
import os
import logging
import matplotlib
matplotlib.use('Agg')  # Use Agg backend for environments without a display

# Import the TinyImageNet configuration
from configs.TinyImageNet import parse_args
from gfedcl import ParallelServerGFedCL
from utils.plot_utils import plot_results, plot_all_tasks_accuracy

# Set random seeds for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def setup_logging(opt):
    """Set up logging configuration with proper directory creation"""
    # Create output directory first
    os.makedirs(opt.output_dir, exist_ok=True)
    
    # Create log file path
    log_file_path = opt.log_path or os.path.join(opt.output_dir, 'run.log')
    
    # Remove any existing handlers to avoid duplicates
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    
    # Set up logging with both console and file handlers
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),  # Console output
            logging.FileHandler(log_file_path, mode='w')  # File output
        ]
    )
    
    logger = logging.getLogger('GFedCL-TinyImageNet-Server')
    logger.info(f"Logging initialized. Log file: {log_file_path}")
    
    return logger

def apply_ablation(opt):
    if opt.ablation == "no_graph":
        opt.use_graph = False
    elif opt.ablation == "no_temporal":
        opt.use_temporal = False
    elif opt.ablation == "no_dp":
        opt.dp = False
    if getattr(opt, "output_dir_is_default", False) and opt.ablation != "none":
        opt.output_dir = f"{opt.output_dir}_{opt.ablation}"
        opt.log_path = os.path.join(opt.output_dir, "run.log")
    return opt

def main():
    opt = parse_args()
    # Set random seed
    set_seed(opt.seed)

    opt = apply_ablation(opt)
    
    # Set up logging first (this will create the output directory)
    logger = setup_logging(opt)
    
    logger.info('Initializing Parallel Server-based GFedCL for TinyImageNet...')
    logger.info(f'Dataset: {opt.dataset}')
    logger.info(f'Number of clients: {opt.num_clients}')
    logger.info(f'Number of tasks per client: {opt.num_task}')
    logger.info(f'Classes per task: {opt.class_per_task}')
    logger.info(f'Total classes: {opt.num_classes}')
    logger.info(f'Image size: {opt.image_size}x{opt.image_size}')
    logger.info(f'Number of channels: {opt.num_channels}')
    logger.info(f'Output directory: {opt.output_dir}')
    logger.info(f'Log file: {opt.log_path}')
    logger.info(f'Ablation: {opt.ablation}')
    
    # Check for CUDA availability and memory
    if torch.cuda.is_available():
        logger.info(f'CUDA Device: {torch.cuda.get_device_name(0)}')
        logger.info(f'CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
    else:
        logger.warning('CUDA not available, using CPU')
    
    gfedcl = ParallelServerGFedCL(opt)
    
    logger.info('Starting training with centralized server and Ray parallelization...')
    accuracy_results, all_tasks_accuracy, quality_summary = gfedcl.train_GFedCL()

    logger.info('Training completed.')
    
    # Plot and save the accuracy results
    plots_dir = plot_results(accuracy_results, opt.output_dir)
    
    # Plot all tasks accuracy over time
    if all_tasks_accuracy:
        plot_all_tasks_accuracy(opt, all_tasks_accuracy, plots_dir)
    
    # Print final summary
    logger.info('===== TRAINING SUMMARY =====')
    logger.info(f'Dataset: {opt.dataset}')
    logger.info(f'Number of clients: {opt.num_clients}')
    logger.info(f'Number of tasks per client: {opt.num_task}')
    logger.info(f'Classes per task: {opt.class_per_task}')
    logger.info(f'Total classes: {opt.num_classes}')
    logger.info(f'Overall average accuracy: {accuracy_results["overall_avg_acc"]:.2f}%')

    fid_scores = quality_summary.get("fid_scores") if quality_summary else None
    if fid_scores:
        valid_fid_scores = [
            data["fid_score"]
            for data in fid_scores
            if data.get("fid_score") is not None and not np.isnan(data["fid_score"])
        ]
        if valid_fid_scores:
            avg_fid = sum(valid_fid_scores) / len(valid_fid_scores)
            logger.info(f'Average FID Score: {avg_fid:.2f} (lower is better)')
    logger.info(f'Results saved to {opt.output_dir}')
    logger.info(f'Log file: {opt.log_path}')
    logger.info(f'Accuracy plots: {plots_dir}')
    
    # Log per-task performance
    logger.info('Performance by task:')
    for task_id, acc in accuracy_results['task_avg_acc'].items():
        logger.info(f'  Task {task_id+1}: {acc:.2f}%')
    
    logger.info('===========================')

if __name__ == "__main__":
    os.environ["RAY_memory_monitor_refresh_ms"] = "0"  # Disable Ray memory monitor
    main()
