# Deep LogReg-IRL
# Chainer implementation

import numpy as np
import cupy as xp

import os

import chainer
import chainer.links as L
import chainer.functions as F
from chainer import optimizers, Chain, serializers

from tqdm import tqdm

import matplotlib.pyplot as plt

from sampling import generate_qv_data, sample_data, sample_data2

optimizer_D = optimizers.Adam(alpha=0.00004)

optimizer_reward = optimizers.Adam(alpha=0.00004)
optimizer_value = optimizers.Adam(alpha=0.00004)


###### DEFINE NETWORK ######

class Network(Chain):
    # Multi-Layered Perceptron.
    ## n_input : num of neuron in Input layer.
    ## n_hidden : num of neuron in Hidden layer.
    ## n_output : num of neuron in Output layer.

    def __init__(self, n_input, n_hidden, n_output, initW):
        super(Network, self).__init__()
        with self.init_scope():
            self.layer1 = L.Linear(n_input, n_hidden, initialW=initW)
            self.layer2 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer3 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer4 = L.Linear(n_hidden, n_output, initialW=initW)

    def __call__(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.leaky_relu(self.layer2(h1))
        h3 = F.leaky_relu(self.layer3(h2))
        h4 = self.layer4(h3)
        return h4

    def compute(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.leaky_relu(self.layer2(h1))
        h3 = F.leaky_relu(self.layer3(h2))
        h4 = self.layer4(h3)
        return h4

class Network2(Chain):
    # Multi-Layered Perceptron.
    ## n_input : num of neuron in Input layer.
    ## n_hidden : num of neuron in Hidden layer.
    ## n_output : num of neuron in Output layer.

    def __init__(self, n_input, n_hidden, n_output, initW):
        super(Network2, self).__init__()
        with self.init_scope():
            self.layer1 = L.Linear(n_input, n_hidden, initialW=initW)
            self.layer2 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer3 = L.Linear(n_hidden, n_hidden, initialW=initW)
            self.layer4 = L.Linear(n_hidden, n_output, initialW=initW)

    def __call__(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.dropout(F.leaky_relu(self.layer2(h1)), ratio=0.7)
        h3 = F.dropout(F.leaky_relu(self.layer3(h2)), ratio=0.7)
        h4 = self.layer4(h3)
        return h4

    def compute(self, x):
        h1 = F.leaky_relu(self.layer1(x))
        h2 = F.leaky_relu(self.layer2(h1))
        h3 = F.leaky_relu(self.layer3(h2))
        h4 = self.layer4(h3)
        return h4

def DeepLogRegIRL(exp_name, base_name,
                  n_hidden_f=1000, n_hidden_qv=1000,
                  n_epoch1=50, n_epoch2=50, n_step1=1000, n_step2=1000,
                  gamma=0.95, batch_size=16, gpu_device=0,
                  model_path="./model_LogReg/",
                  save_path="./saved_model"):

    exp_data = np.loadtxt("./dataset_train/"+exp_name, delimiter=",")
    base_data = np.loadtxt("./dataset_train/" +base_name, delimiter=",")

    model_path1 = model_path + "f"
    model_path2 = model_path + "qv"

    #Set seed of random num.
    np.random.seed(0)

    _, n_state0 = np.shape(exp_data)
    n_states = n_state0 - 1
    n_input = n_states

    print("Checking directories...")

    if os.path.isdir(model_path)!=True:
        os.mkdir(model_path)

    if os.path.isdir(model_path1)!=True:
        os.mkdir(model_path1)

    if os.path.isdir(model_path2)!=True:
        os.mkdir(model_path2)

    if os.path.isdir(save_path)!=True:
        os.mkdir(save_path)

    print("... done.")

    ########################################################################

    print("Building f-network...")

    D_net = Network2(n_input, n_hidden_f, 1, chainer.initializers.HeNormal())
    optimizer_D.setup(D_net)
    optimizer_D.add_hook(chainer.optimizer.WeightDecay(1e-3))
    D_net.to_gpu(gpu_device)

    print("... done.")


    print("===> START: TRAINING of f-network <===")


    loss_his = []
    time_his = []
    for epoch in range(n_epoch1):
        loss_list = []

        for _ in range(n_step1):
            # Clear gradients.
            D_net.cleargrads()

            # compute f_exp
            s_exp = sample_data(exp_data, batch_size)
            s_exp_data = chainer.Variable(xp.asarray(s_exp, dtype=xp.float32).reshape(batch_size, n_states))
            D_exp = F.sigmoid(D_net(s_exp_data))

            # compute f_base
            s_base = sample_data(base_data, batch_size)
            s_base_data = chainer.Variable(xp.asarray(s_base, dtype=xp.float32).reshape(batch_size, n_states))
            D_base = F.sigmoid(D_net(s_base_data))

            loss_Dis = - F.mean(F.log(F.clip(1-D_base, 1e-10, 1.0))) - F.mean(F.log(F.clip(D_exp, 1e-10, 1.0)))

            loss_D_all = loss_Dis

            prob_D = F.mean(D_exp)

            losda = (xp.asnumpy(loss_D_all.data).reshape(1))[0]
            loss_list.append(losda)

            # backpropagate loss
            loss_D_all.backward()
            optimizer_D.update()


        losse = sum(loss_list) / len(loss_list)

        loss_his.append(losse)

        time_his.append(epoch)

        print(" -- End epoch: {0} /  loss_D: {1}".format(epoch, losse))

    print("=== END OF TRAINING f ===")

    plt.plot(time_his, loss_his, label="loss_D")
    #plt.yscale('log')
    plt.legend(loc="upper right")
    plt.savefig(model_path1 + "/d_learning_D.png")

    plt.close()

    serializers.save_npz(save_path + "/d_net.npz", D_net)

    print("... done.")

    print("Building q-/V-network...")

    reward_net = Network2(n_input, n_hidden_qv, 1, chainer.initializers.HeNormal())
    value_net = Network2(n_input, n_hidden_qv, 1, chainer.initializers.HeNormal())

    optimizer_reward.setup(reward_net)
    optimizer_reward.add_hook(chainer.optimizer.WeightDecay(1e-3))
    optimizer_value.setup(value_net)
    optimizer_value.add_hook(chainer.optimizer.WeightDecay(1e-3))

    reward_net.to_gpu(gpu_device)
    value_net.to_gpu(gpu_device)

    print("... done.")

    # disable update of f-network
    D_net.disable_update()

    print("===> START: TRAINING of q-/V-network <===")

    loss_his = []
    time_his = []

    cur_e, next_e = generate_qv_data(exp_data)
    cur_b, next_b = generate_qv_data(base_data)

    for epoch in range(n_epoch2):

        loss_list = []

        for _ in range(n_step2):
            # Clear gradients.
            reward_net.cleargrads()
            value_net.cleargrads()

            # compute f_exp
            s_exp, s_exp_next = sample_data2(cur_e, next_e, batch_size)
            s_exp_data = chainer.Variable(xp.asarray(s_exp, dtype=xp.float32).reshape(batch_size, n_states))
            s_exp_data_next = chainer.Variable(xp.asarray(s_exp_next, dtype=xp.float32).reshape(batch_size, n_states))

            D_exp = F.sigmoid(D_net.compute(s_exp_data) + reward_net(s_exp_data) + gamma * value_net(s_exp_data_next) - value_net(s_exp_data))

            # R-network

            s_base, s_base_next = sample_data2(cur_b, next_b, batch_size)
            s_base_data = chainer.Variable(xp.asarray(s_base, dtype=xp.float32).reshape(batch_size, n_states))
            s_base_data_next = chainer.Variable(xp.asarray(s_base_next, dtype=xp.float32).reshape(batch_size, n_states))

            D_base = F.sigmoid(D_net.compute(s_base_data) + reward_net(s_base_data) + gamma * value_net(s_base_data_next) - value_net(s_base_data))

            # compute loss
            loss_Dis = - F.mean(F.log(F.clip(1-D_base, 1e-10, 1.0))) - F.mean(F.log(F.clip(D_exp, 1e-10, 1.0)))

            prob_D = F.mean(D_exp)

            losda = (xp.asnumpy(loss_Dis.data).reshape(1))[0]
            loss_list.append(losda)

            loss_D_all = loss_Dis

            # backpropagate loss
            loss_D_all.backward()

            optimizer_reward.update()
            optimizer_value.update()


        losse = sum(loss_list) / len(loss_list)

        loss_his.append(losse)

        time_his.append(epoch)

        print(" -- End epoch: {0} /  loss_D: {1}".format(epoch, losse))


    plt.plot(time_his, loss_his, label="loss_D")
    plt.legend(loc="upper right")
    plt.savefig(model_path2 + "/qv_learning_D.png")

    plt.close()


    serializers.save_npz(save_path + "/r_net.npz", reward_net)
    serializers.save_npz(save_path + "/v_net.npz", value_net)
    serializers.save_npz(save_path + "/d_net.npz", D_net)
