#!/usr/bin/env python3



import json
import os

import pandas as pd
import regex as re

goal_strings = [
    "help me clean up",
    "help me tidy",
    "help me prepare",
    "help me create",
    "help me set",
    "help me make",
    "help me organize",
    "let's prepare",
    "let's decorate",
    "let's create",
    "let's tidy",
    "let's clean",
    "let's set",
    "let's organize",
    "let's make",
    "put away",
]

invalid_actions = [
    "squeeze",
    "answer",
    "plug",
    "unplug",
    "connect",
    "call",
    "hang",
    "read",
    "charge",
    "feed",
    "bake",
    "peel",
    "light",
    "stir",
    "fix",
    "eat",
    "use",
    "assemble",
    "human",
    "robot",
    "wrist",
    "wrap",
    "draw",
    "old",
    # "turn",  
    # "fill",
    "open",
    "close",
    # "wash",
    # "pour",
    # "water",
    # "clean",
]

invalid_ordering_terms = ["Then,", "After that,", "Next,", "Finally,"]


def filter_dataset(samples_dict):
    # iterate through the dataset and validate each instruction
    valid_episodes = []
    existing_instructions = []
    valid_instructions = []
    valid_instructions_lower = []
    total_invalid_count = 0

    for sample in samples_dict:
        instruction = sample["instruction"]
        existing_instructions.append(instruction)
        print(instruction)
        valid, new_instruction = validate_instruction(instruction)
        # check for pure/exact matching duplicate instructions too
        if valid and new_instruction.lower() not in valid_instructions_lower:
            sample["instruction"] = new_instruction
            valid_episodes.append(sample)
            valid_instructions.append(new_instruction)
            valid_instructions_lower.append(new_instruction.lower())
        else:
            valid_instructions.append("invalid instruction!")
            total_invalid_count += 1

    # filtered_json_dataset = dict(
    #     {"config": json_dataset["config"], "episodes": valid_episodes}
    # )
    # filtered_json_dataset = dict({"episodes": valid_episodes})

    modified_instruction_overview = pd.DataFrame(
        {
            "old instructions": existing_instructions,
            "new instructions": valid_instructions,
        }
    )

    total_valid_count = len(existing_instructions)
    print(
        "Found",
        total_invalid_count,
        "invalid instructions out of total:",
        total_valid_count,
    )
    if total_valid_count > 0:
        print(
            "Total percent of valid instructions:",
            (total_valid_count - total_invalid_count) / total_valid_count,
        )

    return valid_episodes, modified_instruction_overview


