import sys
sys.dont_write_bytecode = True

import os
import h5py
import time
import argparse
import numpy as np
import tensorflow as tf

import core
import config

np.set_printoptions(3)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_agents', type=int, required=True)
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--obstacle', type=int, default=0)
    parser.add_argument('--save_traj', type=int, default=0)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--tag', type=str, default='default')
    args = parser.parse_args()
    return args


def build_optimizer(loss):
    optimizer = tf.train.AdamOptimizer(learning_rate=config.LEARNING_RATE)
    trainable_vars = tf.trainable_variables()

    accumulators = [
        tf.Variable(
            tf.zeros_like(tv.initialized_value()),
            trainable=False
        ) for tv in trainable_vars]

    accumulation_counter = tf.Variable(0.0, trainable=False)
    grad_pairs = optimizer.compute_gradients(loss, trainable_vars)

    accumulate_ops = [
        accumulator.assign_add(
            grad
        ) for (accumulator, (grad, var)) in zip(accumulators, grad_pairs)]

    accumulate_ops.append(accumulation_counter.assign_add(1.0))

    gradient_vars = [(accumulator / accumulation_counter, var) \
            for (accumulator, (grad, var)) in zip(accumulators, grad_pairs)]

    gradient_vars_h = []
    gradient_vars_a = []
    for accumulate_grad, var in gradient_vars:
        if 'cbf' in var.name:
            gradient_vars_h.append((accumulate_grad, var))
        elif 'action' in var.name:
            gradient_vars_a.append((accumulate_grad, var))
        else:
            raise ValueError

    train_step_h = optimizer.apply_gradients(gradient_vars_h)
    train_step_a = optimizer.apply_gradients(gradient_vars_a)

    zero_ops = [
        accumulator.assign(
            tf.zeros_like(tv)
        ) for (accumulator, tv) in zip(accumulators, trainable_vars)]
    zero_ops.append(accumulation_counter.assign(0.0))
    
    return zero_ops, accumulate_ops, train_step_h, train_step_a


def build_training_graph(num_agents):

    s = tf.placeholder(tf.float32, [num_agents, 5])
    z_ref = tf.placeholder(tf.float32, [num_agents, 6])
    o = tf.placeholder(tf.float32, [None, 2])
    d = core.detect_nearest_obstacles(s, o)

    x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0)
    h, mask, indices = core.network_cbf(
        x=x, d=d, r=config.DIST_MIN_THRES, indices=None)
    u = core.network_action(
        s=s, z_ref=z_ref, d=d, obs_radius=config.OBS_RADIUS, indices=indices)    
    loss_dang, loss_safe, acc_dang, acc_safe = core.loss_barrier(
        h=h, s=s, d=d, indices=indices)
    (loss_dang_deriv, loss_safe_deriv, loss_medium_deriv, acc_dang_deriv, 
    acc_safe_deriv, acc_medium_deriv) = core.loss_derivatives(
        s=s, u=u, h=h, x=x, d=d, indices=indices)

    loss_action = core.loss_actions(s=s, u=u, z_ref=z_ref, d=d, indices=indices)

    loss_list = [loss_dang, loss_safe, 3 * loss_dang_deriv, 
                 loss_safe_deriv, loss_medium_deriv, 0.1 * loss_action]
    acc_list = [acc_dang, acc_safe, acc_dang_deriv, acc_safe_deriv, acc_medium_deriv]

    weight_loss = [
        config.WEIGHT_DECAY * tf.nn.l2_loss(v) for v in tf.trainable_variables()]
    loss = 10 * tf.math.add_n(loss_list + weight_loss)
    return s, z_ref, o, u, loss_list, loss, acc_list


def count_accuracy(accuracy_lists):
    acc = np.array(accuracy_lists)
    acc_list = []
    for i in range(acc.shape[1]):
        acc_list.append(np.mean(acc[acc[:, i] >= 0, i]))
    return acc_list


