import random

from gymnasium.envs.registration import register
import gymnasium as gym

from babyai_utils import LanguageObsWrapper

from minigrid.utils.baby_ai_bot import BabyAIBot, ObjDesc
from SingleSubgoals import OpenSubgoal, GoNextToSubgoal, ActionInfo, CloseSubgoal, DropSubgoal, PickupSubgoal, \
    FinishSubgoal, FindDropLocationSubgoal
from graph_creator2 import GraphCreator
from my_llm import client
from matplotlib import pyplot as plt

from mistake_handling_bot import ErikBot
import json

register(
    id="BabyAI-UnlockToUnlockN-v0",
    entry_point="multi_open:UnlockToUnlockN",
)

seed = random.randint(0, 100000000) # eval on same seeds
seed = 84117778

class Evaluator:
    def __init__(self):
        self.state = None
        self.env = None
        self.seed = 0
        self.prompt =""
        self.poss_act = ""

    def parse_goal(self, action, agent):
        action = action.split(" ")
        match action[0]:
            case "Pickup":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return PickupSubgoal(agent, datum=obj)
            case "Open":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return OpenSubgoal(agent, datum=obj)
            case "Close":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return CloseSubgoal(agent, datum=obj)
            case "Go":
                if len(action) == 3:
                    return FindDropLocationSubgoal(agent)
                color = action[3]
                color = None if color == "the" else color
                obj = ObjDesc(action[4], color)
                obj.find_matching_objs(self.env)
                if action[-1] == "PutNext":
                    reason = "PutNext"
                else:
                    reason = None
                return GoNextToSubgoal(agent, datum=obj, reason=reason)
            case "Drop":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return DropSubgoal(agent, datum=obj)
            case "Mission":
                return FinishSubgoal(agent)
            case _:
                obj = None
                return None


    def env_step(self, action):
        """ one action step using LM """

        act = self.parse_action(action)  # (action_index, target_object_pos)
        state, reward, terminated, truncated, info = self.env.env.env.step(act)
        done = terminated or truncated
        return state, reward, done, info

    def create_subgoal(self, action):
        act = self.parse_action(action)
        return act

    def get_llama_plan(self):
        mission = self.state["mission"]
        level = self.state["language"]
        prompt = f"Mission: {mission}\nLevel: {level}▼"
        print(prompt)
        result = ""
        response = client.generate_stream(prompt, max_new_tokens=100)
        for token in response:
            result += token.token.text
        plan = result.split("⑃")
        return plan

    def get_next_llama_step(self):
        result = ""
        print("Prompt: ", self.prompt + self.poss_act)
        response = client.generate_stream(self.prompt + self.poss_act, max_new_tokens=100)
        for token in response:
            result += token.token.text

        # self.prompt += self.poss_act
        self.prompt += result
        print("LLLAAAAAMAAA: ", result)
        return result


    def set_seed(self, seed):
        valid = False
        env, state = None, None
        env_name = "MiniBossLevel"
        env = gym.make(f'BabyAI-{env_name}', render_mode='human')
        # env = gym.make(f'BabyAI-{env_name}')
        env = LanguageObsWrapper(env)

        state, info = env.reset(seed=seed)
        mission = state["mission"]
        self.seed = seed
        self.state = state
        self.env = env
        self.env.render()

    def get_valid_env(self, env="MiniBossLevel", render=False, **kwargs):
        valid = False
        state = None
        global seed
        print("Seed: ", self.seed)
        while not valid:
            if render:
                env = gym.make(f'BabyAI-{env_name}', render_mode='human', **kwargs)
            else:
                env = gym.make(f'BabyAI-{env_name}', **kwargs)
            env = LanguageObsWrapper(env)
            state, info = env.reset(seed=seed)
            mission = state["mission"]
            if 'front' in mission or 'behind' in mission or 'right' in mission or 'left' in mission:
                valid = False
            else:
                valid = True

            if env.is_object_next_to_door():
                valid = False

            if env.same_objects():
                valid = False
            seed += 1
            print("Seed: ", seed)

        self.seed = seed
        self.state = state
        self.env = env
        self.env.render()


    def get_seeds(self):
        with open("data/old_data/test-env-locked-doors.jsonl", "r") as f:
            data = []
            for line in f:
                data.append(json.loads(line))
        seeds = []

        for entry in data:
            seeds.append(entry["seed"])
        return seeds[:100]


    def evaluate(self, env_name="UnlockToUnlockN", num_runs=100, **kwargs):
        success_count = 0
        results = []

        while len(results) < num_runs:
            self.get_valid_env(env=env_name, **kwargs)
            success = self.test_plan()
            results.append(success)
            if success:
                success_count += 1

        # Create a graphic showing the results
        plt.figure(figsize=(10, 6))
        plt.title(f"{env_name} : {kwargs.get("n_extra_doors", 0)}, Seed: {self.seed}")
        plt.bar(range(len(results)), results, color='blue')
        plt.xlabel('Run')
        plt.ylabel('Success (1) / Failure (0)')
        plt.title(f'Successful Runs: {success_count}/{len(results)}')
        plt.show()



    def test_plan(self):
        try:
            # create subgoals for each action in plan
            agent = ErikBot(self.env)

            done = False
            total_reward = 0
            action = None
            steps = 0

            mission = self.state["mission"]
            level = self.state["language"]
            poss_act = agent.get_possible_actions()
            self.prompt = ""
            # self.prompt = f"Mission: {mission}\nLevel: {level}\nPossible Actions: {poss_act}▼"
            self.prompt = f"Mission: {mission}\nLevel: {level} ▼"
            self.poss_act = ""
            action_response = self.get_next_llama_step()
            action_response = action_response.split("Next Step: ")[1]
            action_response = action_response.split("▼")[0]
            action_response = action_response.strip()
            print("I want to do: ", action_response)
            subgoal = self.parse_goal(action_response, agent)
            if subgoal is not None:
                agent.stack = [subgoal]

            shoud_be_done = False
            while not done:
                if len(agent.stack) == 0:
                    poss_act = agent.get_possible_actions()
                    # self.poss_act = f"Possible Actions: {poss_act}▼"
                    self.poss_act = f""
                    action_response = self.get_next_llama_step()
                    action_response = action_response.split("Next Step: ")[1]
                    action_response = action_response.split("▼")[0]
                    action_response = action_response.strip()
                    if action_response == "Mission Complete":
                        shoud_be_done = True
                    subgoal = self.parse_goal(action_response, agent)
                    if subgoal is not None:
                        agent.stack = [subgoal]

                action = agent.replan(action_taken=action)
                if isinstance(action, ActionInfo):
                    if not action.success:
                        print(action.data)
                        self.prompt += "\n" + action.data + " "
                        agent.stack = []
                        continue
                        # ask llama for recovery plan
                    action = action.data

                observation, reward, terminated, truncated, info = self.env.step(action)
                done = terminated or truncated
                if shoud_be_done and not done:
                    print("I should be done, but I'm not")
                    return False
                total_reward += reward
                self.env.render()
                if steps == 300:
                    return False

            return True
        except Exception as e:
            print("Error during evaluation:", e)
            print(e)
            return False

    def check_plan_for_mistake(self, plan):
        # create subgoals for each action in plan
        agent = ErikBot(self.env)
        subgoals = []
        for action in plan:
            goal = self.parse_goal(action, agent)
            if goal is not None:
                subgoals.append(goal)

        agent.stack = list(reversed(subgoals))
        done = False
        total_reward = 0
        action = None
        steps = 0
        done_plan = None
        graph = GraphCreator()
        moves_to_mistake = None
        while not done:
            graph.add_stack(agent.stack)
            action = agent.replan(action_taken=action)
            if isinstance(action, ActionInfo):
                if not action.success:
                    moves_to_mistake = graph.subgoals_done() + 1
                    break
                action = action.data
            observation, reward, terminated, truncated, info = self.env.step(action)

            # print(terminated, truncated, len(agent.stack) == 0)
            done = (terminated or truncated) and len(agent.stack) == 0

            total_reward += reward
            self.env.render()


            if steps == 100:
                if done_plan is not None:
                    done_plan.append("Mistake: Mission not fulfilled")
                break
            steps += 1


        if moves_to_mistake is not None:
            done_plan = plan[:moves_to_mistake]
            done_plan.append(action.data)
        else:
            # throw Exception
            raise Exception("No mistake found in plan")
        return done_plan

    def set_env(self, env):
        self.env = env

if __name__ == "__main__":
    env_name = "MiniBossLevel"
    evaluator = Evaluator()
    # evaluator.evaluate(num_runs=10, env_name=env_name, render=False, n_extra_doors=3)
    evaluator.evaluate(num_runs=100, env_name=env_name, render=False)


