import os
import pickle
import numpy as np
import argparse
import matplotlib.pyplot as plt
from matplotlib import cm


def main():
    # parse argument and read data
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-path', type=str, required=True)
    parser.add_argument('--dt-path', type=str, required=True)
    parser.add_argument('--mode', type=str, default='neuron_to_action')
    parser.add_argument('--out-dir', type=str, default=None)
    parser.add_argument('--model', type=str, default='ncp')
    parser.add_argument('--env', type=str, default='canyonrun')
    parser.add_argument('--use-inv-dt', action='store_true', default=False)
    parser.add_argument('--not-use-capped-data', action='store_true', default=False)
    args = parser.parse_args()

    with open(args.results_path, 'rb') as f:
        data = pickle.load(f)

    if args.env in ['cartpole', 'halfcheetah', 'pendulum']:
        state_data = [[step_v[0] for step_v in ep_v] for ep_v in data]
    elif args.env == 'driving':
        vec_obs_names = ['v', 'delta', 'dx', 'dy', 'dyaw', 'kappa'] # 'x', 'y', 'd', 'mu'

        state_data = []
        for ep_data in data:
            state_data.append([])
            for step_data in ep_data:
                ego_agent_id = step_data['ego_agent_id']
                vec_obs = []
                for vec_obs_k in vec_obs_names:
                    vec_obs.append(step_data['logs'][ego_agent_id][vec_obs_k])
                state_data[-1].append(vec_obs)
    else:
        raise ValueError

    with open(args.dt_path, 'rb') as f:
        dt_all = pickle.load(f)

    if 'cmd_neuron_indices' in dt_all.keys():
        cmd_neuron_indices = dt_all.pop('cmd_neuron_indices')
        has_cmd_neuron_indices = True
    else:
        has_cmd_neuron_indices = False

    new_dt_all = dict()
    dt_inv_all = dict()
    for dt_name, dt in dt_all.items():
        if 'inv:' in dt_name:
            dt_inv_all[dt_name.replace('inv:', '')] = dt
        else:
            new_dt_all[dt_name] = dt
    dt_all = new_dt_all

    if args.out_dir is not None and not os.path.isdir(args.out_dir):
        os.makedirs(args.out_dir)
    
    # parse decision trees
    dt_info_all = dict()
    for dt_name, dt in dt_all.items():
        dt_info_all[dt_name] = parse_decision_tree(dt)

    # interpret neuron activation along time
    prop_seq_all = []
    decision_path_violation = {k: [] for k in dt_all.keys()}
    logic_program = []
    for ep_i, ep_data in enumerate(data):
        # NOTE: during actual deployment, x_test should always be neuron activation
        if args.mode in ['neuron_to_action', 'action_to_neuron_inv', 'action_to_neuron_to_action',
                         'info_to_neuron_inv']:
            if args.env == 'driving':
                ep_states = np.array([[vv.squeeze() for vv in v['rnn_state']] for v in ep_data])
                ep_actions = np.array([v['actions'][v['ego_agent_id']] for v in ep_data])
            else:
                ep_states = np.array([v[2] for v in ep_data])
                ep_actions = np.array([v[1] for v in ep_data])
            if args.model == 'ncp':
                assert has_cmd_neuron_indices, 'Need to have cmd neuron indices'
                if args.env == 'canyonrun':
                    x_test = ep_states[:, 0, 8:18] # HACK: hardcoded indexing to get cmd neuron activation
                else:
                    if len(ep_actions.shape) == 1:
                        raw_act_dim = 1
                    else:
                        raw_act_dim = ep_actions.shape[1]
                    assert cmd_neuron_indices[0] == raw_act_dim, 'cmd neuron indices not consistent'
                    x_test = ep_states[:, 0, cmd_neuron_indices[0]:cmd_neuron_indices[1]]
            elif args.model == 'lstm':
                x_test = np.concatenate(np.split(ep_states, 2, axis=1), axis=-1)
                x_test = x_test[:, 0, :]
            else:
                x_test = ep_states[:, 0, :]
            y_test = ep_actions
            if args.env == 'canyonrun':
                if args.mode == 'neuron_to_action':
                    feature_name = [f'cmd{i}' for i in range(x_test.shape[1])]
                elif args.mode == 'action_to_neuron_inv':
                    feature_name = ['pitch', 'roll', 'yaw', 'throttle']
                elif args.mode == 'action_to_neuron_to_action':
                    feature_name = ['pitch', 'roll', 'yaw', 'throttle']
                elif args.mode == 'info_to_neuron_inv':
                    feature_name = ['throttle', 'normed_position_y', 'normed_forward_speed', 'rotation_0',
                        'rotation_1', 'rotation_2', 'rotation_3', 'engine_power', 'forward_direction_x', 'forward_direction_y',
                        'forward_direction_z', 'lift_power', 'lift_direction_x', 'lift_direction_y', 'lift_direction_z',
                        'pitch_input', 'yaw_input', 'roll_input', 'drag', 'angular_drag', 'roll_angle', 'pitch_angle']
                y_name = ['pitch', 'roll', 'yaw', 'throttle']

                if not args.not_use_capped_data:
                    x_test = x_test[:1000] # NOTE: only take the first 1000 steps (~20s)
                    y_test = y_test[:1000]
            elif args.env == 'cartpole':
                if args.mode == 'info_to_neuron_inv':
                    feature_name = ['x', 'xdot', 'theta', 'thetadot']
                y_name = ['Push cart']

                y_test = y_test[:,None]
            elif args.env == 'halfcheetah':
                if args.mode == 'info_to_neuron_inv':
                    feature_name = ['rootz_p', 'rooty_p', 'bthigh_p', 'bshin_p', 'bfoot_p',
                                 'fthigh_p', 'fshin_p', 'ffoot_p', 'rootx_v', 'rootz_v', 'rooty_v',
                                 'bthigh_v', 'bshin_v', 'bfoot_v', 'fthigh_v', 'fshin_v', 'ffoot_v'] # no rootx_p
                y_name = ['torque_bthigh', 'torque_bshin', 'torque_bfoot',
                          'torque_fthigh', 'torque_fshin', 'torque_ffoot']
                    
                if not args.not_use_capped_data:
                    x_test = x_test[:50] #x_test[:200] # DEBUG
                    y_test = y_test[:50] #y_test[:200] # DEBUG

            elif args.env == 'driving':
                if args.mode == 'info_to_neuron_inv':
                    feature_name = vec_obs_names

                y_name = ['omega', 'a'] # NOTE: hardcoded
            elif args.env == 'pendulum':
                if args.mode == 'info_to_neuron_inv':
                    # NOTE: hardcoded
                    feature_name = ['theta', 'theta_dot']

                y_name = ['theta_ddot']
            else:
                raise ValueError(f'Unrecognized environment {args.env}')
        else:
            raise ValueError(f'Unrecognized mode {args.mode}')

        # get leaf node and decision path
        infer_out = dict()
        for dt_name, dt in dt_all.items():
            if args.mode == 'neuron_to_action':
                infer_out[dt_name] = infer_decision_tree(x_test, dt, dt_info_all[dt_name], feature_name)
            elif args.mode in ['action_to_neuron_inv', 'info_to_neuron_inv']:
                if dt_info_all[dt_name]['n_nodes'] < 2: # skip with trees with only root
                    continue

                neuron_i = int(dt_name[-2:])
                if args.use_inv_dt:
                    infer_out[dt_name] = infer_decision_tree_inv(x_test[:,neuron_i:neuron_i+1], dt,
                        dt_info_all[dt_name], feature_name, dt_inv=dt_inv_all[dt_name])
                else:
                    infer_out[dt_name] = infer_decision_tree_inv(x_test[:,neuron_i:neuron_i+1], dt,
                        dt_info_all[dt_name], feature_name)
            elif args.mode == 'action_to_neuron_to_action':
                neuron_i = int(dt_name[-2:])
                infer_out[dt_name] = infer_decision_tree(x_test[:,neuron_i:neuron_i+1], dt,
                    dt_info_all[dt_name], feature_name, logic_mode='value')
            else:
                raise ValueError(f'Unrecognized mode {args.mode}')
        # tmp = [infer_out[k]['logic'][0] for k in infer_out.keys()]
        # print(' & '.join(tmp))
        # import pdb; pdb.set_trace()

        # compute decision path violation
        for dt_name in dt_all.keys():
            if dt_name not in infer_out.keys():
                continue
            for step_i, step_path in enumerate(infer_out[dt_name]['path']):
                violation = 0.
                for prop in step_path:
                    prop_val = state_data[ep_i][step_i][feature_name.index(prop[0])]
                    prop_op = prop[1]
                    prop_thresh = prop[2]
                    if prop_op == '<=':
                        vio = max(0., prop_val - prop_thresh)
                    elif prop_op == '>':
                        vio = min(0., prop_val - prop_thresh)
                    else:
                        raise ValueError(f'unrecognized operator {prop_op}')
                    violation += np.abs(vio)
                decision_path_violation[dt_name].append(violation)

        # compute logic conflict
        ep_logic_program = list(zip(*[v['path'] for v in infer_out.values()]))
        ep_logic_program = [[vvv for vv in v for vvv in vv] for v in ep_logic_program]
        reduced_ep_logic_program = []
        for ep_lp_v in ep_logic_program:
            reduced_ep_lp_v = reduce_logic_program(ep_lp_v)
            reduced_ep_logic_program.append(reduced_ep_lp_v)
        logic_program.extend(reduced_ep_logic_program)

        # plot
        y_dim = y_test.shape[1]
        fig, axes = plt.subplots(y_dim + 1, 1, figsize=(6.4*1.4, 2.0*(y_dim+1)))
        ts = np.arange(y_test.shape[0]) * (1 / 50.)
        for y_test_i in range(y_dim):
            axes[y_test_i].plot(ts, y_test[:, y_test_i])
            axes[y_test_i].set_ylabel(y_name[y_test_i])

        colors = list(cm.get_cmap('Set1').colors) + list(cm.get_cmap('Set2').colors) + list(cm.get_cmap('Set3').colors)
        for dt_i, (dt_name, dt_out) in enumerate(infer_out.items()):
            # if dt_name != 'cmd07': # DEBUG
            #     continue
            unique_leaf_id = np.unique(dt_out['leaf_id'])
            color_base = dict()
            for _i, (uli, alpha_i) in enumerate(zip(unique_leaf_id, np.linspace(0.1, 1., len(unique_leaf_id)))):
                color_base[uli] = list(colors[dt_i]) + [alpha_i]
                # color_base[uli] = colors[_i] # DEBUG
            color = [color_base[v] for v in dt_out['leaf_id']]
            
            y_step = 0.02 * x_test.shape[1]
            marker_size = 1000 / x_test.shape[0] / y_step
            axes[-1].scatter(ts, np.ones_like(ts) * (dt_i+1) * y_step, s=marker_size, color=color)
        span = 1.0
        if len(infer_out) > 0:
            if dt_i != 0:
                ylim_0 = min((0.5 - span) * dt_i * y_step, axes[-1].get_ylim()[0])
                ylim_1 = max((0.5 + span) * dt_i * y_step * 1.2, axes[-1].get_ylim()[1])
                axes[-1].set_ylim(ylim_0, ylim_1)
        axes[-1].set_ylabel('decision\npath')

        if args.out_dir:
            fig.tight_layout()
            fig.savefig(os.path.join(args.out_dir, f'ep_{ep_i:02d}.png'))
            # fig.savefig(os.path.join(args.out_dir, 'test.png'))
            # import pdb; pdb.set_trace()

            if len(infer_out) > 0:
                prop_seq = infer_out_to_prop_seq(infer_out)
                prop_seq_all.append(prop_seq)

    # compute statistics of decision path violation
    vio_val = []
    vio_acc = []
    for dt_name, ep_dp_violation in decision_path_violation.items():
        if len(ep_dp_violation) == 0:
            print('has empty ep_dp_violation. skip for now') # TODO:
            continue
        ep_dp_violation = np.array(ep_dp_violation)
        vio_val.append(np.mean(np.abs(ep_dp_violation)))
        vio_acc.append(np.mean(ep_dp_violation == 0.0))
    print(f'Decision path violation (val / acc): {np.mean(vio_val)} / {np.mean(vio_acc)}')

    # compute logic conflicts
    logic_conflict = [np.mean([v == 'conflict' for v in step_lp]) for step_lp in logic_program]
    print(f'Logic conflict: {np.mean(logic_conflict)}')
    
    if args.out_dir:
        fpath = os.path.join(args.out_dir, f'prop_seq.pkl')
        pickle.dump(prop_seq_all, open(fpath, 'wb')) # A list with each element of shape (n_steps, n_propositions)


