import logging
import multiprocessing
import os
import shutil
import sys
import time
from collections import ChainMap
from typing import Any, Callable, Dict, List, Tuple

import GPUtil
import wandb
from torch import cuda
from torch.utils.tensorboard import SummaryWriter

from args import TrainerArguments, ModelArguments


__all__ = (
    'MetricTracker',
    'BestMetricTracker',
    'create_logger',
    'log',
    'log_metric'
)


class MetricTracker:
    name: str
    prefix: str
    value: Any

    def __init__(self, name: str, prefix: str = None) -> None:
        self.prefix = prefix
        self.name = name

    def update(self, val: Any) -> None:
        self.value = val

    def format_dict(self) -> Dict[str, Any]:
        if self.prefix:
            return {f"{self.prefix}/{self.name}": self.value}
        else:
            return {f"{self.name}": self.value}


class AvgMetricTracker(MetricTracker):
    def __init__(self, name: str, prefix: str = None) -> None:
        super().__init__(name, prefix)
        self.value = 0
        self.sum = 0
        self.count = 0

    def update(self, val_avg: Any, num: int) -> None:
        self.sum += val_avg * num
        self.count += num
        self.value = self.sum / self.count


class BestMetricTracker(MetricTracker):
    metric: "BestMetric"
    best_fn: Callable[[Any, Any], bool]

    class BestMetric:
        def __init__(self, value, step) -> None:
            self.value = value
            self.step = step

        def __eq__(self, other: "BestMetricTracker.BestMetric") -> bool:
            return self.value == other.value

        def __lt__(self, other: "BestMetricTracker.BestMetric") -> bool:
            return BestMetricTracker.best_fn(other.value, self.value)

        def get(self, round_digits: int = 4) -> Tuple[int, int]:
            return int(self.value * (10 ** round_digits)), self.step

    def __init__(
        self,
        name: str,
        prefix: str = None,
        value: float = 0.,
        best_fn: Callable[[Any, Any], bool] = lambda x, y: x > y,
    ) -> None:
        super().__init__(name, prefix)
        BestMetricTracker.best_fn = best_fn
        self.step = 0
        self.metric = BestMetricTracker.BestMetric(value, 0)

    def update(self, val: float, step: int) -> None:
        if BestMetricTracker.best_fn(val, self.value):
            self.metric = BestMetricTracker.BestMetric(val, step)

    @property
    def value(self) -> float:
        return self.metric.value


class Logger:
    def log(self, msg: str) -> None:
        pass

    def log_metric(self, *metrics: Tuple[MetricTracker], **kwargs) -> None:
        metric_dict = dict(ChainMap(*[metric.format_dict() for metric in metrics]))
        self._log_metric(metric_dict, **kwargs)

    def _log_metric(self, metric_dict: Dict[str, Any], **kwargs) -> None:
        pass

    def close(self) -> None:
        pass


class WandbLogger(Logger):
    def __init__(self) -> None:
        pass

    def _log_metric(self, metric_dict: Dict[str, Any], **kwargs) -> None:
        wandb.log(metric_dict, **kwargs)

    def close(self) -> None:
        wandb.finish()


def gpu_reporter(writer: SummaryWriter, record_break: int) -> None:
    """
    Args:
        writer: SummaryWriter
        record_break: time interval between two records in ms
    """
    step = 1
    while True:
        gpus_all = GPUtil.getGPUs()
        if cuda.device_count() > 0:
            gpus_available = [
                gpus_all[gpu_id] for gpu_id in map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))
            ]
        else:
            gpus_available = [gpus_all[cuda.current_device()]]
        writer.add_scalars(
            "GPU/memory_percentage(%)",
            {f"gpu{gpu.id}": gpu.memoryUtil*100 for gpu in gpus_available},
            step
        )
        writer.add_scalars(
            "GPU/memory_used(MB)",
            {f"gpu{gpu.id}": gpu.memoryUsed for gpu in gpus_available},
            step
        )
        writer.add_scalars(
            "GPU/temperature(°C)",
            {f"gpu{gpu.id}": gpu.temperature for gpu in gpus_available},
            step
        )
        writer.add_scalars(
            "GPU/utilization(%)",
            {f"gpu{gpu.id}": gpu.load*100 for gpu in gpus_available},
            step
        )

        time.sleep(record_break / 1000)
        step += 1


class TensorboardLogger(Logger):
    def __init__(self, log_dir: str, gpu_report_break: int = 100) -> None:
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir)
        os.makedirs(log_dir, exist_ok=True)

        self.writer = SummaryWriter(log_dir=log_dir)
        self.reporter = multiprocessing.Process(target=gpu_reporter, args=(self.writer, gpu_report_break))
        self.reporter.start()
        self.reporter.daemon

    def _log_metric(self, metric_dict: Dict[str, Any], step: int, **kwargs) -> None:
        for name, value in metric_dict.items():
            self.writer.add_scalar(name, value, step, **kwargs)

    def close(self) -> None:
        self.reporter.terminate()
        self.writer.close()


class SysLogger(Logger):
    def __init__(self, log_path: str) -> None:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s | %(levelname)s |  %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
            handlers=[
                logging.FileHandler(filename=log_path, mode='w'),
                logging.StreamHandler(stream=sys.stdout),
            ],
        )

        self.logger = logging.getLogger(__name__)

    def log(self, msg: str) -> None:
        self.logger.info(msg)

    def _log_metric(self, metric_dict: Dict[str, Any], **kwargs) -> None:
        self.logger.info(metric_dict)


loggers: List[Logger] = []


def create_logger(train_args: TrainerArguments, model_args: ModelArguments):
    global loggers
    loggers.append(SysLogger(train_args.log_file_path))
    if train_args.use_wandb:
        wandb.init(
            dir=train_args.log_dir,
            config={"train_args": train_args, "model_args": model_args},
            project=train_args.project,
            entity=train_args.entity,
            name=train_args.experiment_name
        )
        loggers.append(WandbLogger())
    if train_args.use_tensorboard:
        loggers.append(TensorboardLogger(train_args.tensorboard_log_dir))


def close():
    for logger in loggers:
        logger.close()


def log(msg: str):
    for logger in loggers:
        logger.log(msg)


def log_metric(*metrics: Tuple[MetricTracker], **kwargs):
    for logger in loggers:
        logger.log_metric(*metrics, **kwargs)
