import random
from typing import List, Iterable
from jpype import JArray, JString

from pipelines.prompta.agent.skill_manager import SkillManager
from prompta.core.alphabet.events import is_event
from prompta.core.alphabet.minecraft_alphabet import MinecraftAlphabet
from prompta.core.alphabet.verbs import VERBS
from prompta.core.env import MinecraftSimpleEnv
from prompta.core.language import BaseLanguage
from prompta.utils.java_libs import Word, DefaultQuery



class MinecraftSimpleLanguage(BaseLanguage):
    def __init__(self, whole_cfg) -> None:
        self.env = MinecraftSimpleEnv(whole_cfg)
        self.skill_manager = SkillManager(whole_cfg)
        self._alphabet = MinecraftAlphabet(self.skill_manager)
        self.whole_cfg = whole_cfg

    def in_language(self, word: Iterable[str]):
        obs = self.env.reset()
        for t, symbol in enumerate(word):
            assert symbol in self._alphabet, f"[Error] {symbol} not in {self._alphabet}"

            skill_stack = [] if symbol not in self.skill_manager else [symbol]
            while len(skill_stack):
                curr_symbol, symbol_is_action = self.skill_manager.skills[skill_stack[-1]].next_symbol
                while self._alphabet.is_skill(curr_symbol) and curr_symbol not in skill_stack:
                    skill_stack.append(curr_symbol)
                    curr_symbol, symbol_is_action = self.skill_manager.skills[skill_stack[-1]].next_symbol
                if symbol_is_action:
                    symbol_instance = self._alphabet.get_verb(curr_symbol)
                    obs, rew, done, _ = self.env.step(symbol_instance.to_code(curr_symbol))
                else:
                    real_events = self.env.event_logger.last_events_list
                    event_found = False
                    for real_event in real_events:
                        if self.skill_manager.skills[skill_stack[-1]].events[curr_symbol].accept(real_event):
                            event_found = True
                            break
                    if not event_found:
                        curr_symbol = "NOOP"
                print(curr_symbol, skill_stack)
                success, done = self.skill_manager.skills[skill_stack[-1]].step(curr_symbol)
                while done:
                    skill_done = skill_stack.pop(-1)
                    if len(skill_stack) == 0:
                        return done and success
                    curr_symbol = skill_stack[-1]
                    success, done = self.skill_manager.skills[curr_symbol].step(skill_done)
                    print("Finished", skill_stack, done, success)

            self.env.step(self.skill_manager[symbol].to_code(symbol))
        
        return True

    def counterexample(self, aut, _type=str):
        states = aut.getStates()
        valid_aut = False
        for s in states:
            if aut.isAccepting(s):
                valid_aut = True
                break
        if not valid_aut:
            if _type == str:
                return self.pos_example
            return DefaultQuery(Word.fromArray(JArray(JString)(list(self.pos_example))), True)
        
        max_steps = 100
        num_trajectories = 100

        for _ in range(num_trajectories):
            state = self.env.reset()
            aut_state = aut.getInitialState()
            trajectory = []

            for _ in range(max_steps):
                action = self.select_action(aut, aut_state, state)
                next_state, _, done, _ = self.env.step(action)
                trajectory.append(action)

                word = Word.fromArray(JArray(JString)(trajectory), 0, len(trajectory))
                if aut.accepts(word) != done:
                    ce = trajectory
                    print("Counterexample found:", DefaultQuery(word, done))

                    if _type == str:
                        return tuple(ce)
                    return DefaultQuery(word, done)
        
        return None

    def select_action(self, aut, aut_state, env_state):
        transitions = aut.getTransitions(aut_state)
        valid_actions = [str(_.getInput()) for _ in transitions]
        if 'Enough' in valid_actions[0]:
            if 'IsEnough' in valid_actions[0]:
                pos_idx = 0
            else:
                pos_idx = 1

            chk_func = self.env.leanred_skills[valid_actions[0]]
            if chk_func(env_state['inventory']):
                next_aut_state = transitions[pos_idx].getTarget()
            else:
                next_aut_state = transitions[(pos_idx + 1) % 2].getTarget()
            transitions = aut.getTransitions(next_aut_state)
            valid_actions = [str(_.getInput()) for _ in transitions]
        return random.choice(valid_actions)

    @property
    def alphabet(self):
        return self._alphabet
    
    @property
    def definition(self):
        return f"""A player is trying to complete a task in Minecraft. The task is about getting a {self.env.goal}. The player can call the pre-defined APIs to interact with the world. The valid APIs are: {self.env.action_set}. A valid word in this language is defined by the following criteria: it consists solely of sequences that can help the player complete the task but do not contain any unnecessary action."""
    
    @property
    def examples(self):
        return {
            'pos':{
                'query': f"Given a action sequence {self.pos_example}, does this sequence belongs to the language?",
                'answer': str({'reason': 'It does not contain any unnecessary action and reached the goal.', 'answer': True})
                },
            'neg': {
                'query': f"Given a action sequence {self.neg_example}, does this sequence belongs to the language?",
                'answer': str({'reason': 'It contains wrong action and does not reach the goal.', 'answer': False})
            }
        }
