import os
import pickle
import argparse
import tqdm
import numpy as np
from skimage.io import imsave
import skvideo.io as io
import matplotlib
import matplotlib.pyplot as plt


cmd_idx = 5
feature_name = ['v', 'delta', 'dx', 'dy', 'dyaw', 'kappa'] # 'x', 'y', 'd', 'mu'
use_dt_inv = False


def main():
    # parse argument
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-dir', type=str, required=True)
    parser.add_argument('--out-dir', type=str, required=True)
    parser.add_argument('--max-episodes', type=int, default=10)
    args = parser.parse_args()

    # get all video paths
    video_root_dir = os.path.join(args.results_dir, 'video')
    video_paths = []
    for video_path in sorted(os.listdir(video_root_dir)):
        if os.path.splitext(video_path)[-1] != '.mp4':
            continue
        video_path = os.path.join(video_root_dir, video_path)
        video_paths.append(video_path)

    # load data
    results_path = os.path.join(args.results_dir, 'results.pkl')
    with open(results_path, 'rb') as f:
        data = pickle.load(f)

    dt_path = os.path.join(args.results_dir, 'dt/dt.pkl')
    with open(dt_path, 'rb') as f:
        dt = pickle.load(f)
    dt_of_interest_info = parse_dt(dt[f'cmd{cmd_idx:02d}'], feature_name)
    leaf_idcs = dt_of_interest_info['leaf_idcs']
    prompt = True
    while prompt:
        tgt_leaf_idx = int(input(f'Choose a target leaf index {leaf_idcs}: '))
        print(dt_of_interest_info['decision_path'][tgt_leaf_idx])
        prompt = not (tgt_leaf_idx in leaf_idcs)
        if prompt:
            print('Please select a number from the array')

    # get frame of interest
    frame_of_interest_idcs = []
    neuron_acts = []
    for ep_i, ep_data in enumerate(data):
        frame_of_interest_idcs.append([])
        neuron_acts.append([])
        for step_i, step_data in enumerate(ep_data):
            include_frame, neuron_act = check_step_data(step_data, dt, dt_of_interest_info, tgt_leaf_idx, ep_i, step_i)
            neuron_acts[-1].append(neuron_act)
            if include_frame:
                frame_of_interest_idcs[-1].append(step_i)
    
    # save frames
    if not os.path.isdir(args.out_dir):
        os.makedirs(args.out_dir)

    for ep_i, (ep_foi_idcs, video_path) in tqdm.tqdm(enumerate(zip(frame_of_interest_idcs, video_paths)), total=len(data)):
        video_reader = io.FFmpegReader(video_path, inputdict={}, outputdict={})

        for step_i, frame in tqdm.tqdm(enumerate(video_reader.nextFrame()), desc='Step', total=len(ep_data)):
            save_frame = step_i in ep_foi_idcs
            if save_frame:
                neuron_act = neuron_acts[ep_i][step_i]
                frame_fpath = os.path.join(args.out_dir, f'ep{ep_i:02d}_step{step_i:02d}_nact{neuron_act:.6f}.png')
                imsave(frame_fpath, frame)

        if ep_i >= (args.max_episodes - 1):
            break


def check_step_data(step_data, dt, dt_of_interest_info, tgt_leaf_idx, ep_i, step_i):
    ego_agent_id = step_data['ego_agent_id']
    step_logs = step_data['logs'][ego_agent_id]

    cmd_neuron_indices = dt['cmd_neuron_indices']
    rnn_state = np.stack(step_data['rnn_state'][0]).squeeze()
    rnn_state = rnn_state[cmd_neuron_indices[0]:cmd_neuron_indices[1]]
    neuron_act = rnn_state[cmd_idx]

    if use_dt_inv:
        pred_leaf_id = dt[f'inv:cmd{cmd_idx:02d}'].predict(np.array([[neuron_act]]))[0]
    else:
        leaf_idcs = dt_of_interest_info['leaf_idcs']
        leaf_values = dt[f'cmd{cmd_idx:02d}'].tree_.value[leaf_idcs].squeeze()
        closest_idx = np.argmin(np.abs(neuron_act - leaf_values))
        pred_leaf_id = leaf_idcs[closest_idx]
    if tgt_leaf_idx == pred_leaf_id:
        decision_path = dt_of_interest_info['decision_path'][pred_leaf_id]
        prop_eval = True
        for prop in decision_path:
            val = step_logs[prop[0]]
            op = prop[1]
            thresh = prop[2]
            if op == '<=':
                prop_eval = prop_eval & (val <= thresh)
            elif op == '>':
                prop_eval = prop_eval & (val > thresh)
            else:
                raise ValueError
    else:
        prop_eval = False
    
    return prop_eval, neuron_act


def parse_dt(dt, feature_name):
    # get basic information
    n_nodes = dt.tree_.node_count
    children_left = dt.tree_.children_left
    children_right = dt.tree_.children_right
    feature = dt.tree_.feature
    threshold = dt.tree_.threshold

    parent = {0: -2} # root node has no parent

    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
    while len(stack) > 0:
        # `pop` ensures each node is only visited once
        node_id, depth = stack.pop()
        node_depth[node_id] = depth

        parent[children_left[node_id]] = node_id
        parent[children_right[node_id]] = node_id

        # If the left and right child of a node is not the same we have a split
        # node
        is_split_node = children_left[node_id] != children_right[node_id]
        # If a split node, append left and right children and depth to `stack`
        # so we can loop through them
        if is_split_node:
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
        else:
            is_leaves[node_id] = True

    parent = [parent[k] for k in sorted(parent.keys()) if k not in [-1, -2]]

    # get traversed path for each leaf node by backtracing
    leaf_idcs = np.where(is_leaves)[0]
    decision_path = dict()
    for leaf_idx in leaf_idcs:
        path = []
        backtrace_idx = leaf_idx
        while parent[backtrace_idx] != -2: # -2 is the identifier of parent for root node
            parent_idx = parent[backtrace_idx]
            parent_feature_name = feature_name[feature[parent_idx]]
            parent_threshold = threshold[parent_idx]

            if backtrace_idx in children_left:
                threshold_sign = '<='
            elif backtrace_idx in children_right:
                threshold_sign = '>'

            path.append([parent_feature_name, threshold_sign, parent_threshold])
            backtrace_idx = parent_idx

        decision_path[leaf_idx] = path

    return {
        'n_nodes': n_nodes,
        'children_left': children_left, # left children for each node; -1 means no children
        'children_right': children_right,
        'parent': parent,
        'feature': feature,
        'threshold': threshold,
        'node_depth': node_depth,
        'is_leaves': is_leaves,
        'stack': stack,
        'decision_path': decision_path,
        'leaf_idcs': leaf_idcs,
    }


if __name__ == '__main__':
    main()
