""" Automatic keyframe selection """
import numpy as np

from traj_reconstruction import pos_only_geometric_keyframe_trajectory, reconstruct_keyframe_trajectory, geometric_keyframe_trajectory


""" Iterative keyframe selection """
def greedy_keyframe_selection(
    env=None, actions=None, gt_states=None, err_threshold=None, initial_states=None, remove_obj=None, geometry=True, pos_only=False
):
    # make the last frame a keyframe
    keyframes = [len(actions) - 1]

    # make the frames of gripper open/close keyframes
    if not pos_only:
        for i in range(len(actions) - 1):
            if actions[i, -1] != actions[i + 1, -1]:
                keyframes.append(i)
                keyframes.append(i + 1)
        keyframes.sort()

    # reconstruct the trajectory, and record the reconstruction error for each state
    for i in range(len(actions)):
        if pos_only or geometry:
            func = pos_only_geometric_keyframe_trajectory if pos_only else geometric_keyframe_trajectory
            total_traj_err, reconstruction_error = func(
                actions=actions,
                gt_states=gt_states,
                keyframes=keyframes,
                return_list=True,
            )
        else:
            _, reconstruction_error, total_traj_err = reconstruct_keyframe_trajectory(
                env=env,
                actions=actions,
                gt_states=gt_states,
                keyframes=keyframes,
                verbose=False,
                initial_state=initial_states[0],
                remove_obj=remove_obj,
            )
        # break if the reconstruction error is below the threshold
        if total_traj_err < err_threshold:
            break
        # add the frame of the highest reconstruction error as a keyframe, excluding frames that are already keyframes
        max_error_frame = np.argmax(reconstruction_error)
        while max_error_frame in keyframes:
            reconstruction_error[max_error_frame] = 0
            max_error_frame = np.argmax(reconstruction_error)
        keyframes.append(max_error_frame)
        keyframes.sort()

    print("=======================================================================")
    print(f"Selected {len(keyframes)} keyframes: {keyframes} \t total trajectory error: {total_traj_err:.6f}")
    return keyframes

def heuristic_keyframe_selection(
    env=None, actions=None, gt_states=None, err_threshold=None, initial_states=None, remove_obj=None, geometry=True, pos_only=False
):
    # make the last frame a keyframe
    keyframes = [len(actions) - 1]

    # make the frames of gripper open/close keyframes
    for i in range(len(actions) - 1):
        if actions[i, -1] != actions[i + 1, -1]:
            keyframes.append(i)
    keyframes.sort()

    # if 'robot0_vel_ang' or 'robot0_vel_lin' in gt_states is close to 0, make the frame a keyframe
    threshold = 5e-3
    for i in range(len(gt_states)):
        if np.linalg.norm(gt_states[i]['robot0_vel_ang']) < threshold or np.linalg.norm(gt_states[i]['robot0_vel_lin']) < threshold:
            keyframes.append(i)

    keyframes.sort()

    print("=======================================================================")
    print(f"Selected {len(keyframes)} keyframes: {keyframes}")
    return keyframes


""" Backtrack keyframe selection """
def backtrack_keyframe_selection(env, actions, gt_states, err_threshold, initial_states, remove_obj):
    # add heuristic keyframes
    num_frames = len(actions)
    
    # make the last frame a keyframe
    keyframes = [num_frames - 1]

    # make the frames of gripper open/close keyframes
    for i in range(num_frames - 1):
        if actions[i, -1] != actions[i + 1, -1]:
            keyframes.append(i)
    keyframes.sort()

    # backtracing to find the optimal keyframes
    start = 0
    while start < num_frames - 1:
        for end in range(num_frames - 1, 0, -1):
            rel_keyframes = [k - start for k in keyframes if k >= start and k < end] + [end - start]
            _, _, total_traj_err = reconstruct_keyframe_trajectory(
                env=env,
                actions=actions[start:end+1],
                gt_states=gt_states[start+1:end+2],
                keyframes=rel_keyframes,
                verbose=False,
                initial_state=initial_states[start],
                remove_obj=remove_obj,
            )
            if total_traj_err < err_threshold:
                keyframes.append(end)
                keyframes = list(set(keyframes))
                keyframes.sort()
                break
        start = end

    print("=======================================================================")
    print(f"Selected {len(keyframes)} keyframes: {keyframes} \t total trajectory error: {total_traj_err:.6f}")
    return keyframes
        

""" DP keyframe selection """
# use geometric interpretation
def dp_keyframe_selection(env=None, actions=None, gt_states=None, err_threshold=None, initial_states=None, remove_obj=None, pos_only=False):
    num_frames = len(actions)
    
    # make the last frame a keyframe
    initial_keyframes = [num_frames - 1]

    # make the frames of gripper open/close keyframes
    if not pos_only:
        for i in range(num_frames - 1):
            if actions[i, -1] != actions[i + 1, -1]:
                initial_keyframes.append(i)
                # initial_keyframes.append(i + 1)
        initial_keyframes.sort()

    # Memoization table to store the keyframe sets for subproblems
    memo = {}

    # Initialize the memoization table
    for i in range(num_frames):
        memo[i] = (0, [])

    memo[1] = (1, [1])
    func = pos_only_geometric_keyframe_trajectory if pos_only else geometric_keyframe_trajectory
        
    # Populate the memoization table using an iterative bottom-up approach
    for i in range(1, num_frames):
        min_keyframes_required = float('inf')
        best_keyframes = []
        
        for k in range(1, i):
            # keyframes are relative to the subsequence
            keyframes = [j - k for j in initial_keyframes if j >= k and j < i] + [i - k]
            
            total_traj_err = func(
                actions=actions[k:i+1],
                gt_states=gt_states[k:i+1],
                keyframes=keyframes,
            )

            if total_traj_err < err_threshold:
                subproblem_keyframes_count, subproblem_keyframes = memo[k - 1]
                total_keyframes_count = 1 + subproblem_keyframes_count

                if total_keyframes_count < min_keyframes_required:
                    min_keyframes_required = total_keyframes_count
                    best_keyframes = subproblem_keyframes + [i]

        memo[i] = (min_keyframes_required, best_keyframes)

    min_keyframes_count, keyframes = memo[num_frames - 1]
    keyframes += initial_keyframes
    # remove duplicates
    keyframes = list(set(keyframes))
    keyframes.sort()
    print(f"Minimum number of keyframes: {len(keyframes)} \tTrajectory Error: {total_traj_err}")
    print(f"Keyframe positions: {keyframes}")

    return keyframes


