import os
import pickle
import numpy as np
import argparse
import matplotlib.pyplot as plt
from matplotlib import cm

from .deploy_decision_tree import parse_decision_tree, reduce_logic_program


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dt-path', type=str, required=True)
    parser.add_argument('--model', type=str, required=True)
    parser.add_argument('--env', type=str, required=True)
    parser.add_argument('--out-dir', type=str, default=None)
    args = parser.parse_args()

    with open(args.dt_path, 'rb') as f:
        dt_all = pickle.load(f)

    if args.model == 'ncp':
        neuron_prefix = 'cmd'
    elif args.model == 'lstm':
        neuron_prefix = 'h_c'
    else:
        neuron_prefix = 'state'

    if args.env == 'pendulum':
        n_neurons = 4
        state_names = ['theta', 'theta_dot']
        print_names = [r'\theta', r'\dot{\theta}']
        neuron_names = [f'{neuron_prefix}{i:02d}' for i in range(n_neurons)]
    elif args.env == 'halfcheetah':
        n_neurons = 10
        state_names = ['rootz_p', 'rooty_p', 'bthigh_p', 'bshin_p', 'bfoot_p',
                       'fthigh_p', 'fshin_p', 'ffoot_p', 'rootx_v', 'rootz_v', 'rooty_v',
                       'bthigh_v', 'bshin_v', 'bfoot_v', 'fthigh_v', 'fshin_v', 'ffoot_v']
        print_names = [r'h_R', r'\theta_R', r'\theta_{T,B}', r'\theta_{S,B}', r'\theta_{F,B}', r'\theta_{T,F}',
                       r'\theta_{S,F}', r'\theta_{F,F}', r'\dot{x}_R', r'\dot{h}_R', r'\dot{\theta}_R',
                       r'\dot{\theta}_{T,B}', r'\dot{\theta}_{S,B}', r'\dot{\theta}_{F,B}', r'\dot{\theta}_{T,F}',
                       r'\dot{\theta}_{S,F}', r'\dot{\theta}_{F,F}']
        neuron_names = [f'{neuron_prefix}{i:02d}' for i in range(n_neurons)]
    elif args.env == 'driving':
        n_neurons = 8
        state_names = ['v', 'delta', 'dx', 'dy', 'dyaw', 'kappa']
        print_names = [r'v', r'\delta', r'd', r'\Delta l', r'\mu', r'\kappa']
        neuron_names = [f'{neuron_prefix}{i:02d}' for i in range(n_neurons)]
    else:
        raise ValueError(f'Unrecognized env {args.env}')

    if args.out_dir:
        if not os.path.isdir(args.out_dir):
            os.makedirs(args.out_dir)

    logic_programs = []
    for neuron_i, neuron_name in enumerate(neuron_names):
        dt_i = dt_all[neuron_name]
        logic_program = extract_logic_program(dt_i, state_names, print_names)
        logic_programs.append(logic_program)

    n_rows = np.sum([len(v) for v in logic_programs])
    print('')
    print('\multirow{'+ f'{n_rows}' + '}{*}' + '{' + args.model + '}')
    for neuron_i, neuron_name in enumerate(neuron_names):
        logic_program = logic_programs[neuron_i]

        print(f'& {neuron_i} &')

        logic_str = list(logic_program.values())
        logic_str = [r' \wedge '.join(['(' + ' '.join(vv) + ')' for vv in v]) for v in logic_str]
        logic_str = [f'{vi}:~$' + v + '$' for vi, v in enumerate(logic_str)]
        # logic_str = ' \\\ '.join([r'\makecell{$' + v + '$}' for v in logic_str])
        logic_str = r'\makecell[l]{' + ' \\\ '.join(logic_str) + '}'
        print(logic_str)
        print(r'\\')
        if neuron_i < (n_neurons - 1):
            print(r'\cmidrule{2-3}')


def extract_logic_program(dt, state_names, print_names):
    dt_info = parse_decision_tree(dt)
    feature = dt_info['feature']
    parent = dt_info['parent']
    threshold = dt_info['threshold']
    children_left = dt_info['children_left']
    children_right = dt_info['children_right']
    leaf_idcs = np.where(dt_info['is_leaves'])[0]

    # get traversed path for each leaf node by backtracing
    decision_path = dict()
    for leaf_idx in leaf_idcs:
        path = []
        backtrace_idx = leaf_idx
        while parent[backtrace_idx] != -2: # -2 is the identifier of parent for root node
            parent_idx = parent[backtrace_idx]
            parent_feature_name = state_names[feature[parent_idx]]
            parent_threshold = threshold[parent_idx]

            if backtrace_idx in children_left:
                threshold_sign = '<='
            elif backtrace_idx in children_right:
                threshold_sign = '>'

            if True:
                parent_feature_name = print_names[feature[parent_idx]]
                parent_threshold = f'{parent_threshold:.2f}'

            path.append([parent_feature_name, threshold_sign, parent_threshold])
            backtrace_idx = parent_idx

        path = reduce_logic_program(path)
        decision_path[leaf_idx] = path

    return decision_path


if __name__ == "__main__":
    main()
