import copy
import logging
import os
import random

import matplotlib
import numpy as np
import torch

from configs.CIFAR100 import DEFAULT_OUTPUT_DIR, parse_args
from server_gfedcl import ParallelServerGFedCL
from utils.plot_utils import plot_all_tasks_accuracy, plot_results

matplotlib.use("Agg")


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def setup_logging(output_dir):
    os.makedirs(output_dir, exist_ok=True)
    log_file_path = os.path.join(output_dir, "run.log")

    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[logging.StreamHandler(), logging.FileHandler(log_file_path, mode="w")],
    )
    logger = logging.getLogger("GFedCL-Runner")
    logger.info(f"Logging initialized. Log file: {log_file_path}")
    return logger, log_file_path


def apply_cli_overrides(opt):
    if opt.ablation == "no_graph":
        opt.use_graph = False
    elif opt.ablation == "no_temporal":
        opt.use_temporal = False
    elif opt.ablation == "no_dp":
        opt.dp = False

    if getattr(opt, "output_dir_is_default", False) and opt.ablation != "none":
        opt.output_dir = f"{opt.output_dir}_{opt.ablation}"

    opt.log_path = os.path.join(opt.output_dir, "run.log")
    return opt


def main():
    opt = parse_args()
    opt = apply_cli_overrides(copy.deepcopy(opt))

    logger, log_path = setup_logging(opt.output_dir)
    set_seed(opt.seed)

    logger.info("Initializing Parallel Server-based GFedCL...")
    logger.info(f"Output directory: {opt.output_dir}")
    logger.info(f"Ablation: {opt.ablation}")

    gfedcl = ParallelServerGFedCL(opt)

    logger.info("Starting training with centralized server and Ray parallelization...")
    accuracy_results, all_tasks_accuracy, quality_summary = gfedcl.train_GFedCL()

    logger.info("Training completed.")

    plots_dir = plot_results(accuracy_results, opt.output_dir)
    plot_all_tasks_accuracy(opt, all_tasks_accuracy, plots_dir)

    logger.info("===== TRAINING SUMMARY =====")
    logger.info(f"Number of clients: {opt.num_clients}")
    logger.info(f"Number of tasks: {opt.num_task}")
    logger.info(f"Classes per task: {opt.class_per_task}")
    logger.info(f"Overall average accuracy: {accuracy_results['overall_avg_acc']:.2f}%")

    fid_scores = quality_summary.get("fid_scores") if quality_summary else None
    if fid_scores:
        valid_fid_scores = [
            data["fid_score"]
            for data in fid_scores
            if data.get("fid_score") is not None and not np.isnan(data["fid_score"])
        ]
        if valid_fid_scores:
            avg_fid = sum(valid_fid_scores) / len(valid_fid_scores)
            logger.info(f"Average FID Score: {avg_fid:.2f} (lower is better)")

    logger.info(f"Results saved to {opt.output_dir}")
    logger.info(f"Log file: {log_path}")
    logger.info(f"Accuracy plots: {plots_dir}")
    logger.info("===========================")


if __name__ == "__main__":
    os.environ["RAY_memory_monitor_refresh_ms"] = "0"
    main()
