# Deep SOLO-IRL
# Chainer implementation

import numpy as np
import cupy as xp

import os

import chainer
import chainer.links as L
import chainer.functions as F
from chainer import optimizers, Chain, serializers

from tqdm import tqdm

import matplotlib.pyplot as plt

from sampling import generate_qv_data, sample_data, sample_data2

optimizer_D = optimizers.Adam(alpha=0.00004)
optimizer_R = optimizers.Adam(alpha=0.0001)

optimizer_reward = optimizers.Adam(alpha=0.00004)
optimizer_value = optimizers.Adam(alpha=0.00004)

optimizer_R_2 = optimizers.Adam(alpha=0.0001)
optimizer_R_3 = optimizers.Adam(alpha=0.0001)


###### DEFINE NETWORK ######

class Network(Chain):
    # Multi-Layered Perceptron.
    ## n_input : num of neuron in Input layer.
    ## n_hidden : num of neuron in Hidden layer.
    ## n_output : num of neuron in Output layer.

    def __init__(self, n_input, n_hidden, n_output, initW):
        super(Network, self).__init__()
        with self.init_scope():
            self.layer1 = L.Linear(n_input, n_hidden, initialW=initW)
            self.layer2 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer3 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer4 = L.Linear(n_hidden, n_output, initialW=initW)

    def __call__(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.leaky_relu(self.layer2(h1))
        h3 = F.leaky_relu(self.layer3(h2))
        h4 = self.layer4(h3)
        return h4

    def compute(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.leaky_relu(self.layer2(h1))
        h3 = F.leaky_relu(self.layer3(h2))
        h4 = self.layer4(h3)
        return h4

class Network2(Chain):
    # Multi-Layered Perceptron.
    ## n_input : num of neuron in Input layer.
    ## n_hidden : num of neuron in Hidden layer.
    ## n_output : num of neuron in Output layer.

    def __init__(self, n_input, n_hidden, n_output, initW):
        super(Network2, self).__init__()
        with self.init_scope():
            self.layer1 = L.Linear(n_input, n_hidden, initialW=initW)
            self.layer2 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer3 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer4 = L.Linear(n_hidden, n_output, initialW=initW)

    def __call__(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.dropout(F.leaky_relu(self.layer2(h1)), ratio=0.7)
        h3 = F.dropout(F.leaky_relu(self.layer3(h2)), ratio=0.7)
        h4 = self.layer4(h3)
        return h4

    def compute(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.leaky_relu(self.layer2(h1))
        h3 = F.leaky_relu(self.layer3(h2))
        h4 = self.layer4(h3)
        return h4

def SOLOIRL(exp_name, base_name,
                  n_hidden_f=1000, n_hidden_qv=1000,
                  n_epoch1=50, n_epoch2=50, n_step1=1000, n_step2=1000,
                  gamma=0.95, batch_size=16, gpu_device=0,
                  model_path="./model_SOLO/",
                  save_path="./saved_model"):

    exp_data = np.loadtxt("./dataset_train/"+exp_name, delimiter=",")

    model_path1 = model_path + "f"
    model_path2 = model_path + "qv"

    #Set seed of random num.
    np.random.seed(0)

    _, n_state0 = np.shape(exp_data)
    n_states = n_state0 - 1
    n_input = n_states

    print("Checking directories...")

    if os.path.isdir(model_path)!=True:
        os.mkdir(model_path)

    if os.path.isdir(model_path1)!=True:
        os.mkdir(model_path1)

    if os.path.isdir(model_path2)!=True:
        os.mkdir(model_path2)

    if os.path.isdir(save_path)!=True:
        os.mkdir(save_path)

    print("... done.")

    ########################################################################

    print("Building f-network...")

    D_net = Network2(n_input, n_hidden_f, 1, chainer.initializers.HeNormal())
    optimizer_D.setup(D_net)
    D_net.to_gpu(gpu_device)

    print("... done.")


    # AutoEncoder
    z_size = 100
    R_enc_net = Network(n_input, n_hidden_f, z_size, chainer.initializers.HeNormal())
    R_dec_net = Network(z_size, n_hidden_f, n_input, chainer.initializers.HeNormal())
    optimizer_R.setup(R_enc_net)
    optimizer_R.setup(R_dec_net)
    R_enc_net.to_gpu(gpu_device)
    R_dec_net.to_gpu(gpu_device)

    print("===> START: TRAINING of f-network <===")

    old_loss = np.inf
    best_epoch = 0

    loss_his = []
    lossval_his = []
    loss_r2_his = []
    time_his = []
    for epoch in range(n_epoch1):
        loss_list = []
        lossval_list = []
        lossr2_list = []

        for _ in range(n_step1):
            # Clear gradients.
            D_net.cleargrads()

            # compute f_exp
            s_exp = sample_data(exp_data, batch_size)
            # print(s_exp.shape)
            s_exp_data = chainer.Variable(xp.asarray(s_exp, dtype=xp.float32).reshape(batch_size, n_states))
            D_exp = F.sigmoid(D_net(s_exp_data))

            # R-network
            s_noise = np.random.normal(0, 0.001, size=(batch_size, n_states))
            s_exp_n = s_exp + s_noise

            s_exp_data_n = chainer.Variable(xp.asarray(s_exp_n, dtype=xp.float32).reshape(batch_size, n_states))
            s_exp_reconst = R_dec_net(R_enc_net(s_exp_data_n))
            D_reconst = F.sigmoid(D_net(s_exp_reconst))

            loss_Dis = 0.5 * F.mean((D_exp - 1) ** 2) + 0.5 * F.mean((D_reconst - 0) ** 2)

            loss_D_all = loss_Dis

            prob_D = F.mean(D_exp)

            losda = (xp.asnumpy(prob_D.data).reshape(1))[0]
            loss_list.append(losda)

            # backpropagate loss
            loss_D_all.backward()
            optimizer_D.update()

            #############################################

            R_enc_net.cleargrads()
            R_dec_net.cleargrads()

            loss_Gen = 0.5 * F.mean((D_reconst - 1) ** 2)

            loss_r2 = F.mean((s_exp_data - s_exp_reconst) ** 2)

            prob_G = F.mean(D_reconst)

            losda_r = (xp.asnumpy(prob_G.data).reshape(1))[0]
            lossval_list.append(losda_r)
            losda_r2 = (xp.asnumpy(loss_r2.data).reshape(1))[0]
            lossr2_list.append(losda_r2)

            loss_R_all = loss_Gen + loss_r2

            loss_R_all.backward()

            # Update optimizer
            optimizer_R.update()

        losse = sum(loss_list) / len(loss_list)
        lossev = sum(lossval_list) / len(lossval_list)
        losser2 = sum(lossr2_list) / len(lossr2_list)

        loss_his.append(losse)
        lossval_his.append(lossev)
        loss_r2_his.append(losser2)

        time_his.append(epoch)

        print(" -- End epoch: {0} /  prob_D: {1} / prob_G: {2} / loss_R: {3}".format(epoch, losse, lossev, losser2))

    print("=== END OF TRAINING f ===")

    plt.plot(time_his, loss_his, label="prob_D")
    plt.plot(time_his, lossval_his, label="prob_G")
    #plt.yscale('log')
    plt.legend(loc="upper right")
    plt.savefig(model_path1 + "/d_learning_DG.png")

    plt.close()

    plt.plot(time_his, loss_r2_his, label="loss(R)")
    plt.legend(loc="upper right")
    plt.yscale('log')
    plt.savefig(model_path1 + "/d_learning_R.png")

    plt.close()

    serializers.save_npz(save_path + "/d_net.npz", D_net)

    print("... done.")

    print("Building q-/V-network...")

    reward_net = Network2(n_input, n_hidden_qv, 1, chainer.initializers.HeNormal())
    value_net = Network2(n_input, n_hidden_qv, 1, chainer.initializers.HeNormal())

    optimizer_reward.setup(reward_net)
    optimizer_value.setup(value_net)

    reward_net.to_gpu(gpu_device)
    value_net.to_gpu(gpu_device)

    R_enc_2_net = Network(n_input, n_hidden_qv, z_size, chainer.initializers.HeNormal())
    R_dec_2_net = Network(z_size, n_hidden_qv, n_input, chainer.initializers.HeNormal())

    R_enc_3_net = Network(n_input, n_hidden_qv, z_size, chainer.initializers.HeNormal())
    R_dec_3_net = Network(z_size, n_hidden_qv, n_input, chainer.initializers.HeNormal())

    optimizer_R_2.setup(R_enc_2_net)
    optimizer_R_2.setup(R_dec_2_net)

    optimizer_R_3.setup(R_enc_3_net)
    optimizer_R_3.setup(R_dec_3_net)

    R_enc_2_net.to_gpu(gpu_device)
    R_dec_2_net.to_gpu(gpu_device)

    R_enc_3_net.to_gpu(gpu_device)
    R_dec_3_net.to_gpu(gpu_device)

    print("... done.")

    # disable update of f-network
    D_net.disable_update()

    print("===> START: TRAINING of q-/V-network <===")

    best_epoch2 = 0
    old_loss = np.inf
    fig = plt.figure()
    loss_his = []
    lossval_his = []
    loss_r2_his = []
    time_his = []

    cur_e, next_e = generate_qv_data(exp_data)

    for epoch in range(n_epoch2):

        loss_list = []
        lossval_list = []
        lossr2_list = []

        for _ in range(n_step2):
            # Clear gradients.
            reward_net.cleargrads()
            value_net.cleargrads()

            # compute f_exp
            s_exp, s_exp_next = sample_data2(cur_e, next_e, batch_size)
            s_exp_data = chainer.Variable(xp.asarray(s_exp, dtype=xp.float32).reshape(batch_size, n_states))
            s_exp_data_next = chainer.Variable(xp.asarray(s_exp_next, dtype=xp.float32).reshape(batch_size, n_states))

            D_exp = F.sigmoid(D_net.compute(s_exp_data) + reward_net(s_exp_data) + gamma * value_net(s_exp_data_next) - value_net(s_exp_data))

            # R-network

            s_noise = np.random.normal(0, 0.001, size=(batch_size, n_states))

            s_exp_n = s_exp + s_noise

            s_exp_data_n = chainer.Variable(xp.asarray(s_exp_n, dtype=xp.float32).reshape(batch_size, n_states))

            s_exp_reconst = R_dec_2_net(R_enc_2_net(s_exp_data_n))

            s_exp_next_reconst = R_dec_3_net(R_enc_3_net(s_exp_data_n))

            D_reconst = F.sigmoid(D_net.compute(s_exp_reconst) + reward_net(s_exp_reconst) + gamma * value_net(s_exp_next_reconst) - value_net(s_exp_reconst))

            # compute loss
            loss_Dis = 0.5 * F.mean((D_exp - 1)**2) + 0.5 * F.mean((D_reconst - 0)**2)

            prob_D = F.mean(D_exp)

            losda = (xp.asnumpy(prob_D.data).reshape(1))[0]
            loss_list.append(losda)

            loss_D_all = loss_Dis

            # backpropagate loss
            loss_D_all.backward()

            optimizer_reward.update()
            optimizer_value.update()

            ########################################

            R_enc_2_net.cleargrads()
            R_dec_2_net.cleargrads()
            R_enc_3_net.cleargrads()
            R_dec_3_net.cleargrads()

            loss_Gen = 0.5 * F.mean((D_reconst - 1)**2)

            loss_r2 = F.mean((s_exp_data - s_exp_reconst) ** 2) + F.mean((s_exp_data_next - s_exp_next_reconst) ** 2)

            prob_G = F.mean(D_reconst)

            losda_r = (xp.asnumpy(prob_G.data).reshape(1))[0]
            lossval_list.append(losda_r)
            losda_r2 = (xp.asnumpy(loss_r2.data).reshape(1))[0]
            lossr2_list.append(losda_r2)

            loss_R_all = loss_Gen + loss_r2

            # Update optimizer
            loss_R_all.backward()

            optimizer_R_2.update()
            optimizer_R_3.update()

        losse = sum(loss_list) / len(loss_list)
        lossev = sum(lossval_list) / len(lossval_list)
        losser2 = sum(lossr2_list) / len(lossr2_list)

        loss_his.append(losse)
        lossval_his.append(lossev)
        loss_r2_his.append(losser2)

        time_his.append(epoch)

        print(" -- End epoch: {0} /  prob_D: {1} / prob_G: {2} / loss_R: {3}".format(epoch, losse, lossev, losser2))


    plt.plot(time_his, loss_his, label="prob_D")
    plt.plot(time_his, lossval_his, label="prob_G")
    plt.legend(loc="upper right")
    plt.savefig(model_path2 + "/qv_learning_DG.png")

    plt.close()

    plt.plot(time_his, loss_r2_his, label="loss(R)")
    plt.legend(loc="upper right")
    plt.yscale('log')
    plt.savefig(model_path2 + "/qv_learning_R.png")

    plt.close()


    serializers.save_npz(save_path + "/r_net.npz", reward_net)
    serializers.save_npz(save_path + "/v_net.npz", value_net)
    serializers.save_npz(save_path + "/r_enc_2.npz", R_enc_2_net)
    serializers.save_npz(save_path + "/r_dec_2.npz", R_dec_2_net)
    serializers.save_npz(save_path + "/r_enc_3.npz", R_enc_3_net)
    serializers.save_npz(save_path + "/r_dec_3.npz", R_dec_3_net)
    serializers.save_npz(save_path + "/d_net.npz", D_net)
