import random

from simple_rl.tasks import GridWorldMDP
from simple_rl.tasks.grid_world.GridWorldStateClass import GridWorldState

class GridWorld(GridWorldMDP):
    """
    Tweaked version of GridWorldMDP from simple_rl.

    1) Reward when reaching goal is variable and set as an attribute
    """

    def __init__(
            self,
            width=5,
            height=3,
            init_loc=(1, 1),
            rand_init=False,
            goal_locs=[(5, 3)],
            lava_locs=[()],
            walls=[],
            is_goal_terminal=True,
            gamma=0.99,
            slip_prob=0.0,
            step_cost=0.0,
            lava_cost=0.01,
            goal_rewards=[1.],
            name="Grid-world"
    ):
        GridWorldMDP.__init__(
            self,
            width=width,
            height=height,
            init_loc=init_loc,
            rand_init=rand_init,
            goal_locs=goal_locs,
            lava_locs=lava_locs,
            walls=walls,
            is_goal_terminal=is_goal_terminal,
            gamma=gamma,
            slip_prob=slip_prob,
            step_cost=step_cost,
            lava_cost=lava_cost,
            name=name
        )
        self.goal_rewards = goal_rewards
        self.slip_unidirectional = True

    def transition(self, s, a):
        """
        Joint transition method.

        :param s: (GridWorldState) state
        :param a: (str) action
        :return: reward and resulting state (r, s_p)
        """

        if s.is_terminal():
            return 0., s
        
        if self.slip_prob > random.random():  # Flip direction
            if a == "up":
                a = random.choice(["left", "right"]) if self.slip_unidirectional else "right"
            elif a == "down":
                a = random.choice(["left", "right"]) if self.slip_unidirectional else "left"
            elif a == "left":
                a = random.choice(["up", "down"]) if self.slip_unidirectional else "up"
            elif a == "right":
                a = random.choice(["up", "down"]) if self.slip_unidirectional else "down"

        if a == "up" and s.y < self.height and not self.is_wall(s.x, s.y + 1):
            s_p = GridWorldState(s.x, s.y + 1)
        elif a == "down" and s.y > 1 and not self.is_wall(s.x, s.y - 1):
            s_p = GridWorldState(s.x, s.y - 1)
        elif a == "right" and s.x < self.width and not self.is_wall(s.x + 1, s.y):
            s_p = GridWorldState(s.x + 1, s.y)
        elif a == "left" and s.x > 1 and not self.is_wall(s.x - 1, s.y):
            s_p = GridWorldState(s.x - 1, s.y)
        else:
            s_p = GridWorldState(s.x, s.y)

        if (s_p.x, s_p.y) in self.goal_locs and self.is_goal_terminal:
            s_p.set_terminal(True)

        step_reward = self._calculate_dynamic_reward(s, s_p, a)

        if (s_p.x, s_p.y) in self.goal_locs:
            r = - self.step_cost
            for i in range(len(self.goal_locs)):
                if (s_p.x, s_p.y) == self.goal_locs[i]:
                    r +=self.goal_rewards[i] + step_reward
                    break
        elif (s_p.x, s_p.y) in self.lava_locs:
            r = 0. - self.lava_cost + step_reward
        else:
            r = 0. - self.step_cost + step_reward

        # print(r, s_p)

        return r, s_p
    
    def _calculate_dynamic_reward(self, s, s_p, a):
        """
        Calculate the dynamic reward based on the current state, next state and action.

        :param s: (GridWorldState) current state
        :param s_p: (GridWorldState) next state
        :param a: (str) action
        :return: (float) reward
        """
        reward = 0
        if self.goal_locs:
            proximity_rewards  = []
            for i in range(len(self.goal_locs)):
                proximity_rewards.append((self.goal_rewards[i]) / (abs(s_p.x - self.goal_locs[i][0]) + abs(s_p.y - self.goal_locs[i][1])+1))
            mean_proximity_reward = sum(proximity_rewards) / len(proximity_rewards)
            # print(mean_proximity_reward)
        else:
            mean_proximity_reward = 0.
        return reward + mean_proximity_reward
    
    def _reward_func(self, state, action):
        raise ValueError('Method _reward_func not implemented in this Grid-world version, see transition method.')

    def _transition_func(self, state, action):
        raise ValueError('Method _transition_func not implemented in this Grid-world version, see transition method.')
    
    def states(self):
        """
        Return a list of the states of the environment.
        :return: list of states
        """
        states = []
        for i in range(1, self.width + 1):
            for j in range(1, self.height + 1):
                s = GridWorldState(i, j)
                if self.is_goal_terminal and (i, j) in self.goal_locs:
                    s.set_terminal(True)
                states.append(s)
        return states
    
    def get_states(self):
        """
        Return a list of the states of the environment.
        :return: list of states
        """
        return self.states()
    
    def get_actions(self):
        """
        Return a list of the actions of the environment.
        :return: list of actions
        """
        return ["up", "down", "left", "right"]

    def get_transtions(self):
        """
        Returns a map for each (s, a) pair with the reward and transition probabilities.
        :return: dict
        """
        transitions = {}
        actions = self.get_actions()
        states = self.get_states()

        for s in states:
            if s.is_terminal():
                continue

            for a in actions:
                transitions[(s, a)] = []

                # Compute the primary next state and reward.
                r, primary_s_p = self.transition(s, a)

                # Add the primary transition.
                transitions[(s, a)].append((primary_s_p, r, 1 - self.slip_prob))

                # Handle slip transitions if slip_prob > 0.
                if self.slip_prob > 0:
                    slip_states = []

                    if a in ["up", "down"]:
                        slip_states = [
                            self.transition(s, "left" if self.slip_unidirectional else "right"),
                            self.transition(s, "right" if self.slip_unidirectional else "left"),
                        ]
                    elif a in ["left", "right"]:
                        slip_states = [
                            self.transition(s, "up" if self.slip_unidirectional else "down"),
                            self.transition(s, "down" if self.slip_unidirectional else "up"),
                        ]

                    for slip_r, slip_s_p in slip_states:
                        transitions[(s, a)].append((slip_s_p, slip_r, self.slip_prob / len(slip_states)))

        return transitions