def reduce_logic_program(lp):
    # collect variables
    vars = np.unique([prop[0] for prop in lp])

    # check each variable in every proposition
    vars_prop = dict()
    for var in vars:
        for prop in lp:
            if prop[0] != var:
                continue
            if var not in vars_prop.keys():
                vars_prop[var] = {'<=': None, '>': None}
            if vars_prop[var][prop[1]] is None:
                vars_prop[var][prop[1]] = prop[2]
            else:
                if prop[1] == '<=':
                    thresh = min(prop[2], vars_prop[var][prop[1]])
                else: # >
                    thresh = max(prop[2], vars_prop[var][prop[1]])
                vars_prop[var][prop[1]] = thresh

    # reduce logic program
    reduced_lp = []
    for var, var_prop in vars_prop.items():
        if (var_prop['<='] is not None) and (var_prop['>'] is not None) and (var_prop['<='] < var_prop['>']):
            reduced_lp.append('conflict')
        else:
            for op, thresh in var_prop.items():
                if thresh is not None:
                    reduced_lp.append([var, op, thresh])
    
    return reduced_lp


def infer_out_to_prop_seq(infer_out):
    prop_seq = dict()
    for name, infer_out_i in infer_out.items():
        leaf_ids = infer_out_i['leaf_id']
        n_digits = len(bin(max(leaf_ids))[2:])
        prop_seq_i = []
        for v in leaf_ids:
            v_bin = bin(v)[2:] # get rid off '0b'
            v_bin = ''.join(['0'] * (n_digits - len(v_bin))) + v_bin

            v_prop = [vv == '1' for vv in v_bin]
            prop_seq_i.append(v_prop)
        prop_seq[name] = np.array(prop_seq_i)
    prop_seq = np.concatenate(list(prop_seq.values()), axis=1)
    return prop_seq


