import sys

import os
import pickle
import tensorflow as tf
import numpy as np
import gym
import joblib
import argparse
import matplotlib.pyplot as plt
from tensorflow.distributions import Normal


def save_variables(save_path, variables, sess):
    ps = sess.run(variables)
    save_dict = {v.name: value for v, value in zip(variables, ps)}
    joblib.dump(save_dict, save_path)
    print('Variables saved!')


def load_variables(load_path, variables=None, sess=None):
    # sess = sess or get_session()
    variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    loaded_params = joblib.load(os.path.expanduser(load_path))
    # print('loaded_params keys:', loaded_params.keys())
    restores = []
    for v in variables:
        # print(v.name)
        # print(v.shape)
        # NOTE: This sub-sampling of `v.name` is only for expert_ppo2_model, which
        # will reduce the code compatibility.
        restores.append(v.assign(loaded_params[v.name]))
    sess.run(restores)

def copy_variables(from_variables, to_variables, sess= None):
    """ Given variable scope, assign values in `from_scope` to `to_scope`. You should
    make sure these two scope have the same name structure.
    """
    assert len(from_variables) == len(to_variables), "Different length in two set of variables, are you sure you collected same structure?"
    sess = sess or tf.get_default_session()
    ops = []
    for v_f, v_t in zip(from_variables, to_variables):
        ops.append(v_t.assign(v_f.eval()))
    sess.run(ops)

def save_video(frames, save_num, save_name):
    for i in range(save_num):
            save_sequences(
                frames[i],
                export_dir='./videos/{}_{}.mp4'.format(save_name, i),
                fps=15)
            # save_sequences(expert_frames, export_dir='expert_swimmer.mp4', fps=15)
            print("{} video saved!".format(save_name))


def ortho_init(scale=1.0):
    def _ortho_init(shape, dtype, partition_info=None):
        #lasagne ortho init for tf
        shape = tuple(shape)
        if len(shape) == 2:
            flat_shape = shape
        elif len(shape) == 4:  # assumes NHWC
            flat_shape = (np.prod(shape[:-1]), shape[-1])
        else:
            raise NotImplementedError
        a = np.random.normal(0.0, 1.0, flat_shape)
        u, _, v = np.linalg.svd(a, full_matrices=False)
        q = u if u.shape == flat_shape else v  # pick the one with the correct shape
        q = q.reshape(shape)
        return (scale * q[:shape[0], :shape[1]]).astype(np.float32)

    return _ortho_init


def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
    with tf.variable_scope(scope):
        nin = x.get_shape()[1].value
        # print(nin, nh)
        w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
        b = tf.get_variable(
            "b", [nh], initializer=tf.constant_initializer(init_bias))
        return tf.matmul(x, w) + b


def pi(X,
       act_dim,
       scope,
       num_layers=2,
       num_hidden=64,
       activation=tf.nn.tanh,
       layer_norm=False):
    h = tf.layers.flatten(X)
    for i in range(num_layers):
        local_scope = scope + '/' + 'mlp_fc{}'.format(i)
        h = fc(h, local_scope, nh=num_hidden, init_scale=np.sqrt(2))
        if layer_norm:
            h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
        h = activation(h)
    local_scope = scope
    h = fc(h, local_scope, nh=act_dim)
    return h


def student_pi(X,
               act_dim,
               scope,
               num_layers=2,
               num_hidden=64,
               activation=tf.nn.tanh,
               layer_norm=False):
    h = tf.layers.flatten(X)
    for i in range(num_layers):
        local_scope = scope + '/' + 'mlp_fc{}'.format(i)
        h = fc(h, local_scope, nh=num_hidden, init_scale=np.sqrt(2))
        if layer_norm:
            h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
        h = activation(h)
    local_scope = scope + '/' + 'mean'
    mean = fc(h, local_scope, nh=act_dim)
    return mean


def attack_pi(X,
       act_dim,
       scope,
       epsilon,
       num_layers=2,
       num_hidden=32,
       activation=tf.nn.tanh,
       layer_norm=False):
    h = tf.layers.flatten(X)
    for i in range(num_layers):
        local_scope = scope + '/' + 'mlp_fc{}'.format(i)
        h = fc(h, local_scope, nh=num_hidden, init_scale=np.sqrt(2))
        if layer_norm:
            h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
        h = activation(h)
    local_scope = scope
    h = fc(h, local_scope, nh=act_dim)
    h = tf.layers.flatten(h)
    h = epsilon * tf.tanh(h)
    return h


def plot_reward(args, mean, std, test_mean):
    test_x = np.linspace(0, mean.shape[0], 1000)
    x = np.arange(mean.shape[0])
    fig = plt.figure()
    # plt.scatter(x, mean)
    plt.plot(test_x, np.repeat(test_mean, 1000), '--')
    plt.errorbar(x, mean, std, linestyle='None', marker="o")
    # plt.show()
    # path = os.path.join('./plot', args.plot_save_path)
    fig.savefig(args.plot_save_path)