def validate_instruction(instruction):
    # check each instruction
    # if it is valid, add it back to the dataset

    # Remove instructions that have invalid actions such as "squeeze"
    # defined in the list above.
    bad_chars = r"[!]"
    instruction = re.sub(bad_chars, ". ", instruction)
    # all_steps = instruction.split(".")
    all_steps = re.split(r"[.?]", instruction)
    found_valid_instruction = False
    valid_steps_0 = [
        step
        for step in all_steps
        if not bool(set(invalid_actions) & set(step.lower().split(" ")))
        or ("spray" in step)
        and ("spray bottle" in step)
        and (step.lower().index("spray") == step.lower().index("spray bottle"))
    ]

    # remove numbering present after rooms in instructions as this is unnatural
    room_list = [
        "bedroom",
        "living room",
        "kitchen",
        "laundryroom",
        "mudroom",
        "bathroom",
        "hallway",
        "garage",
        "closet",
        "office",
        "shower",
        "entryway",
        "foyer",
        "lobby",
    ]
    for stp_idx, stp in enumerate(valid_steps_0):
        stp_words = stp.split()
        for room in room_list:
            room_idx = -1
            if (
                stp_words.count(room) == 1
            ):  ## if we have multiple reference to a room type, we skip dropping numbers to avoid confusion
                room_idx = stp_words.index(room)
            if room_idx != -1 and room_idx < len(stp_words) - 1:
                next_wrd = stp_words[room_idx + 1]
                if next_wrd.isdigit():
                    del stp_words[room_idx + 1]
        stp = " ".join(stp_words)
        valid_steps_0[stp_idx] = stp

    # check if "bring/fetch/pass/get" is used without ambiguity
    post_specifier = ["put", "place", "to"]
    pre_specifier = ["pack"]
    specifiers = post_specifier + pre_specifier
    getter = ["bring", "fetch", "pass", "get", "give"]
    valid_steps = valid_steps_0.copy()

    for stp in valid_steps_0:
        for get_spec in getter:
            # bring, fetch, pass, get must be followed by put, place, to
            # or preceeded by pack
            if get_spec in stp.lower():
                if not any(spec in stp.lower() for spec in specifiers):  # noqa: SIM114
                    valid_steps.remove(stp)
                    break

                if not any(
                    stp.lower().index(spec) > stp.lower().index(get_spec)
                    for spec in post_specifier
                    if spec in stp.lower() and get_spec in stp.lower()
                ) and not any(
                    stp.lower().index(spec) < stp.lower().index(get_spec)
                    for spec in pre_specifier
                    if spec in stp.lower()
                    and get_spec in stp.lower()
                    and "me" not in stp.lower()
                ):
                    valid_steps.remove(stp)
                    break

    # check if the dropped steps led to incorrect ordering in the overall instruction
    goal_present = False
    for goal in goal_strings:
        if re.search(r"\b" + goal + r"\b", valid_steps[0].lower()):
            goal_present = True
            break
    if len(valid_steps) >= 1:
        for order in invalid_ordering_terms:
            if order in valid_steps[0]:
                valid_steps[0] = str.capitalize(
                    valid_steps[0].replace(order, "").lstrip()
                )
            if len(valid_steps) > 1 and order in valid_steps[1] and goal_present:
                valid_steps[1] = str.capitalize(
                    valid_steps[1].replace(order, "").lstrip()
                )

        # check if the instruction is too abstract
        # abstract_specifier = ["prepare", "make", "organize", "create", "decorate", "put away"]
        # ensure there is at least one more step explaining what make/prepare entails
        # Note: last step is usually '' in the valid_steps
        required_steps_len = 1
        if valid_steps[-1] == "":
            required_steps_len += 1
        if goal_present:
            required_steps_len += 1
        if len(valid_steps) < required_steps_len:
            valid_steps = [""]

    ##super special case#1 to handle "put away"
    # allow put away in goal description (first step of the instruction)
    # put away otherwise should be followed by "to" or "in"
    put_away_specifier = ["to", "in"]
    for ind, stp in enumerate(valid_steps):
        if (
            "away" in stp
            and ind != 0
            and not any(
                stp.lower().index(spec) > stp.lower().index("away")
                for spec in put_away_specifier
                if re.search(r"\b" + spec + r"\b", stp.lower())
            )
        ):
            found_valid_instruction = False
            return found_valid_instruction, ""

    # add question mark back if the instruction was a question
    ques_specifiers = ["can", "could", "would", "will"]
    for stp_id, stp in enumerate(valid_steps):
        if any(spec in stp.lower() for spec in ques_specifiers):
            valid_steps[stp_id] += "?"
        elif stp == "":
            continue
        else:
            valid_steps[stp_id] += "."

    validated_instruction = " ".join(valid_steps)
    if len(validated_instruction) > 1:
        found_valid_instruction = True

        ##super special case#2 -- handle "respective" and "respectively"
        if (
            "respective" in validated_instruction
            and "respectively" not in validated_instruction
        ):
            found_valid_instruction = False

        print("validated instruction! ", validated_instruction)

    return found_valid_instruction, validated_instruction


if __name__ == "__main__":
    root_path = "/"
    folder_name = f"{root_path}/folder_name"
    instructions_csv_path = os.path.join(folder_name, "filtered_instructions.csv")

    with open(os.path.join(folder_name, "init_state_dicts.json"), "r") as f:
        init_state_dicts = json.load(f)

    # obtain filtered dataset
    dataset, instructions_csv = filter_dataset(init_state_dicts)

    # save everything
    instructions_csv.to_csv(instructions_csv_path, index=False)

    # with open("filtered_dataset.json", "wb") as f:
    #     json_dump = json.dumps(json_dataset)
    #     compressed = gzip.compress(json_dump.encode())
    #     f.write(compressed)
