import numpy as np


class ValueIterationSolver:
    def __init__(self,
                 reward_matrix,
                 terminal_matrix,
                 transition_matrix,
                 only_state_reward,
                 gamma=1.0,
                 early_stop=False,
                 ):
        """
        Value Iteration Solver.
        Args:
            reward_matrix: reward function. Maybe the shape of [S,] (state only) or [S, A] (state-action)
            terminal_matrix: terminal function. Shape of [S,].
            transition_matrix: transition matrix. Shape of [S, A, S]
            only_state_reward: whether the reward only depends on state.
            gamma: discount factor.
            early_stop: whether assign 0 reward for terminal state.
        """
        self._terminal_matrix = terminal_matrix
        self._T = transition_matrix
        self._discount = gamma
        self._size, self._num_actions = self._T.shape[:2]
        self._early_stop = early_stop

        if only_state_reward:
            # [S, ] -> [S A]
            self._reward_matrix = np.repeat(reward_matrix[..., None], axis=-1, repeats=self._num_actions)
        else:
            # shape of [S, A]
            self._reward_matrix = reward_matrix

        assert self._reward_matrix.shape == (self._size, self._num_actions)
        assert self._terminal_matrix.shape == (self._size, )

    def run_value_iterations(self, iterations, discount=None):
        if discount is None:
            discount = self._discount
        utility_grids, policy_grids = self._init_utility_policy_storage(iterations)

        utility_grid = np.zeros(self._size, dtype=utility_grids.dtype)
        for i in range(iterations):
            utility_grid = self._value_iteration(utility_grid=utility_grid, discount=discount)
            policy_grids[..., i] = self._get_best_policy(utility_grid, discount=discount)
            utility_grids[..., i] = utility_grid
        return policy_grids, utility_grids

    def _init_utility_policy_storage(self, depth):
        utility_grids = np.zeros((self._size, depth))
        policy_grids = np.zeros_like(utility_grids)
        return utility_grids, policy_grids

    def _value_iteration(self, utility_grid, discount=1.0):
        value_out = np.zeros_like(utility_grid)
        for i in range(self._size):
            value_out[i] = self._calculate_utility(i, discount, utility_grid)
        return value_out

    def _calculate_utility(self, index, discount, utility_grid):
        if self._early_stop and self._terminal_matrix[index]:
            assert np.all(self._reward_matrix[index, 0] == self._reward_matrix[index])
            return self._reward_matrix[index, 0]
        assert utility_grid.shape == (self._size, )

        best_utility = np.max(
            self._reward_matrix[index] + discount * np.sum(
                self._T[index, :, :] * utility_grid,
                axis=-1)
        )
        return best_utility

    def _get_best_policy(self, utility_grid, discount):
        out = np.argmax(np.round(
            self._reward_matrix + discount * (
                    utility_grid.reshape((1, 1, self._size)) * self._T
            ).sum(axis=-1), decimals=4), axis=-1)
        return out


class FutureDistributionSolver:
    def __init__(self,
                 transition_matrix,
                 initial_distribution,
                 gamma=0.99):
        self._T = transition_matrix
        self._initial_distribution = initial_distribution
        self._discount_factor = gamma

        self._size, self._num_actions = transition_matrix.shape[:2]

    def get_state_distribution(self, policy):
        policy = self._transform_policy(policy)
        state_transition = self._get_state_transition(policy)

        state_dist = np.linalg.pinv(np.eye(self._size) - self._discount_factor * state_transition)
        state_dist = (1 - self._discount_factor) * np.matmul(self._initial_distribution.T, state_dist)

        assert state_dist.shape == (self._size, )
        return state_dist

    def get_state_action_distribution(self, policy):
        state_dist = self.get_state_distribution(policy)
        state_dist = state_dist.reshape([self._size, 1])

        policy = self._transform_policy(policy)    # [S, A]
        state_action_dist = state_dist * policy
        return state_action_dist

    def _get_state_transition(self, policy):
        transition_probability = np.zeros([self._size, self._size])
        assert policy.shape == (self._size, self._num_actions)

        for i in range(self._size):
            transition_probability[i] = np.matmul(policy[i], self._T[i])
        assert np.all(np.sum(transition_probability, axis=-1) == 1.)
        return transition_probability

    def _transform_policy(self, policy):
        if policy.shape == (self._size, ):
            policy_ = np.zeros([self._size, self._num_actions], dtype=policy.dtype)
            policy_[np.arange(self._size), policy] = 1.
        else:
            policy_ = policy.copy()
        assert policy_.shape == (self._size, self._num_actions)
        assert np.all(np.sum(policy_, axis=-1) == 1.)
        return policy_


def test_vi_solver():
    from envs.Grid.gridworld import GridWorldMDP, GridWorldSolver

    np.set_printoptions(precision=2)

    shape = (6, 10)
    goal = (-1, -1)
    trap = (-1, -2)
    obstacle = (0, 1)
    start = (0, 0)
    default_reward = -0.1
    goal_reward = 1
    trap_reward = -1

    reward_grid = np.zeros(shape) + default_reward
    reward_grid[goal] = goal_reward
    reward_grid[obstacle] = 0
    reward_grid[trap] = trap_reward

    terminal_mask = np.zeros_like(reward_grid, dtype=np.bool)
    terminal_mask[goal] = True
    terminal_mask[trap] = True

    obstacle_mask = np.zeros_like(reward_grid, dtype=np.bool)
    obstacle_mask[obstacle] = True

    gw = GridWorldMDP(reward_grid=reward_grid,
                      obstacle_mask=obstacle_mask,
                      terminal_mask=terminal_mask,
                      action_noise_probability=0.1,
                      no_action_probability=0.0,
                      max_episode_steps=100,
                      start_state=start)
    solver = GridWorldSolver(
        reward_grid=reward_grid,
        terminal_mask=terminal_mask,
        transition_probability=gw.transition_probability,
        only_state_reward=True,
        gamma=1.,
        early_stop=True,
    )

    my_solver = ValueIterationSolver(
        reward_matrix=gw.get_reward_matrix(),
        transition_matrix=gw.get_transition_matrix(),
        terminal_matrix=gw.get_terminal_matrix(),
        only_state_reward=True,
        gamma=1.,
        early_stop=True
    )

    policy_real, value_real = solver.run_value_iterations(iterations=10)
    policy_ref, value_ref = my_solver.run_value_iterations(iterations=10)

    np.testing.assert_allclose(value_ref, value_real.reshape(-1, 10))
    np.testing.assert_allclose(policy_ref, policy_real.reshape(-1, 10))


def test_distribution_solver():
    # [0, 1, 2]
    transition_matrix = [
        # state 0
        [   # left
            [1, 0, 0],
            # right
            [1, 0, 0]
        ],
        # state 1
        [
            # left
            [1, 0, 0],
            # right
            [0, 0, 1]
        ],
        # state 2
        [
            # left
            [0, 0, 1],
            # right
            [0, 0, 1]
        ]
    ]

    transition_matrix = np.array(transition_matrix, dtype=np.float32)

    initial_distribution = [0, 1, 0]
    initial_distribution = np.array(initial_distribution, dtype=np.float32)
    solver = FutureDistributionSolver(
        transition_matrix=transition_matrix,
        initial_distribution=initial_distribution,
        gamma=0.9
    )

    policy = np.array([0, 1, 0], dtype=np.int32)
    state_dist = solver.get_state_distribution(policy)
    state_action_dist = solver.get_state_action_distribution(policy)
    print("# state distribution:", state_dist)
    print("# state-action distribution:", state_action_dist)


if __name__ == "__main__":
    test_distribution_solver()


