import sys
sys.dont_write_bytecode = True

import os
import time
import argparse
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pickle

import core
import config


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_agents', type=int, required=True)
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--obstacle', type=int, default=0)
    parser.add_argument('--vis', type=int, default=0)
    parser.add_argument('--gpu', type=str, default='0')
    args = parser.parse_args()
    return args


def build_evaluation_graph(num_agents):
    s = tf.placeholder(tf.float32, [num_agents, 5])
    z_ref = tf.placeholder(tf.float32, [num_agents, 6])
    o = tf.placeholder(tf.float32, [None, 2])
    d = core.detect_nearest_obstacles(s, o)
    
    x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0)
    h, mask, indices = core.network_cbf(
        x=x, d=d, r=config.DIST_MIN_THRES, indices=None)
    u = core.network_action(
        s=s, z_ref=z_ref, d=d, obs_radius=config.OBS_RADIUS, indices=indices)
    safe_mask = core.compute_safe_mask(s, d=d, r=config.DIST_SAFE, indices=indices)
    is_safe = tf.equal(tf.reduce_mean(tf.cast(safe_mask, tf.float32)), 1)

    u_res = tf.Variable(tf.zeros_like(u), name='u_res')
    loop_count = tf.Variable(0, name='loop_count')
   
    def opt_body(u_res, loop_count, is_safe):
        dsdt = core.car_dynamics_tf(s, u + u_res)
        s_next = s + dsdt * config.TIME_STEP_EVAL
        d_next = core.detect_nearest_obstacles(s_next, o)
        x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0)
        h_next, mask_next, _ = core.network_cbf(
            x=x_next, d=d_next, r=config.DIST_MIN_THRES, indices=indices)
        deriv = h_next - h + config.TIME_STEP_EVAL * config.ALPHA_CBF * h
        deriv = deriv * mask * mask_next
        error = tf.reduce_sum(tf.math.maximum(-deriv, 0), axis=1)
        error_gradient = tf.gradients(error, u_res)[0]
        u_res = u_res - config.REFINE_LEARNING_RATE * error_gradient
        loop_count = loop_count + 1
        return u_res, loop_count, is_safe

    def opt_cond(u_res, loop_count, is_safe):
        cond = tf.logical_and(
            tf.less(loop_count, config.REFINE_LOOPS), 
            tf.logical_not(is_safe))
        return cond
    
    with tf.control_dependencies([
        u_res.assign(tf.zeros_like(u)), loop_count.assign(0)]):
        u_res, _, _ = tf.while_loop(opt_cond, opt_body, [u_res, loop_count, is_safe])
        u_opt = u + u_res

    loss_dang, loss_safe, acc_dang, acc_safe = core.loss_barrier(
        h=h, s=s, d=d, indices=indices)
    (loss_dang_deriv, loss_safe_deriv, loss_medium_deriv, acc_dang_deriv, 
    acc_safe_deriv, acc_medium_deriv) = core.loss_derivatives(
        s=s, u=u_opt, h=h, x=x, d=d, indices=indices)

    loss_action = core.loss_actions(s=s, u=u_opt, z_ref=z_ref, d=d, indices=indices)

    loss_list = [loss_dang, loss_safe, loss_dang_deriv, 
                 loss_safe_deriv, loss_medium_deriv, loss_action]
    acc_list = [acc_dang, acc_safe, acc_dang_deriv, acc_safe_deriv, acc_medium_deriv]

    return s, z_ref, o, u_opt, loss_list, acc_list

    
def print_accuracy(accuracy_lists):
    acc = np.array(accuracy_lists)
    acc_list = []
    for i in range(acc.shape[1]):
        acc_i = acc[:, i]
        acc_list.append(np.mean(acc_i[acc_i > 0]))
    print('Accuracy: {}'.format(acc_list))


