import os
import torch
from colorama import Fore
import numpy as np
import random

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns


def set_plot(fontsize):
    plt.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
    plt.rc('text', usetex=True)
    plt.axes().title.set_fontsize(fontsize)
    plt.axes().xaxis.label.set_fontsize(fontsize)
    plt.axes().yaxis.label.set_fontsize(fontsize)
    plt.rc('legend', fontsize=fontsize)
    plt.rc('xtick', labelsize=fontsize)
    plt.rc('ytick', labelsize=fontsize)
    plt.tight_layout()
    plt.switch_backend('agg')
    sns.set_palette('colorblind') 

def get_grad(params):
    if isinstance(params, torch.Tensor):
        params = [params]
    params = list(filter(lambda p: p.grad is not None, params))
    grad = [p.grad.data.cpu().view(-1) for p in params]
    return torch.cat(grad)

def write_to_txt(name, content):
    with open(name, 'w') as text_file:
        text_file.write(content)

def my_makedir(name):
    try:
        os.makedirs(name)
    except OSError:
        pass

def print_args(opt):
    for arg in vars(opt):
        print('%s %s' % (arg, getattr(opt, arg)))

def mean(ls):
    return sum(ls) / len(ls)

def normalize(v):
    return (v - v.mean()) / v.std()

def flat_grad(grad_tuple):
    return torch.cat([p.view(-1) for p in grad_tuple])

def print_nparams(model):
    nparams = sum([param.nelement() for param in model.parameters()])
    print('number of parameters: %d' % (nparams))

def print_color(color, string):
    print(getattr(Fore, color) + string + Fore.RESET)


def init_random_seed(seed):
    if seed == None:
        seed = random.randint(1, 10000)
    print("use random seed: {}".format(seed))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