def infer_decision_tree_inv(data, dt, dt_info, feature_name=None, verbose=False, dt_inv=None):
    # parse arguments
    feature = dt_info['feature']
    parent = dt_info['parent']
    threshold = dt_info['threshold']
    children_left = dt_info['children_left']
    children_right = dt_info['children_right']
    leaf_idcs = np.where(dt_info['is_leaves'])[0]
    if feature_name is None:
        feature_name = [f'x[{i}]' for i in range(data.shape[1])]

    # get traversed path for each leaf node by backtracing
    decision_path = dict()
    for leaf_idx in leaf_idcs:
        path = []
        backtrace_idx = leaf_idx
        while parent[backtrace_idx] != -2: # -2 is the identifier of parent for root node
            parent_idx = parent[backtrace_idx]
            parent_feature_name = feature_name[feature[parent_idx]]
            parent_threshold = threshold[parent_idx]

            if backtrace_idx in children_left:
                threshold_sign = '<='
            elif backtrace_idx in children_right:
                threshold_sign = '>'

            path.append([parent_feature_name, threshold_sign, parent_threshold])
            backtrace_idx = parent_idx

        decision_path[leaf_idx] = path

    # run
    if dt_inv is not None:
        closest_leaf_idcs = dt_inv.predict(data)

    leaf_values = dt.tree_.value[leaf_idcs].squeeze()
    out = dict(leaf_id=[], logic=[], path=[])
    for sample_id in range(data.shape[0]):
        x = data[sample_id]
        if dt_inv is None:
            closest_idx = np.argmin(np.abs(x - leaf_values))
            closest_leaf_id = leaf_idcs[closest_idx]
        else:
            closest_leaf_id = closest_leaf_idcs[sample_id]
        path = decision_path[closest_leaf_id]
        logic_str = ' & '.join([f'({v[0]} {v[1]} {v[2]:.4f})' for v in path])

        out['leaf_id'].append(closest_leaf_id)
        out['logic'].append(logic_str)
        out['path'].append(path)

    return out