def render_init(num_agents):
    fig = plt.figure(figsize=(4 * num_agents // 2, 5))
    return fig


def write_trajectory(data_path, traj):
    write_dict = {}
    write_dict['trajectory'] = traj
    write_dict['start_points'] = np.array([t[0, :, :2] for t in traj])
    write_dict['end_points'] = np.array([t[-1, :, :2] for t in traj])
    pickle.dump(write_dict, open(data_path, 'wb'))


def main():
    args = parse_args()
    s, z_ref, o, u, loss_list, acc_list = build_evaluation_graph(args.num_agents)
    
    vars = tf.trainable_variables()
    vars_restore = []
    for v in vars:
        if 'action' in v.name or 'cbf' in v.name:
            vars_restore.append(v)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=vars_restore)
    saver.restore(sess, args.model_path)

    safety_ratios_epoch = []
    safety_ratios_epoch_baseline = []

    dist_errors = []
    dist_errors_baseline = []
    accuracy_lists = []

    safety_reward = []
    dist_reward = []

    if args.vis > 0:
        plt.ion()
        plt.close()
        fig = render_init(args.num_agents)
        z_accumulate = []

    if not os.path.exists('trajectory'):
        os.mkdir('trajectory')
    traj_dict = {'ours': {'z': [], 's': []}, 
                 'baseline': {'z': [], 's': []},
                 'obstacles': []}

    for istep in range(config.EVALUATE_STEPS):
        start_time = time.time()
        scale, omega, phase, trans, init, obstacles = core.generate_data(args.num_agents)
        traj_dict['obstacles'].append(np.array(obstacles, dtype=np.float32))
        s_np = np.copy(init)
        z_traj = []
        s_traj = []
        safety_info = []
        dist_info = []
        for i in range(config.INNER_LOOPS_EVAL):
            if args.vis == 2:
                continue
            z, dzdt, ddzdt = core.reference_trajectory_np(
                i * config.TIME_STEP_EVAL + config.TIME_OFFSET, scale, omega, phase, trans)
            z_ref_np = np.concatenate([z, dzdt, ddzdt], axis=1)
            u_network, acc_list_np = sess.run(
                [u, acc_list], feed_dict={s: s_np, z_ref: z_ref_np, o: obstacles})
            dsdt = core.car_dynamics_np(s_np, u_network)
            s_np = s_np + dsdt * config.TIME_STEP_EVAL
            safety_ratio = 1 - np.mean(
                core.dangerous_mask_np(s_np, obstacles, config.DIST_MIN_CHECK), axis=1)
            individual_safety = safety_ratio == 1
            safety_info.append(individual_safety.astype(np.float32).reshape((1, -1)))
            safety_ratio = np.mean(individual_safety)
            safety_ratios_epoch.append(safety_ratio)
            accuracy_lists.append(acc_list_np)

        
            if args.vis == 1 and np.mod(i, 5) == 0:
                plt.clf()
                colors = []
                for j in range(individual_safety.shape[0]):
                    if individual_safety[j] == 1:
                        colors.append('darkorange')
                    else:
                        colors.append('darkblue')
                z_accumulate.append(z)
                z_concat = np.concatenate(z_accumulate, axis=0)
                plt.scatter(z_concat[:, 0], z_concat[:, 1], color='grey', s=20)
                plt.scatter(s_np[:, 0], s_np[:, 1], color=colors, s=100)
                plt.scatter(obstacles[:, 0], obstacles[:, 1], color='red', s=8, alpha=0.5)
                plt.xlim(-12, (args.num_agents // 2 - 1) * 12 + 12)
                plt.ylim(-10, 10)
                ax = plt.gca()
                for side in ax.spines.keys():
                    ax.spines[side].set_linewidth(2)
                    ax.spines[side].set_color('grey')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                fig.canvas.draw()
           
            z_traj.append(np.expand_dims(z, axis=0))
            s_traj.append(np.expand_dims(s_np[:, :4], axis=0))
            dist_errors.append(np.mean(np.linalg.norm(s_np[:, :2] - z, axis=1)))
            dist_info.append(np.reshape(np.linalg.norm(s_np[:, :2] - z, axis=1), (1, -1)))
        traj_dict['ours']['z'].append(np.concatenate(z_traj, axis=0))
        traj_dict['ours']['s'].append(np.concatenate(s_traj, axis=0))
        episode_safe_reward = np.sum(np.concatenate(safety_info, axis=0) - 1, axis=0)
        episode_dist = np.mean(np.concatenate(dist_info, axis=0), axis=0)
        episode_dist_reward = (episode_dist < 1.0).astype(np.float32) * 10
        safety_reward.append(np.mean(episode_safe_reward))
        dist_reward.append(np.mean(episode_dist_reward))
        end_time = time.time()

        s_np = np.copy(init)
        z_traj = []
        s_traj = []
        for i in range(config.INNER_LOOPS_EVAL):
            z, dzdt, ddzdt = core.reference_trajectory_np(
                    i * config.TIME_STEP_EVAL + config.TIME_OFFSET, scale, omega, phase, trans)
            u_np = core.car_controller_np(s_np, z, dzdt, ddzdt)
            dsdt = core.car_dynamics_np(s_np, u_np)
            s_np = s_np + dsdt * config.TIME_STEP_EVAL
            safety_ratio = 1 - np.mean(
                core.dangerous_mask_np(s_np, obstacles, config.DIST_MIN_CHECK), axis=1)
            individual_safety = safety_ratio == 1
            safety_ratio = np.mean(individual_safety)
            safety_ratios_epoch_baseline.append(safety_ratio)
            dist_errors_baseline.append(np.mean(np.linalg.norm(s_np[:, :2] - z, axis=1)))

            if args.vis == 2 and np.mod(i, 5) == 0:
                plt.clf()
                colors = []
                for j in range(individual_safety.shape[0]):
                    if individual_safety[j] == 1:
                        colors.append('darkorange')
                    else:
                        colors.append('darkblue')
                z_accumulate.append(z)
                z_concat = np.concatenate(z_accumulate, axis=0)
                plt.scatter(z_concat[:, 0], z_concat[:, 1], color='grey', s=20)
                plt.scatter(s_np[:, 0], s_np[:, 1], color=colors, s=100)
                plt.scatter(obstacles[:, 0], obstacles[:, 1], color='red', s=8, alpha=0.5)
                plt.xlim(-12, (args.num_agents // 2 - 1) * 12 + 12)
                plt.ylim(-10, 10)
                ax = plt.gca()
                for side in ax.spines.keys():
                    ax.spines[side].set_linewidth(2)
                    ax.spines[side].set_color('grey')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                fig.canvas.draw()
                time.sleep((end_time - start_time) / config.INNER_LOOPS_EVAL)
            z_traj.append(np.expand_dims(z, axis=0))
            s_traj.append(np.expand_dims(s_np[:, :2], axis=0))
        traj_dict['baseline']['z'].append(np.concatenate(z_traj, axis=0))
        traj_dict['baseline']['s'].append(np.concatenate(s_traj, axis=0))
        print('Evaluation Step: {} | {}, Time: {:.4f}'.format(
            istep + 1, config.EVALUATE_STEPS, end_time - start_time))

    print_accuracy(accuracy_lists)
    print('Distance Error (Learning | Baseline): {:.4f} | {:.4f}'.format(
          np.mean(dist_errors), np.mean(dist_errors_baseline)))
    print('Mean Safety Ratio (Learning | Baseline): {:.4f} | {:.4f}'.format(
          np.mean(safety_ratios_epoch), np.mean(safety_ratios_epoch_baseline)))
    print('Reward Safety : {:.4f}, Reward Distance: {:.4f}'.format(
        np.mean(safety_reward), np.mean(dist_reward)))

    pickle.dump(traj_dict, open('trajectory/traj_eval.pkl', 'wb'))
    write_trajectory('./trajectory/rings.pkl', traj_dict['ours']['s'])


if __name__ == '__main__':
    main()
