import torch
import os
import shutil

from .base import StepCallback

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter


class CsvWriter(StepCallback):
    def __init__(self, log_dir, cover=False, flush=True, suffix='.csv'):
        r"""Write the metrics to csv(txt) files.

        Args:
            log_dir: write root dir
            cover: if True, will remove all files in `log_dir`
            flush: if True, will flush after each step.
            suffix: saved files' suffix name.
        """
        super().__init__()
        self.log_dir = log_dir
        self.flush = flush
        self.suffix = suffix

        if cover and os.path.exists(log_dir):
            shutil.rmtree(log_dir, ignore_errors=True)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        self._log_writers = {}

    def __del__(self):
        for _, writer in self._log_writers.items():
            writer.close()

    def _add_scalar(self, name, value, step):
        if name not in self._log_writers:
            f = open(os.path.join(self.log_dir, name + self.suffix), 'w')
            f.write('step,value\n')
            self._log_writers[name] = f
        self._log_writers[name].write('{},{}\n'.format(step, value))
        if self.flush:
            self._log_writers[name].flush()

    def on_train_batch_end(self, batch_number: int, logs):
        super().on_train_batch_end(batch_number, logs)

        for name, value in logs.items():
            self._add_scalar('train_{}'.format(name), value, self.step)

    def on_eval_end(self, logs):
        for name, value in logs.items():
            self._add_scalar('eval_{}'.format(name), value, self.step)


class BaseTBWriter(StepCallback):
    def __init__(self, log_dir, cover=False):
        super().__init__()

        self.log_dir = log_dir

        if cover and os.path.exists(log_dir):
            shutil.rmtree(log_dir, ignore_errors=True)

        self._log_writer = SummaryWriter(log_dir)

    def __del__(self):
        pass


class TBLogWriter(BaseTBWriter):
    def __init__(self, log_dir, cover=False):
        super().__init__(log_dir, cover)

    def on_train_batch_end(self, batch_number: int, logs):
        super().on_train_batch_end(batch_number, logs)

        for name, value in logs.items():
            self._log_writer.add_scalar('train/{}'.format(name), value, self.step)

    def on_eval_end(self, logs):
        for name, value in logs.items():
            self._log_writer.add_scalar('eval/{}'.format(name), value, self.step)


class TBHistogramWriter(BaseTBWriter):
    def __init__(self, log_dir, log_step=1, log_names=None, bins='tensorflow', cover=False):
        r"""log the histogram of the net's paramenters.

        Args:
            log_dir: tensorboard log dir
            log_step: Write interval
            log_names: if is None, write all paramenters,
                       otherwise write the paramenters which name in the `log_names` list
            bins: One of {‘tensorflow’,’auto’, ‘fd’, …}. This determines how the bins are made.
                  You can find other options in:
                  https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
            cover: if True, remove all of the dir's file.
        """
        super().__init__(log_dir, cover)
        self.bins = bins
        self.log_names = log_names
        self.log_step = log_step

    def on_train_batch_end(self, batch_number: int, logs):
        super().on_train_batch_end(batch_number, logs)

        if self.step % self.log_step != 0:
            return
        for name, param in self.trainer.net.named_parameters():
            if self.log_names is not None and name not in self.log_names:
                continue
            self._log_writer.add_histogram(name.replace('.', '/'), param, self.step, bins=self.bins)

