"""
Dumps things to tensorboard and console
"""

import os
import logging
import datetime
from typing import Dict
import numpy as np
from PIL import Image

from torch.utils.tensorboard import SummaryWriter
from feature_extractor.cutie.cutie.utils.time_estimator import TimeEstimator


def tensor_to_numpy(image):
    image_np = (image.numpy() * 255).astype('uint8')
    return image_np


def detach_to_cpu(x):
    return x.detach().cpu()


def fix_width_trunc(x):
    return ('{:.9s}'.format('{:0.9f}'.format(x)))


class TensorboardLogger:
    def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb):
        self.run_dir = run_dir
        self.py_log = py_logger
        if enabled_tb:
            self.tb_log = SummaryWriter(run_dir)
        else:
            self.tb_log = None

        # Get current git info for logging
        try:
            import git
            repo = git.Repo(".")
            git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
        except (ImportError, RuntimeError):
            print('Failed to fetch git info. Defaulting to None')
            git_info = 'None'

        self.log_string('git', git_info)

        # used when logging metrics
        self.time_estimator: TimeEstimator = None

    def log_scalar(self, tag, x, it):
        if self.tb_log is None:
            return
        self.tb_log.add_scalar(tag, x, it)

    def log_metrics(self, exp_id, prefix, metrics: Dict, it):
        msg = f'{exp_id}-{prefix} - it {it:6d}: '
        metrics_msg = ''
        for k, v in sorted(metrics.items()):
            self.log_scalar(f'{prefix}/{k}', v, it)
            metrics_msg += f'{k: >10}:{v:.7f},\t'

        if self.time_estimator is not None:
            self.time_estimator.update()
            avg_time = self.time_estimator.get_and_reset_avg_time()
            est = self.time_estimator.get_est_remaining(it)
            est = datetime.timedelta(seconds=est)
            if est.days > 0:
                remaining_str = f'{est.days}d {est.seconds // 3600}h'
            else:
                remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
            eta = datetime.datetime.now() + est
            eta_str = eta.strftime('%Y-%m-%d %H:%M:%S')
            time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
            msg = f'{msg} {time_msg}'

        msg = f'{msg} {metrics_msg}'
        self.py_log.info(msg)

    def log_image(self, stage_name, tag, image, it):
        image_dir = os.path.join(self.run_dir, f'{stage_name}_images')
        os.makedirs(image_dir, exist_ok=True)

        image = Image.fromarray(image)
        image.save(os.path.join(image_dir, f'{tag}_{it}.png'))

    def log_string(self, tag, x):
        self.py_log.info(f'{tag} - {x}')
        if self.tb_log is None:
            return
        self.tb_log.add_text(tag, x)

    def debug(self, x):
        self.py_log.debug(x)

    def info(self, x):
        self.py_log.info(x)

    def warning(self, x):
        self.py_log.warning(x)

    def error(self, x):
        self.py_log.error(x)

    def critical(self, x):
        self.py_log.critical(x)
