from hcraft.examples import MineHcraftEnv
from hcraft.examples.minecraft.items import DIAMOND
from hcraft.task import GetItemTask
from tqdm import tqdm

from typing import List, Dict
import numpy as np

from DataGenerators.DataGenerator import DataGenerator
from config import MY_DEBUG, PERSISTENT_DATA_PATH


def get_masked_action_list(env):
    all_actions = [i.name for i in env.world.transformations]
    is_legal = env.action_masks()
    legal_actions = [all_actions[i] for i in range(len(all_actions)) if is_legal[i]]
    legal_actions_index = [i for i in range(len(all_actions)) if is_legal[i]]
    return legal_actions, legal_actions_index

def hierachy_craft_get_obs_string(env):
    """

    The state of every HierarchyCraft environment is composed of three parts:
    * The player's inventory: `state.player_inventory`
    * The one-hot encoded player's position: `state.position`
    * All zones inventories: `state.zones_inventories`
    :param env:
    :return: obs string
    """
    template = """
You are in POSITION.
Items in your current zone are:
ZONE
You have the following items in your inventory:
INVENTORY
You can take these actions:
POSSIBLEACTION
    """
    state = env.state
    zones = state.world.zones
    items = state.world.items
    observation = state.observation
    player_inventory = state.player_inventory
    player_position = state.position
    position_index = np.argmax(player_position)
    position_str = zones[position_index].name
    template = template.replace("POSITION", position_str)
    inventory_str = ""
    # print(state.player_inventory)
    for index, item in enumerate(env.world.items):
        # print(index, item)
        if player_inventory[index] > 0:
            inventory_str += f"{item.name}: {player_inventory[index]}\n"

    if inventory_str == "":
        inventory_str = "Nothing"
    template = template.replace("INVENTORY", inventory_str)
    zone_inventory_str = ""
    current_zone_inventory_dict = env.state.zones_inventories_dict[env.state.current_zone]
    for key in (current_zone_inventory_dict):
        zone_inventory_str += "\n" + key.name + ": " + str(current_zone_inventory_dict[key])

    if zone_inventory_str == "":
        zone_inventory_str = "Nothing"

    legal_actions, legal_actions_index = get_masked_action_list(env)
    possible_actions_str = "\n".join(legal_actions)
    template = template.replace("ZONE", zone_inventory_str)
    template = template.replace("POSSIBLEACTION", possible_actions_str)
    # print(template)
    return template


def get_optimal_actions_list(trajectory, env, solving_bahaviour, max_step=100):
    legal_actions = None
    step_count_list = []
    minimum_step_count = max_step + 1
    for action_t in range(201):
        step_count = 0
        observation = env.reset()
        if legal_actions is not None and action_t >= len(legal_actions):
            break
        while step_count < max_step:
            if step_count < len(trajectory):
                action_taken = trajectory[step_count]
            elif step_count == len(trajectory):
                # print("Reached OBS", observation)
                if legal_actions is None:
                    legal_actions = np.nonzero(env.action_masks())[0]
                action_taken = legal_actions[action_t]
            else:
                # print(observation)
                action_taken = solving_bahaviour(observation)

            observation, _reward, done, _, info = env.step(action_taken)
            step_count += 1
            if done:
                break

        step_count_list.append([legal_actions[action_t], step_count])
        if step_count < minimum_step_count:
            minimum_step_count = step_count

    if MY_DEBUG:
        for i in step_count_list:
            print("Action: ", env.world.transformations[i[0]].name, "Step: ", i[1])


    optimal_actions = [action for action, step in step_count_list if step == minimum_step_count]
    return optimal_actions





