import datetime
import git
import sys
import torch
import visdom
from moduleloader import print_info as print

from .plot_img import plot_img
from .plot_scalar import plot_scalar
from .plot_scalars import plot_scalars
from .plot_scalar2d import plot_scalar2d
from .plot_pca import plot_pca
from .plot_gradnorm import plot_gradnorm
from .plot_static_hm import plot_static_hm


def datestr_sort():
    return datetime.datetime.now().strftime('%y%m%d-%H%M%S')
def datestr():
    return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

def get_git_revision_short_hash():
    repo = git.Repo(search_parent_directories=True)
    sha = repo.head.object.hexsha
    return sha

def init(state, event):

    # enable defaults, s.t. these settings appear in settings-lists
    state.default.update({state.module_name+"."+k: state[k] for k in ["git_hash","cli_overwrites","date","python","pytorch"]})

    # check if user wishes to add random hash to log
    if state.all["log.dir"].endswith("!"):
        state.all["log.dir"] = state.all["log.dir"][:-1]+"-"+datestr_sort()

    # silent visdom messages
    visdom.logger.setLevel(visdom.logging.CRITICAL)

    # init visdom
    state["vis"] = visdom.Visdom(env=state.all["log.dir"],port=state["port"])

    # save current settings in log
    _pos = state.all["log.dir"].find("_")
    if _pos < 0:
        text_dir = state.all["log.dir"] + "-info"
    else:
        text_dir = state.all["log.dir"][:_pos]+"-info"+state.all["log.dir"][_pos:]
    state["vis"].text(text=event.settings_html(),win="settings",env=text_dir)

def before_training(state):
    print(state.all["log.dir"],msg="Running")
    state["vis"].save([state.all["log.dir"], "info_" + state.all["log.dir"]])
def save_log(state):
    state["vis"].save([state.all["log.dir"]])



# NoneDict
class NoneDict(dict):
    def __getitem__(self, key):
        return dict.get(self, key)

def register(mf):
    mf.register_defaults({
        "port": 8097,
        "steps.scalar": 100,
    })
    mf.register_globals({
        "log.dir": "log"
    })
    mf.load("tensorhelpers")
    mf.register_helpers({
        "WINDOWS": NoneDict(),
        "Scalar": {},
        "Scalar2D": {},
        "git_hash": get_git_revision_short_hash(),
        "cli_overwrites": " ".join(sys.argv),
        "date": datestr(),
        "python": sys.version.replace("\n"," "),
        "pytorch": torch.__version__
    })
    mf.register_event('init',init)
    mf.register_event('before_training',before_training)
    mf.register_event('after_epoch',save_log)
    mf.register_event('after_training',save_log)
    mf.register_event('plot',lambda state,event: True) # can't touch dis
    mf.register_event('plot_img',plot_img)
    mf.register_event('plot_scalar',plot_scalar)
    mf.register_event('plot_scalars',plot_scalars)
    mf.register_event('plot_scalar2d',plot_scalar2d)
    mf.register_event('plot_static_hm',plot_static_hm)
    mf.register_event('plot_pca',plot_pca)
    mf.register_event('plot_gradnorm',plot_gradnorm)

