import os
import pickle
import json
import numpy as np
import argparse
import matplotlib.pyplot as plt
from sklearn import tree
import graphviz
from itertools import combinations
from sklearn.metrics.cluster import mutual_info_score, adjusted_mutual_info_score

import misc

min_samples_leaf_canyonrun = 0.1
min_samples_leaf_cartpole = 0.1 # 0.1 # 0.1 # int or float
min_samples_leaf_pendulum = 0.2 # 0.1 # 0.1 # 0.1 # int or float
min_samples_leaf_halfcheetah = 0.1 # 1 # 0.1
min_samples_leaf_driving = 0.1 # 0.01 # 0.1 # 0.1

ccp_alpha_canyonrun = 0.01
ccp_alpha_cartpole = 0.003 # 0.005 # 0.005
ccp_alpha_pendulum = 0.003 # 0.005 # 0.005
ccp_alpha_halfcheetah = 0.001 #01 # 0.0 # 0.005
ccp_alpha_driving = 0.0001 # 0.00001 # 0.00001

quantile_alpha_default = 0.001 # default to 0.1; set to 0.01 for driving

pendulum_use_theta_only = True

FAST_DEBUG = False
LOAD_EXISTING_TREE = False #True
SKIP_FORWARD_BACKWARD_CONSISTENCY = True
VERBOSE = False
LEAF_IMPURITY_CRITERION = ['std', 'var', 'dist', 'entropy'][2]
LEAF_IMPURITY_CRITERION_N_BINS = 20
MUTUAL_INFO_METRIC_MODE = 1
SAVE_METRICS = True # False
SAVE_TREE = True # False


