import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import neptune.new as neptune
from tensorboardX import SummaryWriter


@dataclass
class Metric:
    name: str
    value: float
    x_axis_value: int


class MetricLogger(ABC):
    """Abstract class to encompass different metric loggers."""

    @abstractmethod
    def start(self) -> None:
        """Start the metric logger, must be called before log in general"""

    @abstractmethod
    def log(self, metric: Metric) -> None:
        """Log metrics, i.e. int or flat values."""

    @abstractmethod
    def log_image(self, image: Any, filename: str) -> None:
        """Log an image."""

    @abstractmethod
    def log_artifact(self, filename: str) -> None:
        """Log artifacts, i.e. files."""


@dataclass
class NeptuneConfig:
    """Config to define Neptune arguments."""

    user_name: str
    project_name: str
    experiment_name: str
    params: Optional[Dict[str, float]] = None
    tags: Optional[List[str]] = None
    api_token: Optional[str] = None  # if not set, uses $NEPTUNE_API_TOKEN


class NeptuneMetricLogger(MetricLogger):
    """Logger for the Neptune interface"""

    def __init__(self, neptune_config: NeptuneConfig) -> None:
        self.neptune_config = neptune_config
        self._run: Optional[neptune.Run] = None

    def start(self) -> None:
        self._run = neptune.init(
            f"{self.neptune_config.user_name}/{self.neptune_config.project_name}",
            tags=self.neptune_config.tags
            if self.neptune_config.tags is not None
            else [],
            name=self.neptune_config.experiment_name,
            api_token=self.neptune_config.api_token,
        )
        if self.neptune_config.params is not None:
            self._run["parameters"] = self.neptune_config.params

    def log(self, metric: Metric) -> None:
        assert self._run is not None, "call start() before logging"

        self._run[f"logs/{metric.name}"].log(
            metric.value,
            step=metric.x_axis_value,
        )

    def log_image(self, image: Any, filename: str) -> None:
        assert self._run is not None, "call start() before logging"
        self._run[os.path.join("images", filename)].upload(image)

    def log_artifact(self, filename: str) -> None:
        assert self._run is not None, "call start() before logging"

        self._run[f"artifacts/{os.path.basename(filename)}"].upload(filename)


class TensorboardMetricLogger(MetricLogger):
    """Logger for the Tensorboard interface."""

    def __init__(self, save_dir: str) -> None:
        assert os.path.exists(save_dir) and os.path.isdir(
            save_dir
        ), "Directory does not exist"
        self._save_dir = save_dir
        self._summary_writer = None

    def start(self) -> None:
        self._summary_writer = SummaryWriter(self._save_dir)

    def log(self, metric: Metric) -> None:
        assert self._summary_writer is not None, "call start() before logging"
        self._summary_writer.add_scalar(
            tag=metric.name,
            scalar_value=metric.value,
            global_step=metric.x_axis_value
            if (
                metric.x_axis_value is not None and isinstance(metric.x_axis_value, int)
            )
            else None,
            walltime=metric.x_axis_value
            if (
                metric.x_axis_value is not None
                and isinstance(metric.x_axis_value, float)
            )
            else None,
        )

    def log_artifact(self, _filename: str) -> None:
        # tensorboard can't log artifacts
        pass

    def log_image(self, image: Any, filename: str) -> None:
        pass
