import torch


def action_to_letter(action):
    if action == 0:
        return "A"
    elif action == 1:
        return "B"
    elif action == 2:
        return "C"
    elif action == 3:
        return "D"
    else:
        raise ValueError("Invalid action")


def tensor_to_tuple(tensor):
    return tuple(tensor.cpu().numpy().flatten())


def get_public_prompt():
    prompt = "You are an expert in refining the automaton of a robot for the task Fetch available in OpenAI Gym repository.\n"
    prompt += (
        "You need to first understand the task and the automaton of the robot.\n\n"
    )
    prompt += "# Task Description\n"
    prompt += "The task in the environment is for a manipulator to move a block to a target position on top of a table or in mid-air."
    prompt += "The robot is a 7-DoF Fetch Mobile Manipulator with a two-fingered parallel gripper. "
    prompt += "The robot is controlled by small displacements of the gripper in Cartesian coordinates and the inverse kinematics are computed internally by the MuJoCo framework."
    prompt += "The gripper can be opened or closed in order to perform the graspping operation of pick and place."
    prompt += "The task is also continuing which means that the robot has to maintain the block in the target position for an indefinite period of time.\n"
    # prompt += "The robotic arm is required to approach, grab, and lift the object to its position.\n"
    # prompt += "The agent needs to solve the task in discrete steps."
    # prompt += "At each step, it takes as input the current predicate state of the automaton and outputs the next action."
    prompt += "# Total Predicates ID and Its Description\n"
    prompt += "- total_predicate is a dictionary that maps the predicate ID to its description including three types of predicates:"
    prompt += " 1. The relative distance between the end effector and the object along the x, y, and z axes.\n"
    prompt += " 2. The displacement of the left gripper and the right gripper.\n"
    prompt += " 3. The height of the object from the table.\n"
    prompt += "# Key Predicates ID\n"
    prompt += "- *key_predicates_id* is a list of predicate IDs that are important for the task.\n"
    # 描述清楚state和state predicates的关系，state predicates和key predicates的关系
    prompt += "# State Predicates\n"
    prompt += "- *state_predicates* is a dictionary that maps the state ID in *states* to a list of tuples of predicates.\n"
    prompt += "- The boolean values of each predicate tuple are in the order of the *key_predicates_id*.\n"
    # prompt += "- Each state consists of several tuples of predicates, a"
    prompt += "# Automaton\n"
    prompt += (
        "- *action* is a dictionary that maps the action letter to its description.\n"
    )
    prompt += "- *states* is a list of state IDs in the automaton.\n"
    prompt += "- *start_state* is the initial state of the automaton.\n"
    prompt += "- *accept_states* is a list of accepting states in the automaton.\n"
    prompt += "- *transitions* is a dictionary of transitions in the automaton. The key is a tuple of the state ID and the action letter, and the value is the next state ID.\n"
    prompt += "# Failure Trajectory\n"
    prompt += "- *failure_trajectory* is a list of episodes. Each episode is a dictionary of transitions. The key is a tuple of the state, and the value is the next state. The state is a tuple represented by boolean values in the order of the *key_predicates_id and the action letter\n"
    # prompt += "# Automaton Description\n"
    # prompt += "The automaton of the robot is composed of states and transitions.\n"
    # prompt += "The states are represented by the predicates of the robot and the block in the environment from three different perspectives:\n"
    # prompt += "The transitions are represented by the abstract actions of the robot.\n"
    print("----------------prompt----------------")
    print(prompt)
    return prompt


