import numpy as np
import os, torch, pickle, zipfile
import imageio, shutil
import scipy, scipy.misc, scipy.integrate
import matplotlib.pyplot as plt
solve_ivp = scipy.integrate.solve_ivp


def integrate_model(model, t_span, y0, fun=None, **kwargs):
    def default_fun(t, np_x):
        x = torch.tensor( np_x, requires_grad=True, dtype=torch.float32)
        x = x.view(1, np.size(np_x)) # batch size of 1
        dx = model.time_derivative(x).data.numpy().reshape(-1)
        return dx
    fun = default_fun if fun is None else fun
    return solve_ivp(fun=fun, t_span=t_span, y0=y0, **kwargs)


def rk4(fun, y0, t, dt, *args, **kwargs):
    dt2 = dt / 2.0
    k1 = fun(y0, t, *args, **kwargs)
    k2 = fun(y0 + dt2 * k1, t + dt2, *args, **kwargs)
    k3 = fun(y0 + dt2 * k2, t + dt2, *args, **kwargs)
    k4 = fun(y0 + dt * k3, t + dt, *args, **kwargs)
    dy = dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4)
    return dy


def L2_loss(u, v):
    return (u-v).pow(2).mean()


def read_lipson(experiment_name, save_dir):
    desired_file = experiment_name + ".txt"
    with zipfile.ZipFile('{}/invar_datasets.zip'.format(save_dir)) as z:
        for filename in z.namelist():
            if desired_file == filename and not os.path.isdir(filename):
                with z.open(filename) as f:
                    data = f.read()
    return str(data)


def str2array(string):
    lines = string.split('\\n')
    names = lines[0].strip("b'% \\r").split(' ')
    dnames = ['d' + n for n in names]
    names = ['trial', 't'] + names + dnames
    data = [[float(s) for s in l.strip("' \\r,").split( )] for l in lines[1:-1]]

    return np.asarray(data), names


def to_pickle(thing, path): # save something
    with open(path, 'wb') as handle:
        pickle.dump(thing, handle, protocol=pickle.HIGHEST_PROTOCOL)


def from_pickle(path): # load something
    thing = None
    with open(path, 'rb') as handle:
        thing = pickle.load(handle)
    return thing


def choose_nonlinearity(name):
    nl = None
    if name == 'tanh':
        nl = torch.tanh
    elif name == 'relu':
        nl = torch.relu
    elif name == 'sigmoid':
        nl = torch.sigmoid
    elif name == 'softplus':
        nl = torch.nn.functional.softplus
    elif name == 'selu':
        nl = torch.nn.functional.selu
    elif name == 'elu':
        nl = torch.nn.functional.elu
    elif name == 'swish':
        nl = lambda x: x * torch.sigmoid(x)
    else:
        raise ValueError("nonlinearity not recognized")
    return nl


def make_gif(frames, save_dir, name='pendulum', duration=1e-1, pixels=None, divider=0):
    '''Given a three dimensional array [frames, height, width], make
    a gif and save it.'''
    temp_dir = './_temp'
    os.mkdir(temp_dir) if not os.path.exists(temp_dir) else None
    for i in range(len(frames)):
        im = (frames[i].clip(-.5,.5) + .5)*255
        im[divider,:] = 0
        im[divider + 1,:] = 255
        if pixels is not None:
          im = scipy.misc.imresize(im, pixels)
        scipy.misc.imsave(temp_dir + '/f_{:04d}.png'.format(i), im)

    images = []
    for file_name in sorted(os.listdir(temp_dir)):
        if file_name.endswith('.png'):
            file_path = os.path.join(temp_dir, file_name)
            images.append(imageio.imread(file_path))
    save_path = '{}/{}.gif'.format(save_dir, name)
    png_save_path = '{}.png'.format(save_path)
    imageio.mimsave(save_path, images, duration=duration)
    os.rename(save_path, png_save_path)

    shutil.rmtree(temp_dir) # remove all the images
    return png_save_path


def process_stats(stats):
    train_mean = np.array(stats['train_loss'])[:, 0]
    train_std = np.array(stats['train_loss'])[:, 1]
    test_mean = np.array(stats['test_loss'])[:, 0]
    test_std = np.array(stats['test_loss'])[:, 1]
    return [train_mean, train_std, test_mean, test_std]


def plot_results(args, stats_list, labels_list=None, total_length=None, title=None):
    if labels_list is None:
        labels_list = ['HNN', 'AE', 'NODA']
    train_mean_list = [stats[0] for stats in stats_list]
    train_std_list = [stats[1] for stats in stats_list]
    test_mean_list = [stats[2] for stats in stats_list]
    test_std_list = [stats[3] for stats in stats_list]
    step = np.arange(1, len(test_mean_list[0]) + 1)
    save_path = args.save_dir + '/'
    if total_length is not None:
        train_mean_list = [train_mean[:total_length] for train_mean in train_mean_list]
        train_std_list = [train_std[:total_length] for train_std in train_std_list]
        test_mean_list = [test_mean[:total_length] for test_mean in test_mean_list]
        test_std_list = [test_std[:total_length] for test_std in test_std_list]
        step = step[:total_length]
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), sharey='row')
    axes = ax.flatten()
    for i in range(len(labels_list)):
        ax[0].loglog(step, train_mean_list[i], label=labels_list[i])
        ax[0].fill_between(step, train_mean_list[i] - train_std_list[i],
                           train_mean_list[i] + train_std_list[i], alpha=0.3)
        ax[1].loglog(step, test_mean_list[i], label=labels_list[i])
        ax[1].fill_between(step, test_mean_list[i] - test_std_list[i],
                           test_mean_list[i] + test_std_list[i], alpha=0.3)
        save_path += labels_list[i] + '-'
    axes[0].set_xlabel('Step')
    axes[1].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].legend(loc='best')
    axes[1].legend(loc='best')
    if title is not None:
        plt.title(title)
    save_path += 'training-comparison_' + str(args.hidden_dim_ae) + '_' + \
                 str(args.hidden_dim_mlp) + '_' + str(len(step)) + '.pdf'
    plt.savefig(save_path)
    plt.close()


def plot_results_testing(args, stats_list, labels_list=None, total_length=None, title=None):
    if labels_list is None:
        labels_list = ['HNN', 'AE', 'NODA']
    test_mean_list = [stats[2] for stats in stats_list]
    test_std_list = [stats[3] for stats in stats_list]
    step = np.arange(1, len(test_mean_list[0]) + 1)
    save_path = args.save_dir + '/'
    if total_length is not None:
        test_mean_list = [test_mean[:total_length] for test_mean in test_mean_list]
        test_std_list = [test_std[:total_length] for test_std in test_std_list]
        step = step[:total_length]
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))
    for i in range(len(labels_list)):
        ax.loglog(step, test_mean_list[i], label=labels_list[i], linewidth=3.5)
        ax.fill_between(step, test_mean_list[i] - test_std_list[i], test_mean_list[i] + test_std_list[i], alpha=0.3)
        save_path += labels_list[i] + '-'
    ax.set_xlabel('Steps')
    ax.set_ylabel('Testing Loss')
    ax.legend(loc='best')
    ax.grid(True)
    if title is not None:
        plt.title(title)
    save_path += 'testing-comparison_' + str(args.hidden_dim_ae) + '_' + \
                 str(args.hidden_dim_mlp) + '_' + str(len(step)) + '.pdf'
    plt.savefig(save_path)
    plt.close()
