
import os
import importlib
import numpy as np
from collections import OrderedDict


def get_dump_file_paths(out_path, fold=None):
    """
    """
    par = 'params.pt' if fold is None else 'params_%d.pt' % fold
    log = 'log.npy' if fold is None else 'log_%d.npy' % fold
    dump_file = os.path.join(out_path, par)
    log_file = os.path.join(out_path, log)
    return dump_file, log_file


def import_from_file(model_path):
    """
    Import from file path
    """
    import_file = os.path.basename(model_path).split(".")[0]
    import_root = os.path.dirname(model_path)
    imported = importlib.import_module("%s.%s" % (import_root, import_file))

    return imported


def import_from_package_by_name(object_name, package_root):
    """
    Import from package by name
    """
    package = importlib.import_module(package_root)
    obj = getattr(package, object_name)
    return obj


def is_better(value, all_values, mode="min"):
    """
    Check for improvement
    """
    if mode == "min":
        return value <= np.min(all_values)
    elif mode == "max":
        return value >= np.max(all_values)
    else:
        raise Exception('is_better(.) takes mode either min or max!')


class BColors:
    """
    Colored command line output formatting
    """
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    def __init__(self):
        """ Constructor """
        pass

    def print_colored(self, string, color):
        """ Change color of string """
        return color + string + BColors.ENDC


class EpochLogger(object):
    """
    Convenient logging of epoch stats
    """

    def __init__(self):

        self.epoch_stats = OrderedDict()
        self.within_epoch_stats = OrderedDict()
        self.previous_epoch = None

    def append(self, key, value):

        if key not in self.within_epoch_stats:
            self.within_epoch_stats[key] = []

        self.within_epoch_stats[key].append(value)

    def summarize_epoch(self):

        self.previous_epoch = self.within_epoch_stats.copy()

        for key in list(self.within_epoch_stats.keys()):

            if key not in self.epoch_stats:
                self.epoch_stats[key] = []

            self.epoch_stats[key].append(np.nanmean(self.within_epoch_stats[key]))
            self.within_epoch_stats[key] = []

    def dump(self, dump_file):
        np.save(dump_file, self.epoch_stats)

    def load(self, dump_file):
        self.epoch_stats = np.load(dump_file, allow_pickle=True).item()
