import os
import argparse
import numpy as np
import pickle

from deploy_decision_tree import parse_decision_tree, reduce_logic_program


def main():
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dt-path', type=str, required=True)
    parser.add_argument('--model', type=str, required=True)
    parser.add_argument('--env', type=str, required=True)
    parser.add_argument('--out-dir', type=str, default=None)
    args = parser.parse_args()

    with open(args.dt_path, 'rb') as f:
        dt_all = pickle.load(f)

    if args.out_dir:
        if not os.path.isdir(args.out_dir):
            os.makedirs(args.out_dir)

    if args.model == 'ncp':
        neuron_prefix = 'cmd'
    elif args.model == 'lstm':
        neuron_prefix = 'h_c'
    else:
        neuron_prefix = 'state'

    if args.env == 'pendulum':
        n_neurons = 4
        state_names = ['theta', 'theta_dot']
        neuron_names = [f'{neuron_prefix}{i:02d}' for i in range(n_neurons)]
    elif args.env == 'halfcheetah':
        n_neurons = 10
        state_names = ['rootz_p', 'rooty_p', 'bthigh_p', 'bshin_p', 'bfoot_p',
                       'fthigh_p', 'fshin_p', 'ffoot_p', 'rootx_v', 'rootz_v', 'rooty_v',
                       'bthigh_v', 'bshin_v', 'bfoot_v', 'fthigh_v', 'fshin_v', 'ffoot_v']
        neuron_names = [f'{neuron_prefix}{i:02d}' for i in range(n_neurons)]
    elif args.env == 'driving':
        n_neurons = 8
        state_names = ['v', 'delta', 'dx', 'dy', 'dyaw', 'kappa']
        neuron_names = [f'{neuron_prefix}{i:02d}' for i in range(n_neurons)]
    else:
        raise ValueError(f'Unrecognized env {args.env}')

    # Extract logic programs
    logic_programs = []
    for _, neuron_name in enumerate(neuron_names):
        dt_i = dt_all[neuron_name]
        logic_program = extract_logic_program(dt_i, state_names)
        logic_programs.append(logic_program)

    # Get measure of explanation quality
    # TODO: check logic program overlap
    n_cases = []
    case_complexity = []
    n_unique_symbols = []
    for logic_program in logic_programs: # iterate through neurons
        n_cases.append(len(logic_program))
        case_complexity.extend([len(v) for v in logic_program.values()])
        unique_symbols = np.unique([vv[0] for v in logic_program.values() for vv in v]) # unique symbols used in single-neuron explanation
        n_unique_symbols.append(len(unique_symbols))
    avg_n_cases = np.mean(n_cases)
    avg_case_complexity = np.mean(case_complexity)
    avg_n_unique_symbols = np.mean(n_unique_symbols)
    all_dt_size = len([vvv for v in logic_programs for vv in v.values() for vvv in vv])
    print(f"Average number of cases: {avg_n_cases:.2f}")
    print(f"Average case complexity: {avg_case_complexity:.2f}")
    print(f"Average number of unique symbols: {avg_n_unique_symbols:.2f}")
    print(f"All DT size: {all_dt_size}")


def extract_logic_program(dt, state_names):
    dt_info = parse_decision_tree(dt)
    feature = dt_info['feature']
    parent = dt_info['parent']
    threshold = dt_info['threshold']
    children_left = dt_info['children_left']
    children_right = dt_info['children_right']
    leaf_idcs = np.where(dt_info['is_leaves'])[0]

    # get traversed path for each leaf node by backtracing
    decision_path = dict()
    for leaf_idx in leaf_idcs:
        path = []
        backtrace_idx = leaf_idx
        while parent[backtrace_idx] != -2: # -2 is the identifier of parent for root node
            parent_idx = parent[backtrace_idx]
            parent_feature_name = state_names[feature[parent_idx]]
            parent_threshold = threshold[parent_idx]

            if backtrace_idx in children_left:
                threshold_sign = '<='
            elif backtrace_idx in children_right:
                threshold_sign = '>'

            if False:
                parent_feature_name = print_names[feature[parent_idx]]
                parent_threshold = f'{parent_threshold:.2f}'

            path.append([parent_feature_name, threshold_sign, parent_threshold])
            backtrace_idx = parent_idx

        path = reduce_logic_program(path)
        decision_path[leaf_idx] = path

    return decision_path


if __name__ == "__main__":
    main()
