import os
import sys
import torch
import copy
import argparse
import numpy as np
import logging
import datetime

from fedplora import fedplora
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from mmengine.config import Config


# ===========================  Utils  =========================== #

def create_log_dir(args):
    """Generate logging directory structure"""

    base = f"./log/{args.dataset}/{args.model}/{args.model_heterogeneity}"
    os.makedirs(base, exist_ok=True)

    name = args.config_name.replace(".yaml", "").replace(".yml", "")
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    path = os.path.join(base, f"{name}_{timestamp}")
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, "checkpoints"), exist_ok=True)

    return path


def merge_config(config, args):
    """Merge YAML config into argparse args"""
    for k in vars(args):
        setattr(config, k, getattr(args, k))
    return config


# ===========================  Main  =========================== #

def main():

    # --------------------- ArgParse --------------------- #
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=list, default=0)
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--repeat", type=int, default=1)
    parser.add_argument("--config_name", type=str, default="text_mrpc_bert_fedplora.yaml")

    args = parser.parse_args()
    sys.setrecursionlimit(10000)

    # ------------------- Load config file ------------------- #
    config_path = f"config/{args.config_name}"
    config = Config.fromfile(config_path)
    args = merge_config(config, args)

    # ---------------- Accelerate + Logger ---------------- #
    ddp = DistributedDataParallelKwargs(find_unused_parameters=True)
    args.accelerator = Accelerator(kwargs_handlers=[ddp])
    args.device = args.accelerator.device

    log_dir = create_log_dir(args)
    args.log_path = log_dir
    args.model_save_path = os.path.join(log_dir, "checkpoints")

    # create main logger
    logging.basicConfig(level=logging.INFO)
    args.logger = get_logger("fedplora")
    txtlog = os.path.join(log_dir, "exp_log.txt")

    if args.accelerator.is_local_main_process:
        with open(txtlog, "a+") as f:
            f.write(log_dir + "\n")

        handler = logging.FileHandler(txtlog)
        args.logger.logger.addHandler(handler)

    # ------------------------ Training ------------------------ #
    score_box = None
    for r in range(args.repeat):
        run_cfg = copy.deepcopy(args)

        # deterministic
        torch.manual_seed(run_cfg.seed + r)
        np.random.seed(run_cfg.seed + r)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        best_result, metrics = fedplora(run_cfg)

        # initialize metric box once
        if score_box is None:
            score_box = [[] for _ in range(len(metrics))]

        # save active metric values
        for idx, (key, enable) in enumerate(metrics.items()):
            if enable == 1:
                score_box[idx].append(best_result[idx])

    # --------------------- Final Result Print --------------------- #
    for idx, (key, enable) in enumerate(metrics.items()):
        if enable == 1:
            vals = score_box[idx]
            args.logger.info(f"[{key}] {vals}", main_process_only=True)
            args.logger.info(f"[AVG-{key}]  {np.mean(vals):.4f}", main_process_only=True)


if __name__ == "__main__":
    os.environ["NCCL_P2P_DISABLE"] = "1"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
    main()