def set_input_prompt(
    p_dict, key_predicates, automaton, letter_state_mapping, failure_trajectory
):
    prompt = "# Input\n"
    prompt += "## Total Predicates ID and Its Description\n"
    prompt += "- total_predicates = {"
    for k, p in p_dict.items():
        prompt += f"'P{k}' : '{p['desc']}', "
    prompt += "}\n"
    prompt += "## Key Predicates ID\n"
    prompt += "- key_predicates_id = ["
    for kp in key_predicates:
        prompt += f"'P{kp}', "
    prompt += "]\n"
    prompt += "## State Predicates\n"
    prompt += "- state_predicates = {"
    for frozenset, new_state in automaton["states"].items():
        prompt += f"{new_state} : ["
        for k in frozenset:
            for v in letter_state_mapping[k]:
                prompt += f"{v}, "
        prompt += "], "
    prompt += "}\n"
    prompt += "## Automaton\n"
    prompt += "- action = {'A': 'approach', 'B': 'grab', 'D': 'lift'}\n"
    prompt += f"- states = {list(automaton['states'].values())}\n"
    prompt += f"- start_state = {automaton['start_state']}\n"
    prompt += f"- accept_states = {automaton['accept_states']}\n"
    prompt += f"- transitions = {automaton['transitions']}\n"
    prompt += "## Failure Trajectory\n"
    prompt += "- failure_trajectory = ["
    for i, ft in enumerate(failure_trajectory):
        episode_transitions = {}
        episode_states = []
        episode_actions = []
        for j, transition in enumerate(ft):
            # print("action: ", a[0], " transition: ", a[1], " reward: ", a[2], " p_star_dict: ", a[3])
            action = transition[0]
            state = transition[3]
            _s = [v["bool"] for v in state.values()]
            state_tensor = torch.tensor(_s, dtype=torch.int32)
            action = action_to_letter(action)
            episode_states.append(tensor_to_tuple(state_tensor))
            episode_actions.append(action)
            if j != 0:
                episode_transitions[(episode_states[j - 1], episode_actions[j - 1])] = (
                    tensor_to_tuple(state_tensor)
                )
        prompt += f"{episode_transitions}, "
    prompt += "]\n"
    print(prompt)
    return prompt


def get_task_prompt():
    prompt = "# Your Task\n"
    prompt += "You need to analyze and refine the automaton of the robot. You must follow the following rules.\n"
    prompt += "1. You can also leverage your own knowledge about the goal of the task, but the conclusions must be based on the Input.\n"
    prompt += "2. You need to analyze the automaton in these two steps: "
    prompt += "(a) analyze the logical relation between key predicates of states, "
    prompt += "(b) analyze why failed trajectories are failed to reach accept states, "
    prompt += "and (c) refine the automaton by proposing new key predicates.\n"
    # and (b) refine the automaton.\n"
    prompt += "3. When performing (a), you can first consider relationship between predicate tuples in the list and then consider using the logical operators AND, OR, and NOT to combine the predicates in the tuple.\n"
    prompt += "4. When performing (c), you should reduce the number of state to four by removing counterintuitive transitions.\n"
    # prompt += "4. When performing (b), you can add new states and transitions to the automaton, \
    #     you can change the predicates of the states to make it more reasonable, \
    #     you cannot remove any existing states or transitions from the automaton, \
    #     and you cannot change the start state or the accepting states of the automaton.\n"
    # TODO: 1. 给LLM成功或失败的trajecotry，让它分析
    # TODO: 2. 生成自动机的时候，把没有到5的轨迹，也就是失败的轨迹，删掉，屏蔽掉一些错误的答案
    prompt += "5. When performing (c), the format of the new states and transitions must be consistent with the existing automaton.\n"
    prompt += "6. The state ID must be a unique integer, and the action letter must be a unique character.\n"
    # prompt += "7. Be specific the effect of each term.\n"
    prompt += "## Output\n"
    # prompt += "You need to output the following:\n"
    prompt += "Now, analyze the logical relation between key predicates of states.\n"
    prompt += "{ChatGPT response}\n"
    # prompt += "Refine the automaton.\n"
    prompt += "Analyze why failed trajectories are failed to reach accept states.\n"
    prompt += "{ChatGPT response}\n"
    prompt += "Refine the automaton by proposing new key predicates.\n"
    prompt += "{ChatGPT response}\n"
    print(prompt)
    return prompt


def main():
    user_prompt = get_public_prompt()
    print(f"You entered: {user_prompt}")


if __name__ == "__main__":
    main()
