import os
import numpy as np
import time
import _pickle as cPickle
# import cPickle
import tensorflow as tf

def save_params(sess, filename):
    params = tf.trainable_variables()
    param_dict = dict()
    for v in params:
        param_dict[v.name] = sess.run(v)
    filename = filename
    f = open(filename + '.pkl', 'wb')
    cPickle.dump(param_dict, f)
    print('parameters saved at ' + filename + '.pkl')
    f.close()

def load_params(sess, filename,  init_all = True):
    params = tf.trainable_variables()
    filename = filename
    f = open(filename + '.pkl', 'rb')
    param_dict = cPickle.load(f)
    print('param loaded', len(param_dict))
    f.close()
    ops = []
    for v in params:
        if v.name in param_dict.keys():
            ops.append(tf.assign(v, param_dict[v.name]))
    sess.run(ops)
    # init uninitialised params
    if init_all:

        all_var = tf.global_variables()
        var = [v for v in all_var if v not in params]
        sess.run(tf.initialize_variables(var))
    print('loaded parameters from ' + filename + '.pkl')