def infer_decision_tree(data, dt, dt_info, feature_name=None, logic_mode='threshold', verbose=False):
    feature = dt_info['feature']
    threshold = dt_info['threshold']
    if feature_name is None:
        feature_name = [f'x[{i}]' for i in range(data.shape[1])]

    node_indicator = dt.decision_path(data)
    leaf_id = dt.apply(data)

    out = dict(leaf_id=leaf_id, logic=[])
    for sample_id in range(data.shape[0]):
        # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
        node_index = node_indicator.indices[
            node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
        ]

        if verbose:
            print("Rules used to predict sample {id}:\n".format(id=sample_id))
        logic_str = []
        for node_id in node_index:
            # continue to the next node if it is a leaf node
            if leaf_id[sample_id] == node_id:
                continue

            if logic_mode == 'threshold':
                # check if value of the split feature for sample 0 is below threshold
                if data[sample_id, feature[node_id]] <= threshold[node_id]:
                    threshold_sign = "<="
                else:
                    threshold_sign = ">"

                logic_str.append(f'({feature_name[feature[node_id]]} {threshold_sign} {threshold[node_id]:.4f})')
            elif logic_mode == 'value':
                pass
            else:
                raise NotImplementedError
            if verbose:
                print(
                    "decision node {node} : (X_test[{sample}, {feature}] = {value}) "
                    "{inequality} {threshold})".format(
                        node=node_id,
                        sample=sample_id,
                        feature=feature[node_id],
                        value=data[sample_id, feature[node_id]],
                        inequality=threshold_sign,
                        threshold=threshold[node_id],
                    )
                )
        if logic_mode == 'value':
            value_grp = dt.tree_.value[leaf_id[sample_id]]
            r_feat_name = np.array(feature_name)[dt.r_feat]

            if True: # convert group to interval (thresholding)
                all_leaf_ids = np.arange(dt_info['n_nodes'])[dt_info['is_leaves']]
                all_leaf_values = dt.tree_.value[all_leaf_ids].squeeze()
                all_leaf_values.sort()
                interval = []
                for lv_i in range(len(all_leaf_values) - 1):
                    th = (all_leaf_values[lv_i] + all_leaf_values[lv_i + 1]) / 2.
                    if isinstance(th, float):
                        th = np.array([th])
                    interval.append(th)

                logic_str = []
                for i, (s, v) in enumerate(zip(r_feat_name, value_grp)):
                    lb, ub = -np.inf, np.inf
                    for iv in interval:
                        iv = iv[i]
                        if v >= iv:
                            lb = iv
                        if v <= iv:
                            ub = iv
                            break # interval sorted

                    lstr = '('
                    if lb != -np.inf:
                        lstr += f'{lb:.4f} <= '
                    lstr += f'{s}'
                    if ub != np.inf:
                        lstr += f' <= {ub:.4f}'
                    lstr += ')'
                    logic_str.append(lstr)
            else:
                logic_str = [f'({s} ~ {v.squeeze():.4f})' for s, v in zip(r_feat_name, value_grp)]
            logic_str = ' & '.join(logic_str)
        else:
            logic_str = ' & '.join(logic_str)
        out['logic'].append(logic_str)

    return out


def parse_decision_tree(dt):
    n_nodes = dt.tree_.node_count
    children_left = dt.tree_.children_left
    children_right = dt.tree_.children_right
    feature = dt.tree_.feature
    threshold = dt.tree_.threshold

    parent = {0: -2} # root node has no parent

    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
    while len(stack) > 0:
        # `pop` ensures each node is only visited once
        node_id, depth = stack.pop()
        node_depth[node_id] = depth

        parent[children_left[node_id]] = node_id
        parent[children_right[node_id]] = node_id

        # If the left and right child of a node is not the same we have a split
        # node
        is_split_node = children_left[node_id] != children_right[node_id]
        # If a split node, append left and right children and depth to `stack`
        # so we can loop through them
        if is_split_node:
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
        else:
            is_leaves[node_id] = True

    parent = [parent[k] for k in sorted(parent.keys()) if k not in [-1, -2]]

    return {
        'n_nodes': n_nodes,
        'children_left': children_left, # left children for each node; -1 means no children
        'children_right': children_right,
        'parent': parent,
        'feature': feature,
        'threshold': threshold,
        'node_depth': node_depth,
        'is_leaves': is_leaves,
        'stack': stack
    }


if __name__ == "__main__":
    main()
