from datetime import datetime
import yaml
import random

from yacs.config import CfgNode as CN
import torch
import numpy as np
import matplotlib.pyplot as plt


def set_seed(seed, fully_deterministic=True):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        if fully_deterministic:
            torch.backends.cudnn.deterministic = True


def get_date_time_str(add_hash=True):
    now = datetime.now()
    return_str = 'date_%s_time_%s' % (now.strftime('%d_%m_%Y'), now.strftime('%H_%M'))
    if add_hash:
        return_str = '%s_hash_%s' % (return_str, now.strftime('%f'))
    return return_str


def save_config(config, path):
    def convert_config_to_dict(cfg_node, key_list):
        if not isinstance(cfg_node, CN):
            return cfg_node

        cfg_dict = dict(cfg_node)
        for k, v in cfg_dict.items():
            cfg_dict[k] = convert_config_to_dict(v, key_list + [k])
        return cfg_dict

    with open(path, 'w') as f:
        yaml.dump(convert_config_to_dict(config, []), f, default_flow_style=False)


def print_stats(x, name):
    print(f"{name} stats: min = {x.min():.3f}, max = {x.max():.3f}, mean = {x.mean():.3f}, std = {x.std():.3f}")


def sample_from_simplex(size, random_sign=False):
    sample = np.random.exponential(scale=1.0, size=size)
    sample = sample / sample.sum()
    if random_sign:
        sample = sample * (np.random.randint(2, size=size) * 2 - 1)
    return sample


def show_3d_fig(x, t, y, title=None):
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    X, T = np.meshgrid(x, t)
    ax.plot_surface(X, T, y, cmap='viridis')
    plt.xlabel('x')
    plt.ylabel('t')
    if title is not None:
        plt.title(title)
    plt.show()
