import tensorflow as tf
import numpy as np
import torch

from risk.common.utils import get_network_object, get_env_object


def convert_model(tf_path, config, pi_path, v_path):

    # Policy:

    env = get_env_object(config)
    pi_network = get_network_object(config['pi_network'], env)

    pi_w0 = tf.train.load_variable(tf_path, 'pi/dense/kernel').T
    pi_b0 = tf.train.load_variable(tf_path, 'pi/dense/bias').T
    pi_w1 = tf.train.load_variable(tf_path, 'pi/dense_1/kernel').T
    pi_b1 = tf.train.load_variable(tf_path, 'pi/dense_1/bias').T
    pi_w2 = tf.train.load_variable(tf_path, 'pi/dense_2/kernel').T
    pi_b2 = tf.train.load_variable(tf_path, 'pi/dense_2/bias').T
    pi_l0 = tf.train.load_variable(tf_path, 'pi/log_std').T

    pi_network.state_dict()['mlp.0.weight'].copy_(torch.from_numpy(pi_w0))
    pi_network.state_dict()['mlp.0.bias'].copy_(torch.from_numpy(pi_b0))
    pi_network.state_dict()['mlp.2.weight'].copy_(torch.from_numpy(pi_w1))
    pi_network.state_dict()['mlp.2.bias'].copy_(torch.from_numpy(pi_b1))
    pi_network.state_dict()['mu_layer.weight'].copy_(torch.from_numpy(pi_w2))
    pi_network.state_dict()['mu_layer.bias'].copy_(torch.from_numpy(pi_b2))
    pi_network.state_dict()['log_std'].copy_(torch.from_numpy(pi_l0))

    torch.save(pi_network.state_dict(), pi_path)

    # Value:
    v_network = get_network_object(config['v_network'], env)

    v_w0 = tf.train.load_variable(tf_path, 'vf/dense/kernel').T
    v_b0 = tf.train.load_variable(tf_path, 'vf/dense/bias').T
    v_w1 = tf.train.load_variable(tf_path, 'vf/dense_1/kernel').T
    v_b1 = tf.train.load_variable(tf_path, 'vf/dense_1/bias').T
    v_w2 = tf.train.load_variable(tf_path, 'vf/dense_2/kernel').T
    v_b2 = tf.train.load_variable(tf_path, 'vf/dense_2/bias').T

    v_network.state_dict()['mlp.0.weight'].copy_(torch.from_numpy(v_w0))
    v_network.state_dict()['mlp.0.bias'].copy_(torch.from_numpy(v_b0))
    v_network.state_dict()['mlp.2.weight'].copy_(torch.from_numpy(v_w1))
    v_network.state_dict()['mlp.2.bias'].copy_(torch.from_numpy(v_b1))
    v_network.state_dict()['mlp.4.weight'].copy_(torch.from_numpy(v_w2))
    v_network.state_dict()['mlp.4.bias'].copy_(torch.from_numpy(v_b2))

    torch.save(v_network.state_dict(), v_path)


if __name__ == "__main__":
    pre = '/Users/markojj1/'
    convert_model(pre + '/Desktop/output/trpo_PointButton2_0/simple_save/variables',
                  pre + 'Documents/risk_sensitive_rl/iaa-risk-sensitive-adversarial-ai/PPO/PointButton2/er32_0.json',
                  pre + '/Desktop/output/trpo_PointButton2_0/simple_save/model-latest.pt',
                  pre + '/Desktop/output/trpo_PointButton2_0/simple_save/value-latest.pt')

