import json
import random
import signal
from copy import deepcopy

import gymnasium as gym
from minigrid.utils.baby_ai_bot import BabyAIBot

from babyai_utils import LanguageObsWrapper, get_random_obj
from eval import Evaluator
from graph_creator2 import GraphCreator, gradual_change
from mistake_handling_bot import ErikBot


class Mistake_Handler:

    def __init__(self, seed=None):
        # env = gym.make('BabyAI-MiniBossLevel-v0', render_mode='human')
        env = gym.make('BabyAI-MiniBossLevel-v0')
        env = LanguageObsWrapper(env)

        self.env = env

        if seed is not None:
            self.state = self.get_env_by_seed(seed)
        else:
            self.state = self.get_valid_env()

        self.mistake_handler_agent = ErikBot(env)
        self.perfect_agent = BabyAIBot(env)

        self.mission_subgoals = deepcopy(self.perfect_agent.stack)



        self.mission_subgoals_str = self.get_subgoal_strings()

        self.backup_env = deepcopy(env)

        print("Mission: ", self.state["mission"])
        print("Level: ", self.state["language"])
        print(self.perfect_agent.stack)


    def get_valid_env(self):
        valid = False
        state = None
        while not valid:

            # env = gym.make(f'BabyAI-{env_name}', render_mode='human')
            state, info = self.env.reset()
            mission = state["mission"]
            if 'front' in mission or 'behind' in mission or 'right' in mission or 'left' in mission:
                valid = False
            else:
                valid = True

            if self.env.is_object_next_to_door():
                valid = False

            if self.env.same_objects():
                valid = False
        return state

    def get_env_by_seed(self, seed):
        state, info = self.env.reset(seed=seed)
        return state

    def get_optimal_plan(self):

        done = False
        total_reward = 0
        action = None
        graph = GraphCreator()
        steps = 0
        while not done:
            action = self.perfect_agent.replan(action_taken=action)
            observation, reward, terminated, truncated, info = self.env.step(action)
            graph.add_stack(self.perfect_agent.stack)

            done = terminated or truncated
            total_reward += reward
            self.env.render()

            if steps > 100:
                raise Exception("Too many steps")
            steps += 1
        self.perfect_agent.replan(action_taken=action)
        graph.add_stack(self.perfect_agent.stack)
        graph.add_stack([])
        tree, plan = graph.print()

        return tree,plan

    def check_for_mistake(self, plan):
        new_evaluator = Evaluator()
        new_evaluator.set_env(self.backup_env)
        plan = new_evaluator.check_plan_for_mistake(plan)
        return plan

    def falsify_plan(self, plan, obj):
        falsified_plan = deepcopy(plan)

        actions = ["Open", "Pickup", "Drop", "Go next to", "Close"]
        if random.random() < 1:
            # 50% chance to add or remove an action
            if random.random() < 0.5:
                # Add a random action
                action = random.choice(actions)
                new_action = f"{action} {obj[0]} {obj[1]}"
                insert_position = random.randint(0, len(falsified_plan) - 1)
                falsified_plan.insert(insert_position, new_action)
            else:
                # Remove a random action
                if falsified_plan:
                    remove_position = random.randint(0, len(falsified_plan) - 2)
                    falsified_plan.pop(remove_position)

        return falsified_plan

    def check_the_still_to_do_mission_subgoals(self, done_plan):
        # go through the done_plan and when the current action is the last action of a subgoal, remove the subgoal from the list


        for action in done_plan:
                if action == self.mission_subgoals_str[-1]:
                    self.mission_subgoals_str.pop(-1)



    def get_wrong_plan_with_correction(self):

        print("Mission Subgoals:\t", self.mission_subgoals_str)
        optimal_tree, optimal_plan = self.get_optimal_plan()
        print("Optimal Plan:\t", optimal_plan)
        false_plan = self.falsify_plan(optimal_plan, get_random_obj(self.perfect_agent))
        print("False Plan:\t\t", false_plan)
        evaluated_plan = self.check_for_mistake(false_plan)
        print("Evaluated Plan:\t", evaluated_plan)
        self.check_the_still_to_do_mission_subgoals(evaluated_plan)
        print("Still to do:\t", self.mission_subgoals_str)
        print("Current State", self.state["language"])

        return optimal_tree, optimal_plan, false_plan, evaluated_plan, self.mission_subgoals_str


    def run_corrected_plan(self):
        """
        Runs the corrected plan in the environment.

        Args:
            plan: The plan to run

        Returns:
            None
        """
        done = False
        total_reward = 0
        action = None
        graph = GraphCreator()
        self.mistake_handler_agent = BabyAIBot(self.backup_env)

        for m in self.mission_subgoals:
            m.bot = self.mistake_handler_agent

        self.mistake_handler_agent.stack = self.mission_subgoals[:len(self.mission_subgoals_str)]

        while not done:
            print("JJJ", self.mistake_handler_agent.stack)
            action = self.mistake_handler_agent.replan(action_taken=action)
            observation, reward, terminated, truncated, info = self.backup_env.step(action)
            graph.add_stack(self.mistake_handler_agent.stack)

            done = terminated or truncated
            total_reward += reward
            self.backup_env.render()

        self.perfect_agent.replan(action_taken=action)
        graph.add_stack(self.perfect_agent.stack)
        graph.add_stack([])
        tree, plan = graph.print()

        return tree, plan

    def propagate_subgoal_info(self, mission_subgoals):
        """
        Propagates information between subgoals in the mission stack in place.

        Args:
            mission_subgoals: The list of subgoals in the mission stack

        Returns:
            list: The enriched list of subgoals with propagated information
        """
        found_pickup_datum = None
        # First pass: right-to-left to propagate datum from GoNextToSubgoal to preceding PickupSubgoal or OpenSubgoal
        for i in range(len(mission_subgoals) - 1, -1, -1):
            sg = mission_subgoals[i]
            if "PickupSubgoal" in str(sg) and hasattr(sg, "datum") and sg.datum is not None:
                found_pickup_datum = sg.datum
            if "GoNextToSubgoal" in str(sg) and hasattr(sg, "datum") and sg.datum is not None:
                for j in range(i - 1, -1, -1):
                    prev_sg = mission_subgoals[j]
                    if ("PickupSubgoal" in str(prev_sg) or "OpenSubgoal" in str(prev_sg)) and \
                       (not hasattr(prev_sg, "datum") or prev_sg.datum is None):
                        prev_sg.datum = sg.datum
                        break

        found_pickup = None
        print(mission_subgoals)
        for i in range(len(mission_subgoals) - 1, -1, -1):
            sg = mission_subgoals[i]
            if "PickupSubgoal" in str(sg) and hasattr(sg, "datum") and sg.datum is not None:
                found_pickup = sg.datum
            if "DropSubgoal" in str(sg) and found_pickup is not None:
                sg.carrying = found_pickup
        return mission_subgoals


    def create_string_for_mission_subgoals(self, entry):

        name = str(entry)
        info = []
        if "PickupSubgoal" in name:
            action = "Pickup"
        elif "DropSubgoal" in name:
            action = "Drop"
        elif "GoNextTo" in name:
            action = "Go next to"
        elif "OpenSubgoal" in name:
            action = "Open"
        elif "CloseSubgoal" in name:
            action = "Close"
        elif "FindDropLocationSubgoal" in name:
            action = "Go somewhere else"
        else:
            action = f"Not implemented {name}"

        info.append(action)

        if action == "Pickup" or action == "Open" or action == "Close" or action == "Go next to":
            obj = f"{entry.datum.color} {entry.datum.type}"
            info.append(obj)
        elif action == "Drop":
            obj = f"{entry.carrying.color} {entry.carrying.type}"
            info.append(obj)
        if entry.reason == "PutNext":

            info.append("next to")
            obj = f"to PutNext"
            info.append(obj)
        return " ".join(info)

    def get_subgoal_strings(self):
        return [self.create_string_for_mission_subgoals(entry) for entry in self.propagate_subgoal_info(self.mission_subgoals)]


    def create_dataentry(self):
        optimal_tree, optimal_plan, false_plan, evaluated_plan, rest_to_do = self.get_wrong_plan_with_correction()
        recover_tree, recover_plan = self.run_corrected_plan()

        return {
            "mission": self.state["mission"],
            "level": self.state["language"],
            "optimal_tree": optimal_tree,
            "optimal_plan": optimal_plan,
            "false_plan": false_plan,
            "evaluated_plan": evaluated_plan,
            "recover_tree": recover_tree,
            "recover_plan": recover_plan,
        }



def create_dataset(num_samples=5, seeds_path=None):
    successful_samples = 0
    data = []
    if seeds_path is not None:
        with open(seeds_path, "r") as f:
            seeds = json.load(f)
        for seed in seeds:
            try:
                mh = Mistake_Handler(seed)
                entry = mh.create_dataentry()
                data.append(entry)
                with open('data/MiniBossLevel_mistake.jsonl', 'a+') as f:
                    f.write(json.dumps(entry, ensure_ascii=False))
                    f.write('\n')
                successful_samples += 1
            except Exception as e:
                pass
    else:
        while successful_samples < num_samples:
            try:
                mh = Mistake_Handler()
                entry = mh.create_dataentry()
                data.append(entry)
                with open('data/MiniBossLevel_mistake.jsonl', 'a+') as f:
                    f.write(json.dumps(entry, ensure_ascii=False))
                    f.write('\n')
                successful_samples += 1
            except Exception as e:
                pass



if __name__ == "__main__":
    create_dataset(10000, "code_submission/data/seeds.json")









