from typing import Dict, Any, Optional, Tuple
import os
from os import path
from tensorboardX import SummaryWriter
import numpy as np
import torch


class Reporter:
    def __init__(self, writer, counter):
        self._writer = writer
        self._times_counter = counter

    def add_scalars(self, info: Dict[str, Any], prefix: str, interval: int = 1):
        for _k, v in info.items():
            k = f"{prefix}/{_k}"
            if k in self._times_counter and self._times_counter[k] % interval != 0:
                self._times_counter[k] += 1
                continue

            if isinstance(v, tuple):
                assert isinstance(v[1], int)
                assert k not in self._times_counter or self._times_counter[k] < v[1]
                val = v[0]
                self._writer.add_scalar(k, val, v[1])
                self._times_counter[k] = v[1]
            else:
                if not k in self._times_counter:
                    self._times_counter[k] = 0
                val = v
                self._writer.add_scalar(k, val, self._times_counter[k])
                self._times_counter[k] += 1

    def add_distributions(
        self, info: Dict[str, Any], prefix: str, interval: int = 1, no_print=False
    ):
        for _k, v in info.items():
            k = f"{prefix}/{_k}"
            if k in self._times_counter and self._times_counter[k] % interval != 0:
                self._times_counter[k] += 1
                continue

            if isinstance(v, torch.Tensor):
                v = v.detach().cpu().numpy()

            if isinstance(v, tuple):
                assert isinstance(v[1], int)
                assert k not in self._times_counter or self._times_counter[k] < v[1]
                val = v[0]
                assert isinstance(val, np.ndarray)
                if not no_print:
                    print(f"distribution report: {k}: {val.tolist()} @ {v[1]}")
                    print(
                        f"{k}: mean: {val.mean().item()}, std: {val.std().item()} @ {v[1]}"
                    )
                self._writer.add_histogram(k, val, v[1])
                self._times_counter[k] = v[1]
            else:
                if not k in self._times_counter:
                    self._times_counter[k] = 0
                val = v
                assert isinstance(val, np.ndarray)
                if not no_print:
                    print(
                        f"distribution report: {k}: {val.tolist()} @ {self._times_counter[k]}"
                    )
                    print(
                        f"{k}: mean: {val.mean().item()}, std: {val.std().item()} @ {self._times_counter[k]}",
                        flush=True,
                    )
                self._writer.add_histogram(k, val, self._times_counter[k])
                self._times_counter[k] += 1

    def add_embedding(
        self, mat, tag, global_step: Optional[int] = None, interval: int = 1
    ):
        if tag in self._times_counter and self._times_counter[tag] % interval != 0:
            self._times_counter[tag] += 1
            return

        if isinstance(mat, torch.Tensor):
            mat = mat.detach().cpu().numpy()

        _step = global_step or (
            self._times_counter[tag] if tag in self._times_counter else 0
        )
        self._writer.add_embedding(mat, tag=tag, global_step=_step)
        self._times_counter[tag] = _step

    def add_videos(self, info: Dict[str, Tuple[np.ndarray, int]], prefix: str):
        for _k, (video, step) in info.items():
            k = f"{prefix}/{_k}"

            self._writer.add_video(k, video, step)

    def add_params(self, params: Dict[str, Any]):
        for k, v in params.items():
            self._writer.add_text(k, str(v))

    def add_text(self, tag: str, text: str):
        self._writer.add_text(tag, text)


reporter: Optional[Reporter] = None
reporter_dir: Optional[str] = None


def get_reporter() -> Reporter:
    assert reporter is not None
    return reporter


def get_reporter_dir():
    assert reporter_dir is not None
    return reporter_dir


def init_reporter(folder: str):
    global reporter, reporter_dir
    assert reporter is None and reporter_dir is None

    _path = path.join(folder, "tblogs")
    writer = SummaryWriter(logdir=_path)
    print(f"tensorboard reporter created at {path.join(os.getcwd(),_path)}")
    times_counter: Dict[str, int] = dict()

    reporter = Reporter(writer, times_counter)
    reporter_dir = writer.logdir
    return