class HierachyCraftDataGenerator(DataGenerator):
    def __init__(self, env="HierachyCraft", type="binary_feedback", distribution=1):

        """

        :param env:
        :param type:
        :param distribution:  1 full expert 0 fully random
        :param expert_path:
        """
        # super().__init__(env, type, distribution)
        get_diamond = GetItemTask(DIAMOND)
        self.goal = get_diamond
        self.env = MineHcraftEnv(purpose=get_diamond, max_step=100)
        self.solving_behavior = self.env.solving_behavior(get_diamond)
        self.type = type
        self.distribution = distribution

        self.current_trajectory = []
        self.test_env = None
        self.test_solving_behaviour = None

    def sample_array_of_states(self, number=48, cutoff_length=50, seed_list=[]) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        expert_action_percentage = self.distribution
        state_list = []
        episode_count = 0
        # seed_count = 0
        obs = self.env.reset()
        for i in range(number):

            matrix_state = hierachy_craft_get_obs_string(self.env)
            expert_actions = self.get_expert_actions(obs)
            expert_action = expert_actions[0]
            # With action mask
            legal_actions, legal_actions_index = get_masked_action_list(self.env)

            random_action = np.random.choice(legal_actions_index)
            if np.random.rand() < expert_action_percentage:
                action_taken = expert_action
            else:
                action_taken = random_action

            obs_dict = {"Observation": obs, "ObsString": matrix_state, "expert_actions": expert_actions,
                        "possible_actions_index": legal_actions_index, "possible_actions": legal_actions}

            state_list.append(obs_dict)
            next_obs, reward, done, _, _ = self.env.step(action_taken)
            # print("Action: ", action_taken)
            # print(self.env.state.player_inventory)
            if done or episode_count > cutoff_length:
                obs = self.env.reset()
                episode_count = 0

            else:
                obs = next_obs
        self.state_array = np.asarray(state_list)
        return self.state_array

    def sample_array_of_trajectories(self, number=48) -> List[Dict]:
        pass

    def sample_data(self, number=48, cutoff_length=100) -> List[Dict]:
        ret_list = []
        total_count = 0
        progress_bar = tqdm(total=number, desc="Processing")
        current_trajectory = []
        test_env = MineHcraftEnv(purpose=self.goal, max_step=cutoff_length)
        self.test_env = test_env
        self.current_trajectory = current_trajectory
        self.test_solving_behaviour = test_env.solving_behavior(self.goal)
        while total_count < number:
            obs = self.env.reset()
            obs = obs[0]
            self.test_env.reset()
            # print("New Episode!")
            self.current_trajectory = []
            history = [] # stores the previous state-action-next_state dict
            for j in range(cutoff_length):

                str_state = hierachy_craft_get_obs_string(self.env)

                expert_actions = self.get_expert_actions(obs)
                # expert_actions = get_optimal_actions_list(trajectory=current_trajectory, env=test_env, solving_behavior=test_env.solving_behavior)

                expert_action = expert_actions[0]
                # if expert_action not in expert_actions:
                #     print("Weird")
                #     print([self.env.world.transformations[i].name for i in expert_actions])
                #     print(expert_actions)
                #     print(self.env.world.transformations[expert_action].name)
                #     print(expert_action)


                # print("Current OBS: ", obs)

                # With action mask
                legal_actions, legal_actions_index = get_masked_action_list(self.env)

                random_action = np.random.choice(legal_actions_index)
                if np.random.rand() < self.distribution:
                    action_taken = expert_action
                else:
                    action_taken = random_action



                if self.type == "binary_feedback":
                    for k in legal_actions_index:
                        ret_list.append(
                            {"state": str_state, "action": k, "feedback": self.get_expert_binary_feedback(obs, k),
                             "expert_actions": expert_actions, "possible_actions_index": legal_actions_index,
                             "possible_actions": legal_actions, 
                             "history": history.copy()})
                elif self.type == "preference":
                    for l in legal_actions_index:
                        for k in legal_actions_index:
                            if l != k:
                                # In order to save some computation here we do not wanna call get_expert_actions mutiple times
                                # expert_actions = self.get_expert_actions(obs)
                                if l in expert_actions and k in expert_actions:
                                    preference = 0
                                elif l in expert_actions and k not in expert_actions:
                                    preference = 1
                                elif l not in expert_actions and k in expert_actions:
                                    preference = -1
                                else:
                                    preference = 0

                                ret_list.append({"state": str_state, "action1": l, "action2": k,
                                                 "feedback": preference,
                                                 "expert_actions": expert_actions,
                                                 "possible_actions_index": legal_actions_index,
                                                 "possible_actions": legal_actions,
                                                 "history": history.copy()})
                                # ret_list.append({"state": str_state, "action1": l, "action2": k,
                                #                  "feedback": self.get_expert_preference(obs, l, k),
                                #                  "expert_actions": expert_actions,
                                #                  "possible_actions_index": legal_actions_index,
                                #                  "possible_actions": legal_actions})
                elif self.type == "action_advising":
                    ret_list.append({"state": str_state, "feedback": expert_actions, "expert_actions": expert_actions,
                                     "possible_actions_index": legal_actions_index, "possible_actions": legal_actions,
                                     "history": history.copy()})

                total_count += 1
                progress_bar.update(1)
                history_dict = {'state': str_state, "original_state": obs, 'action': action_taken, "legal_actions": legal_actions_index,
                                "extra": ""}

                obs, reward, done, _, info = self.env.step(action_taken)
                history_dict["done"] = done
                history_dict['next_state'] = hierachy_craft_get_obs_string(self.env)
                history_dict['original_next_state'] = obs
                history.append(history_dict)
                self.current_trajectory.append(action_taken)


                if done or total_count >= number or j >= cutoff_length:
                    break

        return ret_list

    def get_expert_actions(self, state) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """

        # return [self.solving_behavior(state)]
        return get_optimal_actions_list(trajectory=self.current_trajectory, env=self.test_env, solving_bahaviour=self.test_solving_behaviour)

    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def get_expert_value(self, state) -> float:
        return NotImplemented



if __name__ == "__main__":

    for i in [(0, "action_advising"), (0.5, "action_advising"), (1, "action_advising")]:
        mgdg = HierachyCraftDataGenerator("HierachyCraft", i[1], i[0])
        data = mgdg.sample_data(1000)
        np.save(f"{PERSISTENT_DATA_PATH}/HierachyCraft/HierachyCraft{i[1]}_{i[0]}.npy", data)

    for i in [(1, "binary_feedback"), (0.5, "binary_feedback"), (0, "binary_feedback")]:
        mgdg = HierachyCraftDataGenerator("HierachyCraft", i[1], i[0])
        data = mgdg.sample_data(1000)
        np.save(f"{PERSISTENT_DATA_PATH}/HierachyCraft/HierachyCraft{i[1]}_{i[0]}.npy", data)


    for i in [(0, "preference"), (0.5, "preference"), (1, "preference")]:
        mgdg = HierachyCraftDataGenerator("HierachyCraft", i[1], i[0])
        data = mgdg.sample_data(1000)
        np.save(f"{PERSISTENT_DATA_PATH}/HierachyCraft/HierachyCraft{i[1]}_{i[0]}.npy", data)




    # for i in [(1, "action_advising")]:
    #     mgdg = HierachyCraftDataGenerator("HierachyCraft", i[1], i[0])
    #     data = mgdg.sample_data(100)
    #     np.save(f"persistent_data/HierachyCraft/HierachyCraft{i[1]}_{i[0]}.npy", data)
