import gym

from DataGenerators.DataGenerator import DataGenerator
import numpy as np
from typing import List, Union, Dict

def cliff_walking_get_next_state(state, action):
    r = state // 12
    c = state % 12
    # Double check
    if action == 0:
        #UP
        if r > 0:
            r -= 1
    elif action == 1:
        #RIGHT
        if c < 11:
            c += 1
    elif action == 2:
        #DOWN
        if r < 3:
            r += 1
    elif action == 3:
        #LEFT
        if c > 0:
            c -= 1
    return r * 12 + c


def state_to_row_col_list(state):
    return [state//12, state%12]


class CliffWalkingDataGenerator(DataGenerator):
    def __init__(self, env, type, distribution):
        super().__init__(env, type, distribution)

    def sample_array_of_states(self, number=48) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        if self.distribution == "traverse":
            return np.arange(0, 48, 1)
        elif self.distribution == "random":
            return np.random.choice(48, number, replace=False)
        elif self.distribution == "expert":
            return np.array([36, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 47])

    def sample_array_of_trajectories(self, number=48) -> List[Dict]:
        pass

    def sample_data_from_policy(self, number=100, cutoff_length=50) -> List[Dict]:
        ret_list = []
        env = gym.make("CliffWalking-v0")
        obs, _ = env.reset()
        done = False
        current_history = []
        assert isinstance(self.distribution, int) or isinstance(self.distribution, float)
        episode_step_count = 0
        for i in range(number):
            state = obs
            expert_actions = self.get_expert_actions(state)
            expert_action = expert_actions[0]
            random_action = np.random.random_integers(0, 3)

            if np.random.rand() < self.distribution:
                action_taken = expert_action
            else:
                action_taken = random_action

            next_state, r, done, _, _ = env.step(action_taken)
            episode_step_count += 1

            if self.type == "binary_feedback":
                ret_list.append({"state": state, "action": action_taken, "feedback": 1 if action_taken in expert_actions else -1, "history": current_history.copy()})
            elif self.type == "preference":
                good_action = np.random.choice(expert_actions)
                bad_action = np.random.choice([i for i in range(4) if i not in expert_actions])
                if np.random.rand() < 0.5:
                    ret_list.append({"state": state, "action1": good_action, "action2": bad_action, "feedback": 1, "history": current_history.copy()})
                else:
                    ret_list.append({"state": state, "action1": bad_action, "action2": good_action, "feedback": -1, "history": current_history.copy()})
            elif self.type == "action_advising":
                ret_list.append({"state": state, "feedback": expert_actions, "history": current_history.copy()})

            elif self.type == "goal_advising":
                expert_next_states = map(cliff_walking_get_next_state, [state] * len(expert_actions), expert_actions)
                expert_next_states = map(state_to_row_col_list, expert_next_states)
                ret_list.append({"state": state, "feedback": list(expert_next_states), "history": current_history.copy()})
                # print(state//12, state%12)
                # print(ret_list[-1]["feedback"])

            if next_state == 47:
                extra_info = "You reached the goal."
            elif r < -1:
                extra_info = "You fell into the holes."
            elif next_state == state:
                extra_info = "You hit the wall."
            else:
                extra_info = ""
            current_history.append({"state": state, "action": action_taken, "next_state": next_state, "done": done, "extra": extra_info})

            if episode_step_count >= cutoff_length:
                done = True

            if done:
                obs, _ = env.reset()
                current_history = []
            else:
                obs = next_state




        return ret_list

    def sample_data(self, number=48) -> List[Dict]:
        ret_list = []
        if self.type == "binary_feedback":
            if self.distribution == "traverse":
                for i in range(48):
                    for j in [0, 1, 2, 3]:
                        ret_list.append({"state": i, "action": j, "feedback": self.get_expert_binary_feedback(i, j)})
            elif self.distribution == "random":
                for i in range(number):
                    state = np.random.choice(48, 1)
                    action = np.random.choice(4, 1)
                    ret_list.append(
                        {"state": state, "action": action, "feedback": self.get_expert_binary_feedback(state, action)})
            elif self.distribution == "expert_states":
                expert_states = np.array([36, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 47])
                for state in expert_states:
                    for j in [0, 1, 2, 3]:
                        ret_list.append(
                            {"state": state, "action": j, "feedback": self.get_expert_binary_feedback(state, j)})
            elif self.distribution == "expert_trajectories":
                for i in range(48):
                    expert_actions = self.get_expert_actions(i)
                    ret_list.append({"state": i, "action": expert_actions[0], "feedback": 1})
        elif self.type == "preference":
            if self.distribution == "traverse":
                for i in range(48):
                    for j in [0, 1, 2, 3]:
                        for k in [0, 1, 2, 3]:
                            if j != k:
                                ret_list.append({"state": i, "action1": j, "action2": k,
                                                 "feedback": self.get_expert_preference(i, j, k)})
            elif self.distribution == "random":
                for i in range(number):
                    state = np.random.choice(48, 1)
                    actions = np.random.choice(4, 2, replace=False)
                    action1 = actions[0]
                    action2 = actions[1]
                    ret_list.append({"state": state, "action1": action1, "action2": action2,
                                     "feedback": self.get_expert_preference(state, action1, action2)})
            elif self.distribution == "expert_states":
                expert_states = np.array([36, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 47])
                for state in expert_states:
                    expert_actions = self.get_expert_actions(state)
                    for j in [0, 1, 2, 3]:
                        for k in [0, 1, 2, 3]:
                            if j != k:
                                ret_list.append({"state": state, "action1": j, "action2": k,
                                                 "feedback": self.get_expert_preference(state, j, k)})
        elif self.type == "action_advising":
            if self.distribution == "traverse":
                for i in range(48):
                    ret_list.append({"state": i, "feedback": self.get_expert_action_advising(i)})
            elif self.distribution == "random":
                for i in range(number):
                    state = np.random.choice(48, 1)
                    ret_list.append({"state": state, "feedback": self.get_expert_action_advising(state)})
            elif self.distribution == "expert_states":
                expert_states = np.array([36, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 47])
                for state in expert_states:
                    ret_list.append({"state": state, "feedback": self.get_expert_action_advising(state)})
        elif self.type == "goal_advising":
            if self.distribution == "traverse":
                for i in range(48):
                    expert_actions = self.get_expert_action_advising(i)
                    expert_next_states = map(cliff_walking_get_next_state, [i] * len(expert_actions), expert_actions)
                    expert_next_states = map(state_to_row_col_list, expert_next_states)

                    ret_list.append({"state": i, "feedback": list(expert_next_states)})
            elif self.distribution == "random":
                for i in range(number):
                    state = np.random.choice(48, 1)
                    expert_actions = self.get_expert_action_advisng(state)
                    expert_next_states = map(cliff_walking_get_next_state, [state] * len(expert_actions),
                                             expert_actions)
                    expert_next_states = map(state_to_row_col_list, expert_next_states)

                    ret_list.append({"state": state, "feedback": list(expert_next_states)})
            elif self.distribution == "expert_states":
                expert_states = np.array([36, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 47])
                for state in expert_states:
                    expert_actions = self.get_expert_action_advisng(state)
                    expert_next_states = map(cliff_walking_get_next_state, [state] * len(expert_actions),
                                             expert_actions)
                    expert_next_states = map(state_to_row_col_list, expert_next_states)
                    ret_list.append({"state": state, "feedback": list(expert_next_states)})

        return ret_list

    def get_expert_actions(self, state) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        row = state // 12
        col = state % 12
        if state == 36:
            return [0]
        elif row == 2 and col != 11:
            return [1]
        elif row < 2 and col != 11:
            return [1, 2]
        elif col == 11:
            return [2]
        else:
            return [2]

    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def test_expert_policy(self):
        """
        This function tests the expert policy on the environment.
        :return:
        """
        state, _ = self.env.reset()
        done = False
        steps = 0
        while not done and steps < 100:
            actions = self.get_expert_actions(state)
            action = actions[0]
            next_state, _, done, _, _ = self.env.step(action)
            print(state, action, next_state)
            state = next_state
            steps += 1

    def get_expert_value(self, state) -> float:
        return NotImplemented




if __name__ == "__main__":
    # for feedback_type in ["action_advising", "binary_feedback", "preference"]:
    #     cliffwalking_generator = CliffWalkingDataGenerator("CliffWalking-v0", feedback_type, j)
    #     data = cliffwalking_generator.sample_data(48)
    #     np.save(f"persistent_data/CliffWalking/cliffwalking_{feedback_type}_{j}.npy", data)

    from DataGenerators.CliffWalkingDataGenerator import CliffWalkingDataGenerator
    import numpy as np
    from config import PERSISTENT_DATA_PATH

    # for feedback_type in ["action_advising", "binary_feedback", "preference"]:
    #     # for feedback_type in ["action_advising"]:
    #     for dist in [0, 0.5, 1]:
    #         cliffwalking_generator = CliffWalkingDataGenerator("CliffWalking-v0", feedback_type, dist)
    #         data = cliffwalking_generator.sample_data_from_policy(200)
    #         np.save(f"{PERSISTENT_DATA_PATH}/CliffWalking/cliffwalking_{feedback_type}_{dist}.npy", data)


    for feedback_type in ["goal_advising"]:
        for dist in [0, 0.5, 1]:
            cliffwalking_generator = CliffWalkingDataGenerator("CliffWalking-v0", feedback_type, dist)
            data = cliffwalking_generator.sample_data_from_policy(200)
            np.save(f"{PERSISTENT_DATA_PATH}/CliffWalking/cliffwalking_{feedback_type}_{dist}.npy", data)
            #print(data)