import os
import sys
import tempfile
import logging
import argparse
import atexit
import signal
import warnings

from pathlib import Path
from typing import Callable, Dict, Optional

import nxcl
from nxcl.core.config import ConfigDict
from nxcl.config import load_config, save_config, ConfigDict
from nxcl.config.argparse import add_config_arguments
from nxcl.experimental import utils as dev_utils


def setup_signal_handler():
    def get_signal_handler(raise_obj):
        def handler(signum, frame):
            print("\x1b[?25h")
            raise raise_obj
        return handler

    # Show cursor when exit (`rich` hides cursor by default)
    atexit.register(lambda: print("\x1b[?25h"))
    signal.signal(signal.SIGINT, get_signal_handler(KeyboardInterrupt))
    signal.signal(signal.SIGTERM, get_signal_handler(SystemExit(0)))


def setup_argparse(setup_fn = None, aliases: Optional[Dict] = None):
    parser = argparse.ArgumentParser(add_help=False)
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-f", "--config-file", type=str)
    group.add_argument("-r", "--resume", type=str, help="Resume from a checkpoint")
    parser.add_argument("-n", "--name", type=str)
    parser.add_argument("-bd", "--base-dir", default="results")
    parser.add_argument("--no-progress", action="store_true")
    parser.add_argument("--debug", action="store_true")
    if setup_fn is not None:
        setup_fn(parser)
    args, rest_args = parser.parse_known_args()

    config: ConfigDict = load_config(args.config_file)

    parser = argparse.ArgumentParser()
    add_config_arguments(parser, config, aliases=aliases)
    cfg_args = parser.parse_args(rest_args)

    config.update(vars(cfg_args))
    return args, config


def setup_output_dir(args):
    if args.resume:
        if args.name:
            warnings.warn("You specified both `--resume` and `--name`. `--name` will be ignored.")
        if args.base_dir:
            warnings.warn("You specified both `--resume` and `--base-dirs. `--base-dir` will be ignored.")
        if args.debug:
            warnings.warn("You specified both `--resume` and `--debug`. `--debug` will be ignored.")

        save_dir = Path(args.resume)
        if not save_dir.exists():
            raise ValueError(f"Cannot resume from {args.resume} because it does not exist.")

        def link_output_dir(alias_path: str):
            warnings.warn("Resume from output directory. `link_output_dir` will be ignored.")
            return save_dir

    else:
        exp_id = dev_utils.get_experiment_name()

        if args.debug:
            temp_dir = tempfile.TemporaryDirectory()
            save_dir = temp_dir.name
        else:
            save_dir = os.path.join(args.base_dir, "_", exp_id)

        save_dir = Path(save_dir)

        if not save_dir.exists():
            save_dir.mkdir(parents=True, exist_ok=True)

        def link_output_dir(alias_path: str):
            link_path = Path(args.base_dir, alias_path, exp_id)
            link_path.parent.mkdir(parents=True, exist_ok=True)
            link_path.symlink_to(os.path.relpath(save_dir, link_path.parent), target_is_directory=True)
            return link_path

    return save_dir, link_output_dir


def setup_logger(save_dir, config: ConfigDict, suppress=None):
    suppress = (suppress or []) + [nxcl]
    logger = dev_utils.setup_logger(__name__, save_dir, suppress=suppress)
    logger.debug("python " + " ".join(sys.argv))

    args_str = "Configs:"
    for k, v in config.items(flatten=True):
        args_str += f"\n    {k:<25}: {v}"
    logger.info(args_str)
    logger.info(f"Output directory: \"{save_dir}\"")

    save_config(config, save_dir / "config.yaml")

    return logger


def launch(
    func: Callable[[argparse.Namespace, ConfigDict, logging.Logger, Path, Callable], None],
    argparse_fn: Optional[Callable[[argparse.ArgumentParser], None]] = None,
    aliases: Optional[Dict] = None,
):
    setup_signal_handler()
    args, config = setup_argparse(argparse_fn, aliases)
    save_dir, link_output_dir = setup_output_dir(args)
    logger = setup_logger(save_dir, config)

    try:
        func(args, config, logger, save_dir, link_output_dir)
        code = 0
    except KeyboardInterrupt:
        logging.info("Interrupted")
        code = -1
    except Exception as e:
        logging.exception(e)
        code = -2

    return code
