import os
import re
import math
import torch
import datetime
import numpy as np
try:
    import moviepy.editor as mpy
except:
    pass


def logging(*msg):
    # def prRed(prt): print("\033[91m {}\033[00m".format(prt))
    # def prGreen(prt): print("\033[92m {}\033[00m".format(prt))
    # def prYellow(prt): print("\033[93m {}\033[00m".format(prt))
    # def prLightPurple(prt): print("\033[94m {}\033[00m".format(prt))
    # def prPurple(prt): print("\033[95m {}\033[00m".format(prt))
    # def prCyan(prt): print("\033[96m {}\033[00m".format(prt))
    # def prLightGray(prt): print("\033[97m {}\033[00m".format(prt))
    # def prBlack(prt): print("\033[98m {}\033[00m".format(prt))

    print("{}>".format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')), *msg)


def mean_dict(_list_dict: list):
    result = {}
    for d in _list_dict:
        for k in d.keys():
            result[k] = result.get(k, 0) + d[k]

    for k, v in result.items():
        result[k] = float(v) / float(len(_list_dict))
    return result


def softmax(_vec):
    """Computes the softmax of a vector."""
    normalized_vector = np.array(_vec) - np.max(_vec)  # For numerical stability
    return np.exp(normalized_vector) / np.sum(np.exp(normalized_vector))


def min_max_scale_vec(_vec: np.ndarray, _min: float, _max: float):
    _num = (_vec - np.min(_vec))
    _den = (np.max(_vec) - np.min(_vec))
    _fraction = _num / _den if _den else 0.0  # Safe division!
    return _fraction * (_max - _min) + _min


def scale_number(x, to_min, to_max, from_min, from_max):
    return (to_max - to_min) * (x - from_min) / (from_max - from_min) + to_min


def softplus_np(x):
    return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)


def softplus_math(x):
    return math.log1p(math.exp(-abs(x))) + max(x, 0)


def scaled_sigmoid(x, _min, _max):
    return _min + (_max - _min) / (1 + math.exp(-x))


def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]


def tile_images(img_nhwc, _size=None):
    """
    Tile N images into one big PxQ image
    (P,Q) are chosen to be as close as possible, and if N
    is square, then P=Q.
    input: img_nhwc, list or array of images, ndim=4 once turned into array
        n = batch index, h = height, w = width, c = channel
    returns:
        bigim_HWc, ndarray with ndim=3
    """
    img_nhwc = np.asarray(img_nhwc)

    # NOTE: Added to support renderings that return a series of images
    if len(img_nhwc.shape) == 1:
        # Pad other sequences so they are as long.
        dims = [img_nhwc[i].shape[0] for i in range(img_nhwc.shape[0])]
        max_dim = max(dims)
        for i in range(img_nhwc.shape[0]):
            if img_nhwc[i].shape[0] < max_dim:
                img_nhwc[i] = np.pad(img_nhwc[i], ((0, 1), (0, 0), (0, 0), (0, 0)), 'edge')
        img_nhwc = np.array(list(img_nhwc))
        return tile_images(img_nhwc)

    if len(img_nhwc.shape) == 5:
        result = np.array([tile_images(img_nhwc[:, i]) for i in range(img_nhwc.shape[1])])
        return result

    N, h, w, c = img_nhwc.shape

    if _size is None:
        H = int(np.ceil(np.sqrt(N)))
        W = int(np.ceil(float(N) / H))
    else:
        H, W = _size
    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])
    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
    # add boundary to distinguish environments
    img_HWhwc[:, :, :, -1, :] = 0
    img_HWhwc[:, :, -1, :, :] = 0
    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
    img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)
    return img_Hh_Ww_c


def save_mp4(frames, vid_dir, name, fps=10.0):
    frames = np.array(frames)

    def f(t):
        idx = min(int(t * fps), len(frames) - 1)
        return frames[idx]

    if not os.path.exists(vid_dir):
        os.makedirs(vid_dir)

    vid_file = os.path.join(vid_dir, name + '.mp4')
    if os.path.exists(vid_file):
        os.remove(vid_file)

    video = mpy.VideoClip(f, duration=len(frames) / fps)
    video.write_videofile(vid_file, fps, verbose=False, logger=None)


def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


def hard_update(target, source):
    for m1, m2 in zip(target.modules(), source.modules()):
        m1._buffers = m2._buffers.copy()
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)


if __name__ == '__main__':
    img = np.random.random((11, 28, 28, 3))
    img = tile_images(img_nhwc=img)
    print(img.shape)
    asdf