def main():
    # parse argument and read data
    parser = argparse.ArgumentParser()
    parser.add_argument('--pkl-path', type=str, required=True)
    parser.add_argument('--mode', type=str, default='neuron_to_action')
    parser.add_argument('--out-dir', type=str, default=None)
    parser.add_argument('--env', type=str, default='canyonrun')
    parser.add_argument('--use-old-dt', action='store_true', default=False)
    args = parser.parse_args()

    if args.env == 'driving':
        with open(args.pkl_path, 'rb') as f:
            data = pickle.load(f)

        hparams_fpath = os.path.join(os.path.dirname(args.pkl_path), 'hparams.json')
        with open(hparams_fpath, 'r') as f:
            config = json.load(f)

        eval_config_fpath = os.path.join(os.path.dirname(args.pkl_path), 'eval_config.json')
        with open(eval_config_fpath, 'r') as f:
            eval_config = json.load(f)

        # convert to consistent data
        vec_obs_names = ['v', 'delta', 'dx', 'dy', 'dyaw', 'kappa'] # 'x', 'y', 'd', 'mu'
        fl_data = dict(action=[], state=[], vec_obs=[])
        for ep_i, ep_data in enumerate(data):
            for step_data in ep_data:
                ego_agent_id = step_data['ego_agent_id']

                # if np.abs(step_data['logs'][ego_agent_id]['x']) > 2 or np.abs(step_data['logs'][ego_agent_id]['y']) > 1: # DEBUG
                #     print(ep_i, step_data['logs'][ego_agent_id]['x'], step_data['logs'][ego_agent_id]['y'], step_data['logs'][ego_agent_id]['d'], step_data['logs'][ego_agent_id]['yaw']) # DEBUG
                #     # import pdb; pdb.set_trace()
                #     # continue # filter out weird data from vista

                rnn_state = [v.squeeze() for v in step_data['rnn_state']]
                if False: # DEBUG
                    rnn_state = [np.clip(v, a_min=-3., a_max=3) for v in rnn_state]
                # if np.abs(np.array(rnn_state)).max() > 5: # DEBUG
                #     print(ep_i)#np.array(rnn_state))
                #     continue

                fl_data['action'].append(step_data['actions'][ego_agent_id])
                fl_data['state'].append(rnn_state)
                vec_obs = []
                for vec_obs_k in vec_obs_names:
                    vec_obs.append(step_data['logs'][ego_agent_id][vec_obs_k])
                fl_data['vec_obs'].append(vec_obs)
    else:
        with open(args.pkl_path, 'rb') as f:
            data = pickle.load(f)

        f_prefix = os.path.splitext(args.pkl_path)[0]
        config = misc.load_yaml(f_prefix + '_eval_config.yaml')

        # flatten data in the unit of a step
        info_keys = list(data[0][0][-1].keys())[1:] # exclude DecisionStep data
        fl_data = dict()
        for key in ['action', 'state', 'vec_obs'] + info_keys:
            fl_data[key] = []

        for ep_i, ep_data in enumerate(data):
            for step_i, step_data in enumerate(ep_data):
                obs, action, state, next_obs, reward, done, info = step_data

                if args.env == 'pendulum' and pendulum_use_theta_only:
                    cos_theta, sin_theta, theta_dot = obs
                    theta = np.arctan2(sin_theta, cos_theta)
                    obs = np.array([theta, theta_dot])

                fl_data['action'].append(action)
                fl_data['state'].append(state)
                if args.env == 'canyonrun':
                    fl_data['vec_obs'].append(obs[1])
                else:
                    fl_data['vec_obs'].append(obs)
                for info_key, info_val in info.items():
                    if info_key in ['step', 'TimeLimit.truncated', 'reward_run']:
                        continue
                    fl_data[info_key].append(info_val)

    n_episodes = len(data)
    print(f'Total number of episodes: {n_episodes}')
    
    for k, v in fl_data.items():
        fl_data[k] = np.array(v)
    
    # get action dimension (where raw action dimension is (mean, logstd), (alpha, beta), etc)
    if args.env == 'canyonrun':
        if config['env_config']['input_mode'] == 3:
            dof = 4
            act_dim_names = [['pitch'], ['roll'], ['yaw'], ['throttle']]
        elif config['env_config']['input_mode'] == 2:
            dof = 2
            act_dim_names = [['roll'], ['pitch']]
        else:
            dof = 1
            act_dim_names = [['pitch']]
        if config['env_config'].get('is_bang_off_bang', False):
            act_dim_per_dof = 3
        else:
            act_dim_per_dof = 2
        act_dim = dof
        raw_act_dim = dof * act_dim_per_dof
    elif args.env == 'cartpole':
        act_dim = raw_act_dim = 1
    elif args.env == 'pendulum':
        act_dim = raw_act_dim = 1
    elif args.env == 'halfcheetah':
        act_dim = raw_act_dim = 6
    elif args.env == 'driving':
        act_dim = raw_act_dim = 2

    if len(fl_data['action'].shape) == 1:
        assert act_dim == 1
    else:
        assert act_dim == fl_data['action'].shape[1]

    # extract decision tree
    if args.env == 'canyonrun':
        rnn_config = config['model']['custom_model_config']['module_4']
    elif args.env == 'driving':
        if ('not_use_lstm' in config.keys() and not config['not_use_lstm']) or \
           ('use_lstm' in config.keys() and config['use_lstm']):
            rnn_type = 'lstm'
        elif config['use_fc_fake_rnn']:
            rnn_type = 'fc_fake_rnn'
        elif config['use_gru']:
            rnn_type = 'gru'
        elif config['use_cfc']:
            rnn_type = 'cfc'
        elif config['use_ncp']:
            rnn_type = 'ncp'
        elif config['use_ode_rnn']:
            rnn_type = 'ode_rnn'
        else:
            raise ValueError(f'Cannot identify RNN type')

        rnn_config = {k.replace(rnn_type+'_', ''): v for k, v in config.items() if rnn_type in k}
        if rnn_type == 'fc_fake_rnn':
            rnn_type = 'fc'
        rnn_config['type'] = rnn_type
    else:
        rnn_config = config['model']['custom_model_config']['module_2']
    if rnn_config['type'] == 'ncp':
        if LOAD_EXISTING_TREE:
            dt_fpath = os.path.join(args.out_dir, f'dt.pkl')
            assert os.path.exists(dt_fpath), f'dt_fpath does not exist'
            
            with open(dt_fpath, 'rb') as f:
                out_data = pickle.load(f)
                
            print(f'load existing tree {dt_fpath}')
        else:
            out_data = dict()

        # parse state data to neuron activation
        fl_data['mot_neuron'] = fl_data['state'][:, :, :raw_act_dim]
        idx_offset = raw_act_dim
        fl_data['cmd_neuron'] = fl_data['state'][:, :, idx_offset:idx_offset+rnn_config['command_neurons']]
        out_data['cmd_neuron_indices'] = [idx_offset, idx_offset+rnn_config['command_neurons']]
        idx_offset += rnn_config['command_neurons']
        fl_data['int_neuron'] = fl_data['state'][:, :, idx_offset:idx_offset+rnn_config['inter_neurons']]

        # run tree analysis
        train_data = dict()
        if args.mode == 'neuron_to_action':
            feature_names = [f'cmd{i}' for i in range(rnn_config['command_neurons'])]

            for act_i in range(act_dim): # loop through every action dimension (and correlate neuron activation)
                # prepare training data for current action dimension
                xdata = fl_data['cmd_neuron'][:, 0]
                ydata = fl_data['action'][:, act_i]

                # extract decision tree
                if config['env_config'].get('is_bang_off_bang', False): # discrete action space
                    # NOTE: there is some bug of bob training and those trained models are actually with cont. action space
                    assert False # TODO: due to the above reason, don't use this for now
                    class_names = ['neg', 'neutral', 'pos']

                    dt_kwargs = dict(max_depth=3, class_weight='balanced', min_samples_leaf=0.1, ccp_alpha=0.03)

                    dt = tree.DecisionTreeClassifier(**dt_kwargs)
                    dt = dt.fit(xdata, ydata)
                    if VERBOSE: print('fit tree')
                else: # continuous action space
                    dt_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.02, criterion='friedman_mse')

                    dt = tree.DecisionTreeRegressor(**dt_kwargs)
                    dt = dt.fit(xdata, ydata)
                    if VERBOSE: print('fit tree')

                act_dim_name = '_'.join(act_dim_names[act_i])
                out_data[f'{act_dim_name}'] = dt

                if args.out_dir:
                    dot_data = tree.export_graphviz(dt, out_file=None, 
                                                    feature_names=feature_names,
                                                    filled=True, rounded=True,  
                                                    special_characters=True)
                    graph = graphviz.Source(dot_data)

                    out_path = os.path.join(args.out_dir, f'{act_dim_name}')
                    graph.render(out_path)

                # TODO: plot decision boundary
        elif args.mode == 'action_to_neuron':
            feature_names = ['-'.join(v) for v in act_dim_names]

            for neuron_i in range(rnn_config['command_neurons']): # loop through neuron
                # prepare training data for current action dimension
                xdata = fl_data['action']
                ydata = fl_data['cmd_neuron'][:, 0, neuron_i]

                # extract decision tree
                assert not config['env_config'].get('is_bang_off_bang', False) # cont. action space only
                dt_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.01, criterion='friedman_mse')
                dt = tree.DecisionTreeRegressor(**dt_kwargs)
                dt = dt.fit(xdata, ydata)
                if VERBOSE: print('fit tree')

                neuron_name = f'cmd{neuron_i:02d}'
                out_data[f'{neuron_name}'] = dt

                if args.out_dir:
                    dot_data = tree.export_graphviz(dt, out_file=None, 
                                                    feature_names=feature_names,
                                                    filled=True, rounded=True,  
                                                    special_characters=True)
                    graph = graphviz.Source(dot_data)

                    out_path = os.path.join(args.out_dir, f'{neuron_name}')
                    graph.render(out_path)
        elif args.mode == 'action_to_neuron_to_action':
            # find a subset of correlated feature (action or state)
            a2n_dt = dict()
            related_feature = dict()
            for neuron_i in range(rnn_config['command_neurons']): # loop through neuron
                # prepare training data for current action dimension
                xdata = fl_data['action']
                ydata = fl_data['cmd_neuron'][:, 0, neuron_i]

                # extract decision tree
                assert not config['env_config'].get('is_bang_off_bang', False) # cont. action space only
                dt_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.01, criterion='friedman_mse')
                dt = tree.DecisionTreeRegressor(**dt_kwargs)
                dt = dt.fit(xdata, ydata)
                if VERBOSE: print('fit tree')

                neuron_name = f'cmd{neuron_i:02d}'
                a2n_dt[neuron_name] = dt
                related_feature[neuron_name] = np.unique(dt.tree_.feature[dt.tree_.feature > 0])

            # use neuron activation to infer relevant "event" defined on the correlated feature
            for neuron_i in range(rnn_config['command_neurons']):
                feature_names = [f'cmd{neuron_i}'] # fit every neuron independently

                neuron_name = f'cmd{neuron_i:02d}'
                r_feat = related_feature[neuron_name]
                if r_feat.shape[0] == 0: # no correlated feature
                    continue

                xdata = fl_data['cmd_neuron'][:, 0, neuron_i:neuron_i+1]
                ydata = fl_data['action'][:, r_feat]

                # extract decision tree
                assert not config['env_config'].get('is_bang_off_bang', False) # cont. action space only
                dt_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.01, criterion='friedman_mse')
                dt = tree.DecisionTreeRegressor(**dt_kwargs)
                dt = dt.fit(xdata, ydata)
                if VERBOSE: print('fit tree')

                dt.r_feat = r_feat # customly added field
                out_data[f'{neuron_name}'] = dt

                if args.out_dir:
                    dot_data = tree.export_graphviz(dt, out_file=None, 
                                                    feature_names=feature_names,
                                                    filled=True, rounded=True,  
                                                    special_characters=True)
                    graph = graphviz.Source(dot_data)

                    out_path = os.path.join(args.out_dir, f'{neuron_name}')
                    graph.render(out_path)
        elif args.mode == 'info_to_neuron':
            feature_names = ['throttle', 'normed_position_y', 'normed_forward_speed', 'rotation_0',
            'rotation_1', 'rotation_2', 'rotation_3', 'engine_power', 'forward_direction_x', 'forward_direction_y',
            'forward_direction_z', 'lift_power', 'lift_direction_x', 'lift_direction_y', 'lift_direction_z',
            'pitch_input', 'yaw_input', 'roll_input', 'drag', 'angular_drag', 'roll_angle', 'pitch_angle']

            for neuron_i in range(rnn_config['command_neurons']): # loop through neuron
                # prepare training data for current action dimension
                xdata = []
                for name in ['throttle', 'normed_position_y', 'normed_forward_speed', 'rotation', 
                             'engine_power', 'forward_direction', 'lift_power', 'lift_direction',
                             'pitch_input', 'yaw_input', 'roll_input', 'drag', 'angular_drag',
                             'roll_angle', 'pitch_angle']:
                    xdata_i = fl_data[name][:,None] if len(fl_data[name].shape) == 1 else fl_data[name]
                    xdata.append(xdata_i)
                xdata = np.concatenate(xdata, axis=1)
                ydata = fl_data['cmd_neuron'][:, 0, neuron_i]

                # extract decision tree
                assert not config['env_config'].get('is_bang_off_bang', False) # cont. action space only
                dt_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.01, criterion='friedman_mse')
                dt = tree.DecisionTreeRegressor(**dt_kwargs)
                dt = dt.fit(xdata, ydata)
                if VERBOSE: print('fit tree')

                neuron_name = f'cmd{neuron_i:02d}'
                out_data[f'{neuron_name}'] = dt

                if args.out_dir:
                    dot_data = tree.export_graphviz(dt, out_file=None, 
                                                    feature_names=feature_names,
                                                    filled=True, rounded=True,  
                                                    special_characters=True)
                    graph = graphviz.Source(dot_data)

                    out_path = os.path.join(args.out_dir, f'{neuron_name}')
                    graph.render(out_path)
        elif args.mode == 'state_to_neuron':
            include_act = False
            if args.env == 'cartpole':
                feature_names = ['x', 'xdot', 'theta', 'thetadot']
                if include_act:
                    feature_names += ['act']
            elif args.env == 'pendulum':
                if pendulum_use_theta_only:
                    feature_names = ['theta', 'theta_dot']
                else:
                    feature_names = ['cos_theta', 'sin_theta', 'theta_dot']
                if include_act:
                    feature_names += ['theta_ddot']
            elif args.env == 'halfcheetah':
                # Ref: https://www.gymlibrary.ml/environments/mujoco/half_cheetah/
                #      https://github.com/openai/gym/blob/bf0113bfa15e29eef3033b0cc5815eda99a98193/gym/envs/mujoco/assets/half_cheetah.xml#L6
                feature_names = ['rootz_p', 'rooty_p', 'bthigh_p', 'bshin_p', 'bfoot_p',
                                 'fthigh_p', 'fshin_p', 'ffoot_p', 'rootx_v', 'rootz_v', 'rooty_v',
                                 'bthigh_v', 'bshin_v', 'bfoot_v', 'fthigh_v', 'fshin_v', 'ffoot_v'] # no rootx_p
                if include_act:
                    feature_names += ['torque_bthigh', 'torque_bshin', 'torque_bfoot',
                                      'torque_fthigh', 'torque_fshin', 'torque_ffoot']
            elif args.env == 'driving':
                feature_names = vec_obs_names
                if include_act:
                    feature_names += eval_config['control_mode'].split('-')

            forward_backward_acc = []
            for neuron_i in range(rnn_config['command_neurons']): # loop through neuron
                # prepare training data for current action dimension
                xdata = fl_data['vec_obs']
                if include_act:
                    if args.env in ['cartpole', 'halfcheetah']:
                        xdata = np.concatenate([xdata, fl_data['action'][:,None]], axis=-1)
                ydata = fl_data['cmd_neuron'][:, 0, neuron_i]
                
                if FAST_DEBUG: # DEBUG
                    xdata = xdata[:2000]
                    ydata = ydata[:2000]

                # extract decision tree
                neuron_name = f'cmd{neuron_i:02d}'
                if not LOAD_EXISTING_TREE:
                    if args.env == 'canyonrun':
                        assert not config['env_config'].get('is_bang_off_bang', False) # cont. action space only
                    if args.env == 'cartpole':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_cartpole, ccp_alpha=ccp_alpha_cartpole, criterion='friedman_mse')
                    elif args.env == 'pendulum':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_pendulum, ccp_alpha=ccp_alpha_pendulum, criterion='friedman_mse')
                    elif args.env == 'halfcheetah':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_halfcheetah, ccp_alpha=ccp_alpha_halfcheetah, criterion='friedman_mse')
                    elif args.env == 'driving':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_driving, ccp_alpha=ccp_alpha_driving, criterion='friedman_mse')
                    dt = tree.DecisionTreeRegressor(**dt_kwargs)
                    dt = dt.fit(xdata, ydata)
                    if VERBOSE: print('fit tree')
                    
                    out_data[f'{neuron_name}'] = dt
                else:
                    dt = out_data[f'{neuron_name}']

                train_data[f'{neuron_name}'] = [xdata, ydata]

                # print(xdata.var(0), ydata.min(), ydata.max(), ydata.var()) # DEBUG
                # if neuron_name == 'cmd01': # DEBUG
                #     import pdb; pdb.set_trace()

                if args.out_dir and not LOAD_EXISTING_TREE:
                    dot_data = tree.export_graphviz(dt, out_file=None, 
                                                    feature_names=feature_names,
                                                    filled=True, rounded=True,  
                                                    special_characters=True)
                    graph = graphviz.Source(dot_data)

                    out_path = os.path.join(args.out_dir, f'{neuron_name}')
                    graph.render(out_path)

                # augment inverse decision tree for backtracking
                if True:
                    # form class based on the forward decision tree
                    dp_class = []
                    node_indicator = dt.decision_path(xdata)
                    for sample_id in range(xdata.shape[0]):
                        dp_node_index = node_indicator.indices[
                            node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
                        ]
                        dp_leaf_node = dp_node_index[-1]
                        dp_class.append(dp_leaf_node) # correspond to a decision path
                    dp_class = np.array(dp_class)

                    # construct inverse decision tree (classifier)
                    # NOTE: can use any classifier here, not necessarily decision tree classifier
                    if not LOAD_EXISTING_TREE:
                        dt_inv_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.01, criterion='gini')
                        dt_inv = tree.DecisionTreeClassifier(**dt_inv_kwargs)
                        dt_inv = dt_inv.fit(ydata[:,None], dp_class)
                        if VERBOSE: print('fit tree')

                        out_data[f'inv:{neuron_name}'] = dt_inv
                    else:
                        dt_inv = out_data[f'inv:{neuron_name}']

                    if args.out_dir and not LOAD_EXISTING_TREE:
                        dot_data = tree.export_graphviz(dt_inv, out_file=None, 
                                                        feature_names=[f'{neuron_name}'],
                                                        filled=True, rounded=True,  
                                                        special_characters=True)
                        graph = graphviz.Source(dot_data)

                        out_path = os.path.join(args.out_dir, f'inv_{neuron_name}')
                        graph.render(out_path)

                    # check forward-backward consistency
                    if not SKIP_FORWARD_BACKWARD_CONSISTENCY:
                        from sklearn.metrics import accuracy_score, log_loss
                        pred_dp_class = dt_inv.predict(ydata[:,None])
                        unique_dp_class = np.unique(dp_class)
                        pred_dp_class_one_hot = np.zeros((pred_dp_class.shape[0], unique_dp_class.max()+1))
                        pred_dp_class_one_hot[:,pred_dp_class] = 1
                        if unique_dp_class.shape[0] == 1:
                            ce_loss = 'n/a'
                            accuracy = 'n/a'
                            forward_backward_acc.append(0.5)
                        else:
                            ce_loss = log_loss(dp_class, pred_dp_class_one_hot, labels=np.arange(unique_dp_class.max() + 1))
                            accuracy = accuracy_score(dp_class, pred_dp_class)
                            forward_backward_acc.append(accuracy)
                        # print(f'Forward-backward consistency (acc / ce): {accuracy} / {ce_loss}')
        else:
            raise ValueError(f'Unrecognized mode {args.mode}')
    elif rnn_config['type'] in ['lstm', 'cfc', 'gru', 'rnn', 'fc', 'ode_rnn']:
        if rnn_config['type'] == 'lstm':
            fl_data['state_h'] = fl_data['state'][:, 0:1, :]
            fl_data['state_c'] = fl_data['state'][:, 1:2, :]

            neuron_prefix = ['h', 'c', 'h_c'][2]
            if neuron_prefix == 'h':
                state_of_interest = fl_data['state_h']
            elif neuron_prefix == 'c':
                state_of_interest = fl_data['state_c']
            elif neuron_prefix == 'h_c':
                state_of_interest = np.concatenate([fl_data['state_h'], fl_data['state_c']], axis=-1)
        else:
            neuron_prefix = 'state'
            state_of_interest = fl_data['state']
        
        # run tree analysis
        if LOAD_EXISTING_TREE:
            dt_fpath = os.path.join(args.out_dir, f'dt.pkl')
            assert os.path.exists(dt_fpath), f'dt_fpath does not exist'
            
            with open(dt_fpath, 'rb') as f:
                out_data = pickle.load(f)
                
            print(f'load existing tree {dt_fpath}')
        else:
            out_data = dict()
        train_data = dict()
        if args.mode == 'state_to_neuron':
            include_act = False
            if args.env == 'cartpole':
                feature_names = ['x', 'xdot', 'theta', 'thetadot']
                if include_act:
                    feature_names += ['act']
            elif args.env == 'pendulum':
                if pendulum_use_theta_only:
                    feature_names = ['theta', 'theta_dot']
                else:
                    feature_names = ['cos_theta', 'sin_theta', 'theta_dot']
                if include_act:
                    feature_names += ['theta_ddot']
            elif args.env == 'halfcheetah':
                # Ref: https://www.gymlibrary.ml/environments/mujoco/half_cheetah/
                feature_names = ['rootz_p', 'rooty_p', 'bthigh_p', 'bshin_p', 'bfoot_p',
                                 'fthigh_p', 'fshin_p', 'ffoot_p', 'rootx_v', 'rootz_v', 'rooty_v',
                                 'bthigh_v', 'bshin_v', 'bfoot_v', 'fthigh_v', 'fshin_v', 'ffoot_v'] # no rootx_p
                if include_act:
                    feature_names += ['torque_bthigh', 'torque_bshin', 'torque_bfoot',
                                      'torque_fthigh', 'torque_fshin', 'torque_ffoot']
            elif args.env == 'driving':
                feature_names = vec_obs_names
                if include_act:
                    feature_names += eval_config['control_mode'].split('-')

            forward_backward_acc = []
            for neuron_i in range(state_of_interest.shape[-1]): # loop through neuron
                # prepare training data for current action dimension
                xdata = fl_data['vec_obs']
                if include_act:
                    if args.env in ['cartpole', 'halfcheetah']:
                        xdata = np.concatenate([xdata, fl_data['action'][:,None]], axis=-1)
                ydata = state_of_interest[:, 0, neuron_i]
                
                if FAST_DEBUG: # DEBUG
                    xdata = xdata[:2000]
                    ydata = ydata[:2000]

                # extract decision tree
                neuron_name = f'{neuron_prefix}{neuron_i:02d}'
                if not LOAD_EXISTING_TREE:
                    if args.env == 'canyonrun':
                        assert not config['env_config'].get('is_bang_off_bang', False) # cont. action space only
                    if args.env == 'cartpole':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_cartpole, ccp_alpha=ccp_alpha_cartpole, criterion='friedman_mse')
                    elif args.env == 'pendulum':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_pendulum, ccp_alpha=ccp_alpha_pendulum, criterion='friedman_mse')
                    elif args.env == 'halfcheetah':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_halfcheetah, ccp_alpha=ccp_alpha_halfcheetah, criterion='friedman_mse')
                    elif args.env == 'driving':
                        dt_kwargs = dict(max_depth=3, min_samples_leaf=min_samples_leaf_driving, ccp_alpha=ccp_alpha_driving, criterion='friedman_mse')
                    dt = tree.DecisionTreeRegressor(**dt_kwargs)
                    dt = dt.fit(xdata, ydata)
                    if VERBOSE: print('fit tree')
                    
                    out_data[f'{neuron_name}'] = dt
                else:
                    dt = out_data[f'{neuron_name}']

                train_data[f'{neuron_name}'] = [xdata, ydata]

                if args.out_dir and not LOAD_EXISTING_TREE:
                    dot_data = tree.export_graphviz(dt, out_file=None, 
                                                    feature_names=feature_names,
                                                    filled=True, rounded=True,  
                                                    special_characters=True)
                    graph = graphviz.Source(dot_data)

                    out_path = os.path.join(args.out_dir, f'{neuron_name}')
                    graph.render(out_path)

                # augment inverse decision tree for backtracking
                if True:
                    # form class based on the forward decision tree
                    dp_class = []
                    node_indicator = dt.decision_path(xdata)
                    for sample_id in range(xdata.shape[0]):
                        dp_node_index = node_indicator.indices[
                            node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
                        ]
                        dp_leaf_node = dp_node_index[-1]
                        dp_class.append(dp_leaf_node) # correspond to a decision path
                    dp_class = np.array(dp_class)

                    # construct inverse decision tree (classifier)
                    # NOTE: can use any classifier here, not necessarily decision tree classifier
                    if not LOAD_EXISTING_TREE:
                        dt_inv_kwargs = dict(max_depth=3, min_samples_leaf=0.01, ccp_alpha=0.01, criterion='gini')
                        dt_inv = tree.DecisionTreeClassifier(**dt_inv_kwargs)
                        dt_inv = dt_inv.fit(ydata[:,None], dp_class)
                        if VERBOSE: print('fit tree')

                        out_data[f'inv:{neuron_name}'] = dt_inv
                    else:
                        dt_inv = out_data[f'inv:{neuron_name}']

                    if args.out_dir and not LOAD_EXISTING_TREE:
                        dot_data = tree.export_graphviz(dt_inv, out_file=None, 
                                                        feature_names=[f'{neuron_name}'],
                                                        filled=True, rounded=True,  
                                                        special_characters=True)
                        graph = graphviz.Source(dot_data)

                        out_path = os.path.join(args.out_dir, f'inv_{neuron_name}')
                        graph.render(out_path)

                    # check forward-backward consistency
                    if not SKIP_FORWARD_BACKWARD_CONSISTENCY:
                        from sklearn.metrics import accuracy_score, log_loss
                        pred_dp_class = dt_inv.predict(ydata[:,None])
                        unique_dp_class = np.unique(dp_class)
                        pred_dp_class_one_hot = np.zeros((pred_dp_class.shape[0], unique_dp_class.max()+1))
                        pred_dp_class_one_hot[:,pred_dp_class] = 1
                        if unique_dp_class.shape[0] == 1:
                            ce_loss = 'n/a'
                            accuracy = 'n/a'
                            forward_backward_acc.append(0.5)
                        else:
                            ce_loss = log_loss(dp_class, pred_dp_class_one_hot, labels=np.arange(unique_dp_class.max() + 1))
                            accuracy = accuracy_score(dp_class, pred_dp_class)
                            forward_backward_acc.append(accuracy)
                        # print(f'Forward-backward consistency (acc / ce): {accuracy} / {ce_loss}')
        else:
            raise ValueError(f'Unrecognized mode {args.mode}')
    else:
        raise ValueError('Unrecognized model type {}'.format(rnn_config['type']))

    metrics = dict()

    leaf_impurity = dict()
    for name, dt in out_data.items():
        if 'inv:' not in name and name != 'cmd_neuron_indices':
            is_leaf, leaf_impurity[name] = compute_leaf_impurity(dt, train_data[name])
    print('Average leaf impurity', np.mean([vv for v in leaf_impurity.values() for vv in v]))
    print('Average forward-backward accuracy', np.mean(forward_backward_acc))
    
    metrics['leaf_impurity'] = list(leaf_impurity.values())
    
    # analysis based on mutual information
    mig, modularity = compute_mutual_info_metrics(out_data, train_data)
    
    metrics['mig'] = mig
    metrics['modularity'] = modularity
    
    if args.env == 'driving':
        crash_rate = []
        for ep_data in data:
            ep_crash_rate = []
            for step_data in ep_data:
                ego_agent_id = step_data['ego_agent_id']
                step_infos = step_data['infos'][ego_agent_id]
                has_crashed = step_infos['out_of_lane'] or step_infos['exceed_max_rot'] or step_infos['crashed']
                ep_crash_rate.append(has_crashed)
            crash_rate.append(ep_crash_rate)
            
        max_episode_len = 100
        complete_ratio = []
        for v in crash_rate:
            if v[-1]: # incomplete
                cr = len(v) / max_episode_len
            else:
                cr = 1.
            complete_ratio.append(cr)
        complete_ratio = np.array(complete_ratio)
        
        metrics['performance'] = np.mean(complete_ratio)
    else:
        ep_reward = [np.sum([step_v[-3] for step_v in ep_v]) for ep_v in data]
        metrics['performance'] = np.mean(ep_reward)

    if args.out_dir and not LOAD_EXISTING_TREE and SAVE_TREE:
        with open(os.path.join(args.out_dir, f'dt.pkl'), 'wb') as f:
            pickle.dump(out_data, f)
            
    if args.out_dir and SAVE_METRICS:
        with open(os.path.join(args.out_dir, f'metrics.pkl'), 'wb') as f:
            pickle.dump(metrics, f)