# iterative version, bottom-up
def dp_reconstruct_keyframe_selection(env, actions, gt_states, err_threshold, initial_states, remove_obj):
    num_frames = len(actions)
    
    # make the last frame a keyframe
    initial_keyframes = [num_frames - 1]

    # make the frames of gripper open/close keyframes
    for i in range(num_frames - 1):
        if actions[i, -1] != actions[i + 1, -1]:
            initial_keyframes.append(i)
    initial_keyframes.sort()

    # Memoization table to store the keyframe sets for subproblems
    memo = {}

    # Initialize the memoization table
    for i in range(num_frames):
        memo[i] = (0, [])

    memo[1] = (1, [1])
        
    # Populate the memoization table using an iterative bottom-up approach
    for i in range(1, num_frames):
        min_keyframes_required = float('inf')
        best_keyframes = []
        
        for k in range(1, i):
            # keyframes are relative to the subsequence
            keyframes = [j - k for j in initial_keyframes if j >= k and j < i] + [i - k]
            
            _, _, total_traj_err = reconstruct_keyframe_trajectory(
                env=env,
                actions=actions[k-1:i],
                gt_states=gt_states[k:i+1],
                keyframes=keyframes,
                verbose=False,
                initial_state=initial_states[k-1],
                remove_obj=remove_obj,
            )

            print(f"i: {i}, k: {k}, total_traj_err: {total_traj_err}")

            if total_traj_err < err_threshold:
                subproblem_keyframes_count, subproblem_keyframes = memo[k - 1]
                total_keyframes_count = 1 + subproblem_keyframes_count

                if total_keyframes_count < min_keyframes_required:
                    min_keyframes_required = total_keyframes_count
                    best_keyframes = subproblem_keyframes + [i]

                    print(f"min_keyframes_required: {min_keyframes_required}, best_keyframes: {best_keyframes}")
                    
        memo[i] = (min_keyframes_required, best_keyframes)

    min_keyframes_count, keyframes = memo[num_frames - 1]
    keyframes += initial_keyframes
    # remove duplicates
    keyframes = list(set(keyframes))
    keyframes.sort()
    print(f"Minimum number of keyframes: {len(keyframes)}")
    print(f"Keyframe positions: {keyframes}")

    return keyframes


# backlog: recursive version, top-down
def recursive_dp_keyframe_selection(env, actions, gt_states, err_threshold, initial_states, remove_obj):
    num_frames = len(actions)

    # make the last frame a keyframe
    initial_keyframes = [num_frames - 1]

    # make the frames of gripper open/close as keyframes
    for i in range(num_frames - 1):
        if actions[i, -1] != actions[i + 1, -1]:
            initial_keyframes.append(i)
    initial_keyframes.sort()

    # Memoization table to store the keyframe sets for subproblems
    memo = {}

    def min_keyframes(i, err_threshold):
        if i < 1:
            return (0, [])

        if i in memo:
            return memo[i]

        min_keyframes_required = float('inf')
        best_keyframes = []

        for k in range(1, i):

            # keyframes are relative to the subsequence
            keyframes = [j - k for j in initial_keyframes if j >= k and j < i] + [i - k]

            _, _, total_traj_err = reconstruct_keyframe_trajectory(
                env=env,
                actions=actions[k:i+1],
                gt_states=gt_states[k:i+1],
                keyframes=keyframes,
                verbose=False,
                initial_state=initial_states[k],
                remove_obj=remove_obj,
            )

            # print some useful information for debugging
            print(f"i: {i}, k: {k}, total_traj_err: {total_traj_err}")

            if total_traj_err < err_threshold:
                subproblem_keyframes_count, subproblem_keyframes = min_keyframes(k - 1, err_threshold)
                total_keyframes_count = 1 + subproblem_keyframes_count

                if total_keyframes_count < min_keyframes_required:
                    min_keyframes_required = total_keyframes_count
                    best_keyframes = subproblem_keyframes + [i]

            print(f"min_keyframes_required: {min_keyframes_required}, best_keyframes: {best_keyframes}")

        memo[i] = (min_keyframes_required, best_keyframes)
        return memo[i]

    min_keyframes_count, keyframes = min_keyframes(num_frames - 1, err_threshold)
    keyframes += initial_keyframes
    # remove duplicates
    keyframes = list(set(keyframes))
    keyframes.sort()
    print(f"Minimum number of keyframes: {min_keyframes_count}")
    print(f"Keyframe positions: {keyframes}")

    return keyframes
