import time

import gymnasium as gym
from numpy.f2py.auxfuncs import throw_error
from transformers.testing_utils import require_spacy

# from babyai.levels.verifier import ObjDesc
from babyai_utils import LanguageObsWrapper
import regex as re
from dotmap import DotMap

from minigrid.utils.baby_ai_bot import BabyAIBot, ObjDesc
from SingleSubgoals import OpenSubgoal, GoNextToSubgoal, ActionInfo, CloseSubgoal, DropSubgoal, PickupSubgoal, \
    FinishSubgoal, FindDropLocationSubgoal
from graph_creator2 import GraphCreator
from my_llm import client
from matplotlib import pyplot as plt

from mistake_handling_bot import RecoveryBot
import json


class Evaluator:
    def __init__(self):
        self.state = None
        self.env = None
        self.seed = 0

    def parse_goal(self, action, agent):
        action = action.split(" ")
        match action[0]:
            case "Pickup":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return PickupSubgoal(agent, datum=obj)
            case "Open":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return OpenSubgoal(agent, datum=obj)
            case "Close":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return CloseSubgoal(agent, datum=obj)
            case "Go":
                if len(action) == 3:
                    return FindDropLocationSubgoal(agent)
                color = action[3]
                color = None if color == "the" else color
                obj = ObjDesc(action[4], color)
                obj.find_matching_objs(self.env)
                return GoNextToSubgoal(agent, datum=obj)
            case "Drop":
                obj = ObjDesc(action[2], action[1])
                obj.find_matching_objs(self.env)
                return DropSubgoal(agent, datum=obj)
            case "Mission":
                return FinishSubgoal(agent)
            case _:
                obj = None
                return None


    def env_step(self, action):
        """ one action step using LM """

        act = self.parse_action(action)  # (action_index, target_object_pos)
        state, reward, terminated, truncated, info = self.env.env.env.step(act)
        done = terminated or truncated
        return state, reward, done, info

    def create_subgoal(self, action):
        act = self.parse_action(action)
        return act

    def get_llama_plan(self):
        mission = self.state["mission"]
        level = self.state["language"]
        prompt = f"Mission: {mission}\nLevel: {level}▼"
        print(prompt)
        result = ""
        response = client.generate_stream(prompt, max_new_tokens=100)
        for token in response:
            result += token.token.text
        plan = result.split("⑃")
        return plan

    def set_seed(self, seed):
        valid = False
        env, state = None, None
        env_name = "UnlockPickup"
        env = gym.make(f'BabyAI-{env_name}', render_mode='human')
        # env = gym.make(f'BabyAI-{env_name}')
        env = LanguageObsWrapper(env)

        state, info = env.reset(seed=seed)
        mission = state["mission"]
        self.seed = seed
        self.state = state
        self.env = env
        self.env.render()

    def get_seeds(self):
        with open("data/old_data/MiniBossLevel_train.jsonl", "r") as f:
            data = []
            for line in f:
                data.append(json.loads(line))
        seeds = []

        for entry in data:
            seeds.append(entry["seed"])
        return seeds[:100]
    def evaluate(self, env_name="MiniBossLevel", num_runs=100):
        success_count = 0
        results = []

        # seeds = self.get_seeds()
        seeds = [i for i in range(num_runs)]
        for seed in seeds:
            self.set_seed(seed)
            plan = self.get_llama_plan()
            for i in plan:
                print(i)

            success = self.test_plan(plan)
            results.append(success)
            if success:
                success_count += 1

        # Create a graphic showing the results
        plt.figure(figsize=(10, 6))
        plt.bar(range(len(results)), results, color='blue')
        plt.xlabel('Run')
        plt.ylabel('Success (1) / Failure (0)')
        plt.title(f'Successful Runs: {success_count}/{len(results)}')
        plt.show()



    def test_plan(self, plan):
        # create subgoals for each action in plan
        agent = RecoveryBot(self.env)
        subgoals = []
        for action in plan:
            goal = self.parse_goal(action, agent)
            if goal is not None:
                subgoals.append(goal)

        agent.stack = list(reversed(subgoals))
        done = False
        total_reward = 0
        action = None
        steps = 0
        while not done:
            print("Steps:", steps)
            action = agent.replan(action_taken=action)
            if isinstance(action, ActionInfo):
                if not action.success:
                    print(action.data)
                    # ask llama for recovery plan
                action = action.data

            observation, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            total_reward += reward
            self.env.render()

            if steps == 100:
                return False

        return True

    def check_plan_for_mistake(self, plan):
        # create subgoals for each action in plan
        agent = RecoveryBot(self.env)
        subgoals = []
        for action in plan:
            goal = self.parse_goal(action, agent)
            if goal is not None:
                subgoals.append(goal)

        agent.stack = list(reversed(subgoals))
        done = False
        total_reward = 0
        action = None
        steps = 0
        done_plan = None
        graph = GraphCreator()
        moves_to_mistake = None
        while not done:
            print("Steps:", steps)
            graph.add_stack(agent.stack)
            action = agent.replan(action_taken=action)
            if isinstance(action, ActionInfo):

                if not action.success:
                    # print(action.data)
                    moves_to_mistake = graph.subgoals_done() + 1
                    break
                action = action.data
            observation, reward, terminated, truncated, info = self.env.step(action)

            # print(terminated, truncated, len(agent.stack) == 0)
            done = (terminated or truncated) and len(agent.stack) == 0

            total_reward += reward
            self.env.render()


            if steps == 100:
                if done_plan is not None:
                    done_plan.append("Mistake: Mission not fulfilled")
                break
            steps += 1


        if moves_to_mistake is not None:
            done_plan = plan[:moves_to_mistake]
            done_plan.append(action.data)
        else:
            # throw Exception
            raise Exception("No mistake found in plan")
        return done_plan

    def set_env(self, env):
        self.env = env

if __name__ == "__main__":
    env_name = "UnlockPickup"
    evaluator = Evaluator()
    evaluator.evaluate()
    # evaluator.get_valid_env(env_name)
    # plan = ['Go next to green key','Pickup green key', 'Go next to green door', 'Open green door','Go somewhere else', 'Drop green key', 'Go next to purple box', 'Pickup purple box', 'Mission Complete']
    # plan = ["Go next to green door", "Open green door", "Go next to green box", "Pickup green box"]
    # evaluator.evaluate()
    # evaluator.test_plan(plan)