def compute_mutual_info_metrics(out_data, train_data, n_bins=20):
    # unified all dt and corresponding data (discretize activation)
    dts_with_discrete_act = dict()
    for name, dt in out_data.items():
        if 'inv:' not in name and name != 'cmd_neuron_indices':
            dts_with_discrete_act[name] = dict()
            dts_with_discrete_act[name]['dt'] = dt
            raw_act = train_data[name][1]
            hist, bin_edges = np.histogram(raw_act, n_bins)
            discrete_act = np.digitize(raw_act, bin_edges) # range from 1 to n_bins
            dts_with_discrete_act[name]['discrete_act'] = discrete_act
            
    # get psuedo-ground-truth factor (cartesian product of decision path)
    xdata = train_data[list(train_data.keys())[0]][0] # xdata is shared across all dt training

    multidim_labels = []
    for name in dts_with_discrete_act.keys():
        dt = dts_with_discrete_act[name]['dt']
        
        is_leaf = (dt.tree_.children_left == -1) & (dt.tree_.children_right == -1)
        leaf_node_idcs = np.arange(dt.tree_.node_count)[is_leaf]
        n_decision_paths = len(leaf_node_idcs)

        decision_path_leaf_idcs = []
        node_indicator = dt.decision_path(xdata)
        for sample_id in range(xdata.shape[0]):
            node_index = node_indicator.indices[
                node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
            ]
            decision_path_leaf_idcs.append(node_index[-1])
        multidim_labels.append(decision_path_leaf_idcs)
    multidim_labels = np.array(multidim_labels)
    
    multidim_unique_labels = [np.unique(multidim_labels[i]) for i in range(multidim_labels.shape[0])]
    n_unique_labels = [len(v) for v in multidim_unique_labels]
    m_score_fn = mutual_info_score #adjusted_mutual_info_score
    if False: # NOTE: computing psuedo-gt fator set is too large
        n_unique_labels_cumsum = np.cumsum(n_unique_labels)
        pseudo_gt_factors = np.zeros((xdata.shape[0], np.sum(n_unique_labels)))
        for i in range(multidim_labels.shape[0]):
            start_idx = 0 if i == 0 else n_unique_labels_cumsum[i - 1]
            for j, unique_label in enumerate(multidim_unique_labels[i]):
                mask = multidim_labels[i] == unique_label
                pseudo_gt_factors[mask, start_idx + j] = 1 # still multinomial now
    elif MUTUAL_INFO_METRIC_MODE == 0: # instead of compute mutual info in multinomial distribution (need to double-check the math)
        use_calibration_term = False # True
        
        # compute MIG
        mutual_infos_per_dp = dict()
        entropies_dp = dict()
        entropies_z = dict()
        for name in dts_with_discrete_act.keys(): # NOTE: compute mutual info at independent gt factor set
            mutual_infos_per_dp[name] = []
            for i in range(multidim_labels.shape[0]):
                m_info = m_score_fn(multidim_labels[i], dts_with_discrete_act[name]['discrete_act'])
                mutual_infos_per_dp[name].append(m_info)
            mutual_infos_per_dp[name] = np.array(mutual_infos_per_dp[name])
            
            entropies_dp[name] = m_score_fn(dts_with_discrete_act[name]['discrete_act'], dts_with_discrete_act[name]['discrete_act'])
            z_i = list(dts_with_discrete_act.keys()).index(name) # each neuron dimension corresponds to each decision tree
            entropies_z[name] = m_score_fn(multidim_labels[z_i], multidim_labels[z_i])

        mig = []
        for i, name in enumerate(mutual_infos_per_dp.keys()): # loop through i (neuron dim)
            if use_calibration_term:
                ent_multiplier = (multidim_labels.shape[0] - 1) / multidim_labels.shape[0] # to compensate for sum over |I| later
                m_info = entropies_z[name] * ent_multiplier + mutual_infos_per_dp[name]
            else:
                m_info = mutual_infos_per_dp[name]
            sorted_m_info = np.sort(m_info)[::-1]
            k_i = n_unique_labels[i] # number of decision paths 
            mig_i = (sorted_m_info[0] - sorted_m_info[1]) / entropies_dp[name] / k_i
            mig.append(mig_i)
            
        # compute modularity
        mutual_infos_per_neuron_dim = dict() # enumerate across all dps
        for i, name in enumerate(dts_with_discrete_act.keys()):
            mutual_infos_per_neuron_dim[name] = []
            for j, name_inner in enumerate(dts_with_discrete_act.keys()):
                m_info = m_score_fn(multidim_labels[i], dts_with_discrete_act[name_inner]['discrete_act'])
                mutual_infos_per_neuron_dim[name].append(m_info)
            mutual_infos_per_neuron_dim[name] = np.array(mutual_infos_per_neuron_dim[name])
        
        modularity = []
        for i, name in enumerate(mutual_infos_per_dp.keys()):
            if use_calibration_term:
                ent_multiplier = (multidim_labels.shape[0] - 1) / multidim_labels.shape[0]
                m_info = entropies_z[name] * ent_multiplier + mutual_infos_per_dp[name]
            else:
                m_info = mutual_infos_per_neuron_dim[name]
            sorted_m_info = np.sort(m_info)[::-1]
            k_i = n_unique_labels[i]
            best_m_info = sorted_m_info[0]
            
            normalized_var = ((sorted_m_info - best_m_info) ** 2) / (best_m_info**2 * (k_i - 1))
            modularity_i = 1 - np.mean(normalized_var)
            modularity.append(modularity_i)
    elif MUTUAL_INFO_METRIC_MODE == 1:
        # get global decision paths
        global_labels = []
        for i, singldim_unique_labels in enumerate(multidim_unique_labels):
            for j in singldim_unique_labels:
                global_labels.append(multidim_labels[i] == j)
        global_labels = np.array(global_labels)
        K = len(global_labels)
        
        # pairwise entropies across all global decision paths
        mutual_infos_global_dp = []
        for i in range(global_labels.shape[0]):
            mutual_infos_global_dp.append([])
            for j in range(global_labels.shape[0]):
                mutual_infos_global_dp_ij = m_score_fn(global_labels[i], global_labels[j])
                mutual_infos_global_dp[-1].append(mutual_infos_global_dp_ij)
        mutual_infos_global_dp = np.array(mutual_infos_global_dp)
        
        # pairwise mutual information between all global dp and neuron activation
        mutual_infos_z_dp = []
        for i, name in enumerate(dts_with_discrete_act.keys()):
            mutual_infos_z_dp.append([])
            for j in range(global_labels.shape[0]):
                discrete_act_i = dts_with_discrete_act[name]['discrete_act']
                mutual_infos_z_dp_ij = m_score_fn(discrete_act_i, global_labels[j])
                mutual_infos_z_dp[-1].append(mutual_infos_z_dp_ij)
        mutual_infos_z_dp = np.array(mutual_infos_z_dp)
        
        I, K = mutual_infos_z_dp.shape
        
        # compute MIG
        mig = []
        for k in range(K):
            ent = mutual_infos_global_dp[k, k]
            istar = mutual_infos_z_dp[:, k].argmax()
            best_m_info = mutual_infos_z_dp[istar, k]
            
            m_info_j = []
            for j in range(I):
                if j == istar:
                    continue
                k_jstar = mutual_infos_z_dp[j, :].argmax()
                
                # m_info_j_entry = mutual_infos_z_dp[j, k] - mutual_infos_global_dp[k, k_jstar]
                m_info_j_entry = max(0, mutual_infos_z_dp[j, k] - mutual_infos_global_dp[k, k_jstar])
                m_info_j.append(m_info_j_entry)
            m_info_j = np.array(m_info_j)
            
            best_m_info_j = m_info_j.max()
            
            if ent != 0.:
                mig_k = 1. / ent * (best_m_info - best_m_info_j)
            else: # handle decision path set with a single node
                mig_k = 0. # NOTE: not sure if this makes sense; or perhaps just skip it?!
            
            mig.append(mig_k)
        mig = np.array(mig)
        
        # compute modularity
        modularity = []
        for i in range(I):
            kstar = mutual_infos_z_dp[i, :].argmax()
            best_m_info_over_k = mutual_infos_z_dp[i, kstar]
            
            # k_istar = kstar
            # m_info_k_kistar = mutual_infos_global_dp[k, k_istar]

            # t_ik = np.zeros_like(mutual_infos_z_dp[i, :])
            # t_ik[kstar] = best_m_info_over_k
            
            # denom = ((t_ik - (mutual_infos_z_dp[i] - mutual_infos_global_dp[kstar])) ** 2).sum()
            # denom = ((t_ik - mutual_infos_z_dp[i]) ** 2).sum()
            # denom = (((t_ik - m_info_k_kistar) - mutual_infos_z_dp[i]) ** 2).sum()
            
            denom = []
            for k in range(K):
                if k == kstar:
                    continue
                denom_k = max(0, mutual_infos_z_dp[i, k] - mutual_infos_global_dp[kstar, k]) ** 2
                # if mutual_infos_z_dp[i, k] - mutual_infos_global_dp[kstar, k] < 0:
                #     import pdb; pdb.set_trace()
                # denom_k = (mutual_infos_z_dp[i, k] - mutual_infos_global_dp[kstar, k]) ** 2
                denom.append(denom_k)
            denom = np.array(denom).sum()
            
            numer = (K - 1) * best_m_info_over_k**2
            
            delta_i = denom / (numer + 1e-8)
            modularity.append(1 - delta_i)
        modularity = np.array(modularity)
    else:
        raise ValueError
    
    print(f'Mutual information gap (min, mean, max, sum): {np.min(mig)} {np.mean(mig)} {np.max(mig)} {np.sum(mig)}')
    print(f'Modularity (min, mean, max, sum): {np.min(modularity)} {np.mean(modularity)} {np.max(modularity)} {np.sum(modularity)}')
    
    return mig, modularity


