from typing import Dict, List
import numpy as np

from DataGenerators.DataGenerator import DataGenerator

def cliff_walking_direction_optimal_policy(row, col, direction):
    """
    0 LEFT
    1 RIGHT
    2 MOVE FORWARD
    Args:
        row:
        col:
        direction:

    Returns:

    """
    if row == 3:
        if col == 0:
            ans_dict = {
                "UP": [2],
                "DOWN": [0, 1],
                "LEFT": [1],
                "RIGHT": [0]
            }
    elif row == 2:
        if col == 11:
            ans_dict = {
                "UP": [0, 1],
                "DOWN": [2],
                "LEFT": [0],
                "RIGHT": [1]
            }
        else:
            ans_dict = {
                "UP": [1],
                "DOWN": [0],
                "LEFT": [0, 1],
                "RIGHT": [2]
            }
    elif row == 1:
        if col == 11:
            ans_dict = {
                "UP": [0,1],
                "DOWN": [2],
                "LEFT": [0],
                "RIGHT": [1]
            }
        else:
            ans_dict = {
                "UP": [1],
                "DOWN": [0, 2],
                "LEFT": [0, 1],
                "RIGHT": [2]
            }
    elif row == 0:
        if col == 11:
            ans_dict = {
                "UP": [0,1],
                "DOWN": [2],
                "LEFT": [0],
                "RIGHT": [1]
            }
        else:
            ans_dict = {
                "UP": [1],
                "DOWN": [0, 2],
                "LEFT": [0,1],
                "RIGHT": [2]
            }

    return ans_dict[direction]



class CliffWalkingDirectionDataGenerator(DataGenerator):
    def __init__(self, env, type, distribution):
        super().__init__(env, type, distribution)

    def sample_array_of_states(self, number=48) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        if self.distribution == "traverse":
            return np.arange(0, 48, 1)
        elif self.distribution == "random":
            return np.random.choice(48, number, replace=False)
        elif self.distribution == "expert":
            return np.array([36, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 47])
    def sample_array_of_trajectories(self, number=48) -> List[Dict]:
        pass
    def sample_data(self, number=48) -> List[Dict]:
        ret_list = []
        if self.type == "binary_feedback":
            if self.distribution == "traverse":
                for i in range(36):
                    for direction in ["UP", "DOWN", "LEFT", "RIGHT"]:
                        for j in [0, 1, 2]:
                            t = (i//12, i%12, direction)
                            ret_list.append({"state":  t, "action": j, "feedback": self.get_expert_binary_feedback(t, j)})
        elif self.type == "preference":
            if self.distribution == "traverse":
                for i in range(36):
                    for direction in ["UP", "DOWN", "LEFT", "RIGHT"]:
                        for j in [0, 1, 2]:
                            for k in [0, 1, 2]:
                                t = (i//12, i%12, direction)
                                if j != k:
                                    ret_list.append({"state": t, "action1": j, "action2": k, "feedback": self.get_expert_preference(t, j, k)})
        elif self.type == "action_advising":
            if self.distribution == "traverse":
                for i in range(36):
                    for direction in ["UP", "DOWN", "LEFT", "RIGHT"]:
                        t = (i//12, i%12, direction)
                        ret_list.append({"state": t, "feedback": self.get_expert_action_advising(t)})

        return ret_list

    def get_expert_actions(self, state) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        return cliff_walking_direction_optimal_policy(state[0],  state[1], state[2])
    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def test_expert_policy(self):
        """
        This function tests the expert policy on the environment.
        :return:
        """
        state, _ = self.env.reset()
        done = False
        steps = 0
        while not done and steps < 100:
            actions = self.get_expert_actions(state)
            action = actions[0]
            next_state, _, done, _, _ = self.env.step(action)
            print(state, action, next_state)
            state = next_state
            steps += 1

    def get_expert_value(self, state) -> float:
        return NotImplemented