def main():
    args = parse_args()

    src_dir = os.path.join('src_{}'.format(args.tag))
    if not os.path.exists(src_dir):
        os.mkdir(src_dir)
    os.system('cp *py {}'.format(src_dir))

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    s, z_ref, o, u, loss_list, loss, acc_list = build_training_graph(args.num_agents)
    zero_ops, accumulate_ops, train_step_h, train_step_a = build_optimizer(loss)
    accumulate_ops.append(loss_list)
    accumulate_ops.append(acc_list)

    accumulation_steps = config.INNER_LOOPS

    run_baseline = True
    save_traj = int(args.save_traj)
    if save_traj:
        if not os.path.exists('./trajectory'):
            os.mkdir('./trajectory')
        f_traj = h5py.File('./trajectory/traj.hdf5', 'w')
        dataset_state = f_traj.create_dataset('state', 
        (config.SAVE_TRAJ * config.INNER_LOOPS, args.num_agents, 5), dtype='float')
        dataset_reference = f_traj.create_dataset('reference', 
        (config.SAVE_TRAJ * config.INNER_LOOPS, args.num_agents, 6), dtype='float')
        dataset_pointer = 0

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        if args.model_path:
            saver.restore(sess, args.model_path)

        loss_lists_np = []
        acc_lists_np = []
        dist_errors_np = []
        dist_errors_baseline_np = []

        safety_ratios_epoch = []
        safety_ratios_epoch_baseline = []
        
        start_time = time.time()
        for istep in range(config.TRAIN_STEPS):
            scale, omega, phase, trans, init, obstacles = core.generate_data(args.num_agents)
            sess.run(zero_ops)
            s_np = np.copy(init)
            for i in range(accumulation_steps):
                z, dzdt, ddzdt = core.reference_trajectory_np(
                    i * config.TIME_STEP + config.TIME_OFFSET, scale, omega, phase, trans)
                z_ref_np = np.concatenate([z, dzdt, ddzdt], axis=1)
                u_np, out = sess.run(
                    [u, accumulate_ops], feed_dict={s:s_np, z_ref: z_ref_np, o: obstacles})
                dsdt = core.car_dynamics_np(s_np, u_np)
                s_np = s_np + dsdt * config.TIME_STEP
                safety_ratio = 1 - np.mean(
                    core.dangerous_mask_np(s_np, obstacles, config.DIST_MIN_CHECK), axis=1)
                safety_ratio = np.mean(safety_ratio == 1)
                safety_ratios_epoch.append(safety_ratio)
                loss_list_np, acc_list_np = out[-2], out[-1]
                loss_lists_np.append(loss_list_np)
                acc_lists_np.append(acc_list_np)
                dist_errors_np.append(np.mean(np.linalg.norm(s_np[:, :2] - z, axis=1)))

                if save_traj:
                    if dataset_pointer == config.SAVE_TRAJ * config.INNER_LOOPS:
                        f_traj.close()
                        print('Trajectories saved.')
                        save_traj = 0
                        continue
                    dataset_state[dataset_pointer] = s_np
                    dataset_reference[dataset_pointer] = z_ref_np
                    dataset_pointer = dataset_pointer + 1
            
            if run_baseline:
                s_np = np.copy(init)
                for i in range(accumulation_steps):
                    z, dzdt, ddzdt = core.reference_trajectory_np(
                        i * config.TIME_STEP + config.TIME_OFFSET, scale, omega, phase, trans)
                    u_np = core.car_controller_np(s_np, z, dzdt, ddzdt)
                    dsdt = core.car_dynamics_np(s_np, u_np)
                    s_np = s_np + dsdt * config.TIME_STEP
                    safety_ratio = 1 - np.mean(
                        core.dangerous_mask_np(s_np, obstacles, config.DIST_MIN_CHECK), axis=1)
                    safety_ratio = np.mean(safety_ratio == 1)
                    safety_ratios_epoch_baseline.append(safety_ratio)
                    dist_errors_baseline_np.append(np.mean(np.linalg.norm(s_np[:, :2] - z, axis=1)))
                run_baseline = False

            if np.mod(istep // 10, 2) == 0:
                sess.run(train_step_h)
            else:
                sess.run(train_step_a)
            
            if np.mod(istep, config.DISPLAY_STEPS) == 0:
                print('Step: {}, Time: {:.1f}, Loss: {}, Acc: {}, Dist: {:.3f} | {:.3f}, Safe: {:.3f} | {:.3f}'.format(
                    istep, time.time() - start_time, np.mean(loss_lists_np, axis=0), np.array(count_accuracy(acc_lists_np)),
                    np.mean(dist_errors_np), np.mean(dist_errors_baseline_np), 
                    np.mean(safety_ratios_epoch), np.mean(safety_ratios_epoch_baseline)))
                start_time = time.time()
                (loss_lists_np, acc_lists_np, dist_errors_np, safety_ratios_epoch) = [], [], [], []

            if np.mod(istep, config.SAVE_STEPS) == 0 or istep + 1 == config.TRAIN_STEPS:
                saver.save(sess, 'models/model_{}_iter_{}'.format(args.tag, istep))


if __name__ == '__main__':
    main()