def compute_leaf_impurity(dt, train_data, mode=1, criterion=LEAF_IMPURITY_CRITERION):
    node_impurity = dt.tree_.impurity
    node_value = dt.tree_.value
    is_leaf = (dt.tree_.children_left == -1) & (dt.tree_.children_right == -1)
    leaf_node_idcs = np.arange(dt.tree_.node_count)[is_leaf]
    if mode == 0:
        leaf_impurity = node_impurity[is_leaf]
    else:
        xdata, ydata = train_data
        node_indicator = dt.decision_path(xdata)

        decision_path_leaf_idcs = []
        for sample_id in range(xdata.shape[0]):
            # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
            node_index = node_indicator.indices[
                node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
            ]
            decision_path_leaf_idcs.append(node_index[-1])

        leaf_impurity = []
        for i in leaf_node_idcs:
            mask_i = decision_path_leaf_idcs == i
            masked_data = ydata[mask_i]
            if criterion == 'std':
                leaf_impurity_i = masked_data.std()
            elif criterion == 'var':
                leaf_impurity_i = masked_data.var()
            elif criterion == 'dist':
                hist, bin_edges = np.histogram(ydata, LEAF_IMPURITY_CRITERION_N_BINS)
                leaf_impurity_i = (np.digitize(masked_data, bin_edges) / len(bin_edges)).var()
            elif criterion == 'entropy':
                hist, bin_edges = np.histogram(ydata, LEAF_IMPURITY_CRITERION_N_BINS)
                disc_masked_data = np.digitize(masked_data, bin_edges)
                leaf_impurity_i = mutual_info_score(disc_masked_data, disc_masked_data)
            else:
                raise ValueError(f'Unrecognized criterion {criterion}')
            leaf_impurity.append(leaf_impurity_i)
        quantile_alpha = quantile_alpha_default
        if len(leaf_impurity) > 1:
            normalize_factor = np.quantile(ydata, 1 - quantile_alpha) - np.quantile(ydata, quantile_alpha)
        else: # NOTE: handle tree with only one root node --> make leaf impurity to be 1
            normalize_factor = leaf_impurity[0]
        if criterion == 'var':
            normalize_factor = normalize_factor ** 2
        elif criterion in ['dist', 'entropy']:
            normalize_factor = 1.
        # print(normalize_factor, ydata.min(), ydata.max()) # DEBUG
        # if normalize_factor > 10:
        #     import pdb; pdb.set_trace()
        leaf_impurity = (np.array(leaf_impurity) + 1e-8) / (normalize_factor + 1e-8) # NOTE: a trick to make 0 variance to be 1
        # print(leaf_impurity) # DEBUG

    return is_leaf, leaf_impurity


if __name__ == "__main__":
    main()
