from typing import Tuple, Dict, List

import math
import numpy as np

import logging

logger = logging.getLogger(__name__)


class GridWorld:
    """
    Creates a gridworld object to pass to an RL algorithm.

    Args:
        num_rows (int): Number of rows in the gridworld.
        num_cols (int): Number of columns in the gridworld.
        start_state (numpy array of shape (1,2)): The row and column position of the start state.
        goal_states (numpy array of shape (n,2)): The row and column position of the goal states.
            n is the number of goal states.
        restart_states (numpy array of shape (n,2)): The row and column position of the restart states.
            n is the number of restart states.
        obstructed_states (numpy array of shape (n,2)): The row and column position of the obstructed states.
            n is the number of obstructed states.
        bad_states (numpy array of shape (n,2)): The row and column position of the bad states.
            n is the number of bad states.
        gamma (float): Discount factor.
        p_good_transition (float): The probability that the agent successfully
            executes the intended action.
        transition_bias (float): The probability that the agent transitions to the left
            of the intended transition when the action is incorrectly executed.
        step_reward (float): The reward received for each step taken in the gridworld.
        goal_reward (float): The reward received for reaching a goal state.
        bad_state_reward (float): The reward received for reaching a bad state.
        restart_state_reward (float): The reward received for reaching a restart state.
    """

    def __init__(
        self,
        num_rows: int,
        num_cols: int,
        start_state: np.ndarray,
        goal_states: np.ndarray,
        restart_states: np.ndarray,
        obstructed_states: np.ndarray,
        bad_states: np.ndarray,
        gamma: float,
        p_good_transition: float,
        transition_bias: float,
        step_reward: float,
        goal_reward: float,
        bad_state_reward: float,
        restart_state_reward: float,
    ) -> None:
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.start_state = (
            np.array(list(start_state)) if start_state != "None" else None
        )
        self.goal_states = (
            np.array(list(goal_states)) if goal_states != "None" else None
        )
        self.restart_states = (
            np.array(list(restart_states)) if restart_states != "None" else None
        )
        self.obs_states = (
            np.array(list(obstructed_states)) if obstructed_states != "None" else None
        )
        self.bad_states = np.array(list(bad_states)) if bad_states != "None" else None
        self.num_bad_states = len(self.bad_states) if self.bad_states is not None else 0
        self.p_good_trans = p_good_transition
        self.bias = transition_bias
        self.r_step = step_reward
        self.r_goal = goal_reward
        self.r_bad = bad_state_reward
        self.r_restart = restart_state_reward
        self.gamma = gamma

        self._create_gridworld()

    def _create_gridworld(self) -> None:
        """Creates the gridworld object."""
        self.num_actions = 4
        self.num_states = self.num_cols * self.num_rows + 1
        self.start_state_seq = self._row_col_to_seq(self.start_state)
        self.goal_states_seq = self._row_col_to_seq(self.goal_states)

        # rewards structure
        self.R = self.r_step * np.ones((self.num_states, self.num_actions, 1))
        self.R[self.num_states - 1, :] = 0
        self.R[self.goal_states_seq, :] = self.r_goal
        if self.bad_states is not None:
            for i in range(self.num_bad_states):
                if self.r_bad is None:
                    raise Exception("Bad state specified but no reward is given")
                bad_state = self._row_col_to_seq(self.bad_states[i, :].reshape(1, -1))
                self.R[bad_state, :, :] = self.r_bad
        if self.restart_states is not None:
            for i in range(len(self.restart_states)):
                if self.r_restart is None:
                    raise Exception("Restart state specified but no reward is given")
                restart_state = self._row_col_to_seq(
                    self.restart_states[i, :].reshape(1, -1)
                )
                self.R[restart_state, :, :] = self.r_restart

        # probability model
        if self.p_good_trans == None:
            raise Exception(
                "Must assign probability and bias terms via the add_transition_probability method."
            )

        self.P = np.zeros((self.num_states, self.num_states, self.num_actions))
        for action in range(self.num_actions):
            for state in range(self.num_states):
                # check if state is the fictional end state - self transition
                if state == self.num_states - 1:
                    self.P[state, state, action] = 1
                    continue

                # check if the state is the goal state or an obstructed state - transition to end
                row_col = self._seq_to_col_row(state)

                if self.obs_states is not None:
                    end_states = np.vstack((self.obs_states, self.goal_states))
                else:
                    end_states = self.goal_states

                for dir in range(-1, 2, 1):
                    direction = self._get_direction(action, dir)
                    next_state = self._get_state(state, direction)
                    if dir == 0:
                        prob = self.p_good_trans
                    elif dir == -1:
                        prob = (1 - self.p_good_trans) * (self.bias)
                    elif dir == 1:
                        prob = (1 - self.p_good_trans) * (1 - self.bias)

                    self.P[state, next_state, action] += prob

                # make restart states transition back to the start state with
                # probability 1
                if self.restart_states is not None:
                    if any(np.sum(np.abs(self.restart_states - row_col), 1) == 0):
                        next_state = self._row_col_to_seq(self.start_state)
                        self.P[state, :, :] = 0
                        self.P[state, next_state, :] = 1

    def _get_direction(self, action: int, direction: int) -> int:
        """Takes is a direction and an action and returns a new direction.

        Args:
            action (int): The current action 0, 1, 2, 3 for gridworld.
            direction (int): The current direction -1, 0, 1 for left, forward, right.

        Returns:
            new_direction (int): The new direction.
        """
        left = [2, 3, 1, 0]
        right = [3, 2, 0, 1]
        if direction == 0:
            new_direction = action
        elif direction == -1:
            new_direction = left[action]
        elif direction == 1:
            new_direction = right[action]
        else:
            raise Exception("getDir received an unspecified case")
        return new_direction

    def _get_state(self, state: int, direction: int) -> int:
        """Get the next_state from the current state and a direction.

        Args:
            state (int): The current state.
            direction (int): The direction to move in.
        """
        row_change = [-1, 1, 0, 0]
        col_change = [0, 0, -1, 1]
        row_col = self._seq_to_col_row(state)
        row_col[0, 0] += row_change[direction]
        row_col[0, 1] += col_change[direction]

        # check for invalid states
        if self.obs_states is not None:
            if (
                np.any(row_col < 0)
                or np.any(row_col[:, 0] > self.num_rows - 1)
                or np.any(row_col[:, 1] > self.num_cols - 1)
                or np.any(np.sum(abs(self.obs_states - row_col), 1) == 0)
            ):
                next_state = state
            else:
                next_state = self._row_col_to_seq(row_col)[0]
        else:
            if (
                np.any(row_col < 0)
                or np.any(row_col[:, 0] > self.num_rows - 1)
                or np.any(row_col[:, 1] > self.num_cols - 1)
            ):
                next_state = state
            else:
                next_state = self._row_col_to_seq(row_col)[0]

        return next_state

    def _row_col_to_seq(self, row_col: np.array) -> np.ndarray:
        """Converts a row and column to a sequence number.

        Args:
            seq (np.array): A 2D array of row and column values.
            num_cols (int): The number of columns in the gridworld.

        Returns:
            np.array: A 1D array of sequence numbers.
        """
        return row_col[:, 0] * self.num_cols + row_col[:, 1]

    def _seq_to_col_row(self, seq: int) -> np.ndarray:
        """Converts a sequence number to a row and column.

        Args:
            row_col (np.array): A 1D array of sequence numbers.

        Returns:
            np.array: A 2D array of row and column values.
        """
        r = math.floor(seq / self.num_cols)
        c = seq - r * self.num_cols
        return np.array([[r, c]])

    def sample_next_state(
        self, state: int, action: int, num: int = 1
    ) -> Tuple[int, float]:
        """Samples the next state and reward given a state and action.

        Args:
            state (int): The current state.
            action (int): The current action.

        Returns:
            next_state (int): The next state.
            reward (float): The reward for the transition.
        """
        next_state = np.random.choice(
            self.num_states, p=self.P[state, :, action], size=num
        )
        reward = self.R[state, action][0]
        return next_state, reward

    def sample_trajectory(
        self, policy: int, horizon: int = 10
    ) -> List[Tuple[int, int]]:
        """Samples a trajectory from the environment.

        Args:
            policy (np.array): The policy to follow as a
                1D array from state to action.

        Returns:
            np.array: The trajectory.
        """
        trajectory = []
        state = self.start_state
        state = self._row_col_to_seq(state)
        state = int(state[0])
        for _ in range(horizon):
            action = int(policy[state][0])
            next_state, reward = self.sample_next_state(state, action)
            next_state = int(next_state[0])
            trajectory.append((state, action, reward, next_state))
            state = next_state

        return trajectory

    def parallel_sampling(
        self, num: int = 1
    ) -> Dict[Tuple[int, int], Tuple[int, float]]:
        """Samples one next state and reward for all states and actions.

        Returns:
            Dict[Tuple[int, int], Tuple[int, float]]: A dictionary of next states and rewards.
        """
        next_states = {}
        next_rewards = {}
        for s in range(self.num_states):
            for a in range(self.num_actions):
                next_state, next_reward = self.sample_next_state(s, a, num)
                next_states[(s, a)] = next_state
                next_rewards[(s, a)] = next_reward

        return next_states, next_rewards


class DummyEnv:
    """A dummy environment that stores parameters required for value iteration."""

    def __init__(
        self, env: GridWorld, approx_rew: np.ndarray, approx_P: np.ndarray
    ) -> None:
        self.num_states = env.num_states
        self.num_actions = env.num_actions
        self.gamma = env.gamma

        self.R = approx_rew
        self.P = approx_P
