from typing import Any

import wandb.wandb_run

import wandb


class Logger:
    def log(self, items: dict[str, Any], step: int | None = None) -> None:
        raise NotImplementedError()

    def flush(self) -> None:
        raise NotImplementedError()

    def finish(self) -> None:
        pass


class PrintLogger(Logger):
    _step_name: str

    def __init__(self, run_name: str, step_name: str = "Step") -> None:
        print("\nStarting run ", run_name, "...\n", sep="")
        self._step_name = step_name

    def log(self, items: dict[str, Any], step: int | None = None) -> None:
        line = ", ".join([f"{key}: {val}" for key, val in items.items()])
        if step is not None:
            print(f"{self._step_name} {step}:", line)
        else:
            print(line)

    def flush(self) -> None:
        print(end="", flush=True)

    def finish(self) -> None:
        print("Finished.")


class WandbLogger(Logger):
    _run: wandb.wandb_run.Run
    _last_step: int | None
    _step_name: str

    def __init__(self, run_name: str, config: dict | None = None, step_name: str = "Step") -> None:
        self._run = wandb.init(project="robust-icrl", name=run_name, config=config)
        self._last_step = None
        self._step_name = step_name

    def log(self, items: dict[str, Any], step: int | None = None) -> None:
        self._run.log(items, step=step)
        if step is not None and step != self._last_step:
            print(f"{self._step_name} {step}")
        self._last_step = step

    def flush(self) -> None:
        self._run.log({}, commit=True)

    def finish(self) -> None:
        self._run.finish()
        self._run = None  # type: ignore
