import logging
from logging import Logger
from pathlib import Path
from typing import Dict, Optional

from pythonjsonlogger.jsonlogger import JsonFormatter

from algorithms.space.base_space import IdentifiableSpace


class AllParamsLogger(Logger):
    def _log(
        self,
        level: int,
        msg: object,
        args,
        exc_info=None,
        extra: Dict[str, object] = None,
        stack_info: bool = None,
        stacklevel: int = 1,
        **kwargs,
    ):
        if extra:
            extra.update(kwargs)
        else:
            extra = kwargs

        return super(AllParamsLogger, self)._log(
            level, msg, args, exc_info, extra, stack_info, stacklevel
        )


class RunJsonFormatter(JsonFormatter):
    def __init__(self, run_name: str, alg_name: str, space: IdentifiableSpace, *args, **kwargs):
        super(RunJsonFormatter, self).__init__(*args, **kwargs)
        self.run_name = run_name
        self.alg_name = alg_name
        self.suite = space.suite
        self.func_id = space.func_id
        self.dimension = space.dimension
        self.func_instance = space.func_instance

    def add_fields(self, log_record, record, message_dict):
        super(RunJsonFormatter, self).add_fields(log_record, record, message_dict)
        if record.exc_text:
            log_record["exc_info"] = record.exc_text
        log_record["levelname"] = record.levelname
        log_record["run_name"] = self.run_name
        log_record["alg_name"] = self.alg_name
        log_record["suite"] = self.suite
        log_record["func_id"] = self.func_id
        log_record["dimension"] = self.dimension
        log_record["func_instance"] = self.func_instance


def create_logger(
    log_path: Optional[Path],
    splunk_path: Optional[Path],
    run_name: str,
    alg_name: str,
    space: IdentifiableSpace,
    logger_name: str = "",
):
    logging.setLoggerClass(AllParamsLogger)
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.DEBUG)
    logger.handlers.clear()

    if log_path:
        log_path.parent.mkdir(exist_ok=True, parents=True)
        file_handler = logging.FileHandler(log_path)
        normal_formatter = logging.Formatter(
            "%(levelname)s - %(asctime)s.%(msecs)03d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
        )
        file_handler.setFormatter(normal_formatter)
        file_handler.setLevel(logging.INFO)
        logger.addHandler(file_handler)
    if splunk_path:
        splunk_path.parent.mkdir(exist_ok=True, parents=True)
        splunk_handler = logging.FileHandler(splunk_path)
        json_formatter = RunJsonFormatter(run_name, alg_name, space)
        splunk_handler.setFormatter(json_formatter)
        splunk_handler.setLevel(logging.INFO)
        logger.addHandler(splunk_handler)

    return logger


def create_base_log_path(base: Path, alg_name: str, run_name: str) -> Path:
    return base / "logs" / alg_name / run_name


def create_file_log_path(base: Path, alg_name: str, run_name: str) -> Path:
    return create_base_log_path(base, alg_name, run_name) / "normal"


def create_splunk_path(base: Path, alg_name: str, run_name: str, signature: str) -> Path:
    return (
        create_base_log_path(base, alg_name, run_name)
        / "splunk"
        / f"{run_name}_splunk_logs_{signature}"
    )
