import os
from copy import deepcopy

import numpy as np


def filter_robomimic_XYZ_only(data):
    new_list = []
    for i in data:
        if i["feedback"] in [0, 1, 2, 3, 4, 5, "0", "1", "2", "3", "4", "5"]:
            new_list.append(i)
    return np.asarray(new_list)

def fix_robomimic_XYZ_feedback_out_range(data):
    new_list = []
    for i in data:
        i_copy = i.copy()
        clipped = np.clip(i["delta_actions"], -1 , 1)
        i_copy["delta_actions"] = clipped
        new_list.append(i_copy)
    return new_list

def robomimic_gripper_binary(data):
    new_list = []
    for i in data:
        i_copy = i.copy()
        tmp = np.asarray(i["delta_actions"])
        tmp[:, -1] = np.where(tmp[:, -1] > 0, 1, -1)
        tmp[-1, -1] = -1
        tmp[-2, -1] = 1
        i_copy["delta_actions"] = tmp
        new_list.append(i_copy)
    return np.asarray(new_list)

def filter_robomimic_XYZ_gripper_only(data):
    new_list = []
    for i in data:
        if i["feedback"] in [0, 1, 2, 3, 4, 5, 12, 13, "0", "1", "2", "3", "4", "5", "12", "13"]:
            new_list.append(i)
    return np.asarray(new_list)


def filter_robomimic_XYZ_gripper_only_with_history(data):
    new_list = []
    for i in data:
        if "history" in i:
            for j in i["history"]:
                action = j["action"]
                action[-1] = 1 if action[-1] > 0 else -1
                j["action"] = action
        new_list.append(i)
    return np.asarray(new_list)

def filter_robomimic_XYZ_gripper_only_binary_gripper_preference(data):
    new_list = []
    for i in data:
        print("Before ", i["feedback"])
        print(i)
        tmp_i = i.copy()
        tmp_i["action1"] = tmp_i["action1"][[0, 1, 2, 6]]
        tmp_i["action1"][-1] = 1 if tmp_i["action1"][-1] > 0 else -1
        tmp_i["action2"] = tmp_i["action2"][[0, 1, 2, 6]]
        tmp_i["action2"][-1] = 1 if tmp_i["action2"][-1] > 0 else -1
        
        new_list.append(tmp_i)

        print("After ", tmp_i["feedback"])
        print("\n")

    return np.asarray(new_list)


def filter_robomimic_XYZ_gripper_only_binary_gripper_binary_feedback(data):
    new_list = []
    for i in data:
        print("Before ", i["feedback"])
        print(i)
        tmp_i = i.copy()
        tmp_i["action"] = tmp_i["action"][[0, 1, 2, 6]]
        tmp_i["action"][-1] = 1 if tmp_i["action"][-1] > 0 else -1
        new_list.append(tmp_i)

        print("After ", tmp_i["feedback"])
        print("\n")

    return np.asarray(new_list)

def filter_robomimic_XYZ_gripper_only_binary_gripper_action_advising(data):
    new_list = []
    for i in data:
        # print("Before ", i["feedback"])
        # print(i)
        tmp_i = i.copy()
        tmp_i["feedback"][0] = tmp_i["feedback"][0][[0, 1, 2, 6]]
        tmp_i["feedback"][0][-1] = 1 if tmp_i["feedback"][0][-1] > 0 else -1
        new_list.append(tmp_i)

        # print("After ", tmp_i["feedback"])
        # print("\n")

    return np.asarray(new_list)


def filter_cliffwalking_impossible_states(data):
    CLIFFWALKING_IMPOSSIBLE_STATES = [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
    new_list = []
    for i in data:
        if i["state"] not in CLIFFWALKING_IMPOSSIBLE_STATES:
            new_list.append(i)
    return np.asarray(new_list)


def filter_results_cliffwalking_impossible_states(data):
    CLIFFWALKING_IMPOSSIBLE_STATES = [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
    new_list = []
    for i in data:
        if i["State"]["state"] not in CLIFFWALKING_IMPOSSIBLE_STATES:
            new_list.append(i)
    return new_list


def minigrid_observation_to_flattened_array(minigrid_array):
    minigrid_map = {
        "^":0,
        "v":1,
        "V":1,
        "<":2,
        ">":3,
        "#":4,
        "K":5,
        "G":6,
        "D":7,
        ".": 8
    }
    H, W = minigrid_array.shape
    num_symbols = 9
    one_hot_grid = np.zeros((H, W, num_symbols))

    for i in range(H):
        for j in range(W):
            one_hot_grid[i, j, minigrid_map[minigrid_array[i, j]]] = 1
    one_hot_grid = one_hot_grid.reshape(-1)
    return one_hot_grid


def replace_minigrid_observation_to_one_hot(data):
    new_list = []
    for i in data:
        i_copy = deepcopy(i)
        grid_representation = i_copy["State"]["state"]["ObsString"]
        grid_representation = minigrid_observation_to_flattened_array(grid_representation)
        grid_representation = grid_representation.reshape(-1)
        # i_copy["State"]["state"] = grid_representation
        i_copy["State"]["state"]["Observation"]["image"] = grid_representation
        new_list.append(i_copy)
    return new_list

def filter_results_wrong_away(data):
    new_list = []
    for i in data:
        if i["Correct"] == 1:
            new_list.append(i)
    return new_list


def remove_neutral_preference(data):
    new_list = []
    for i in data:
        if i["feedback"] != 0:
            new_list.append(i)
    return new_list



def seed_every_where(seed):
    """
    Set the random seed for numpy, torch, and random modules.
    """
    import random
    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_latest_modified_file_with_prefix(directory_path, prefix, excluding="uniform_indexes.npy", n=1):
    """
    Returns the path of the most recently modified file in the given directory
    that matches the specified prefix.

    Args:
        directory_path (str): The path to the directory.
        prefix (str): The prefix to filter files.

    Returns:
        str: The path to the latest modified file with the given prefix, or None if no match is found.
    """
    # print("Searching for: ", directory_path, prefix)
    if not os.path.isdir(directory_path):
        # raise ValueError(f"The path {directory_path} is not a valid directory.")
        return None  # The directory does not exist
    # Get all files in the directory with their full paths that match the prefix
    files = [
        os.path.join(directory_path, file)
        for file in os.listdir(directory_path)
        if os.path.isfile(os.path.join(directory_path, file)) and file.startswith(prefix) and excluding not in file
    ]

    # print("These are found: ", files)
    if not files:
        return None  # No matching files in the directory

    # Find the latest modified file
    sorted_files = sorted(files, key=os.path.getmtime, reverse=True)
    # Return the n-th latest file if it exists, else None
    if 0 < n <= len(sorted_files):
        return sorted_files[n-1]
    elif n < 0:
        return sorted_files
    return None




if __name__ == "__main__":
    print("Test")