from collections import defaultdict

import gin
import mrunner
import numpy as np

VARS_NOT_TO_AVERAGE = ["var_idx", "data_samples", "epoch", "ds/var_idx", "ds/SHD"]


@gin.configurable
class Logger:
    DEFAULT_STEP_NAME = "default"

    def __init__(
        self, use_neptune_logger, log_frequency=20, log_not_averaged_every_step=True
    ):
        super().__init__()
        self.use_neptune_logger = use_neptune_logger
        self.log_frequency = log_frequency
        self.log_not_averaged_every_step = log_not_averaged_every_step
        if self.use_neptune_logger:
            self.neptune_run = mrunner.helpers.client_helper.experiment_
        self.prefix = "logs/"
        self.steps = defaultdict(int)
        self.saved_values = defaultdict(list)

    def log(self, name, value, step_name=DEFAULT_STEP_NAME):
        if not self.use_neptune_logger:
            return

        value = float(value)

        self.saved_values[name].append(value)
        step = self.get_current_step(step_name)
        if ((step + 1) % self.log_frequency == 0) or (
            name in VARS_NOT_TO_AVERAGE and self.log_not_averaged_every_step
        ):
            avg_value = np.mean(self.saved_values[name])
            reported_value = value if name in VARS_NOT_TO_AVERAGE else avg_value
            self.neptune_run[f"{self.prefix}{name}"].log(
                value=reported_value, step=step
            )
            self.saved_values[name] = []

    def reset(self):
        for s in self.steps:
            self.steps[s] = 0
        self.saved_values = defaultdict(list)

    def bump(self, value=1, step_name=DEFAULT_STEP_NAME):
        self.steps[step_name] += value

    def get_current_step(self, step_name=DEFAULT_STEP_NAME):
        return self.steps[step_name]


NEPTUNE_LOGGER = Logger(False)


def create_neptune_logger(use_neptune_logger):
    global NEPTUNE_LOGGER
    NEPTUNE_LOGGER = Logger(use_neptune_logger)
