import gym
import numpy as np
from utils.envs.test_env import test_policy_evaluation, test_occupancy_measure
from envs.tabular_env import TabularEnv
from envs.dis_tabular_env import DiscountedTabularEnv


class CliffWalking(TabularEnv):

    def __init__(self, num_state: int, num_action: int, initial_state_dis: np.ndarray, max_episode_steps: int,
                 early_stop=False) -> None:

        # set the state ``s-1'' as the bad state and the action ``0'' as the optimal action

        self._bad_state_idx = num_state - 1
        self._opt_action_idx = np.random.randint(num_action)

        # r(non_absorbing, optimal_act) = 1.0

        reward_vec = np.zeros(shape=[num_state, num_action], dtype=np.float32)
        reward_vec[0: -1, self._opt_action_idx] = 1.0
        transition_matrix = self._create_transition_matrix(num_state, num_action, initial_state_dis)
        super(CliffWalking, self).__init__(num_state, num_action, max_episode_steps, initial_state_dis,
                                           reward_vec, transition_matrix, early_stop)
        self.reset()

    def generate_experience(self, current_state_idx, action_idx):

        next_state_dis = self._T[current_state_idx, action_idx, :]
        next_state_idx = np.random.choice(a=self._ns, p=next_state_dis)
        reward = self.reward_vec[current_state_idx, action_idx]
        terminal = True if next_state_idx == self._bad_state_idx else False

        return (next_state_idx,
                reward,
                terminal)

    def get_optimal_policy(self):

        """ get the optimal policy
        Returns:
            optimal_policy: the optimal policy, numpy array with shape [ns, na, H]
        """
        M, N, H = self._ns, self._na, self._max_episode_steps
        optimal_policy = np.zeros(shape=[M, N, H], dtype=np.float32)
        action_dis = np.zeros(shape=[self._na], dtype=np.float32)
        action_dis[self._opt_action_idx] = 1.0
        optimal_policy[:, :, :] = np.reshape(np.tile(action_dis, (M, 1)), (M, N, 1))

        return optimal_policy

    def compute_policy_value(self):
        raise NotImplementedError

    def _check_init_state_dis(self, state_dis: np.ndarray):

        is_valid = state_dis.shape[0] == self._ns and np.isclose(np.sum(state_dis), 1.0) \
                   and np.isclose(state_dis[self._bad_state_idx], 0.0)
        return is_valid

    def _create_transition_matrix(self, num_state: int, num_action: int, init_state_dis: np.ndarray):
        """
        Create the transition matrix, a numpy array with shape [ns, na, ns].
        """
        ns, na = num_state, num_action
        transition_prob = np.zeros(shape=[ns, na, ns], dtype=np.float32)
        # the next state distribution induced by the optimal action and other actions
        next_state_dis_by_opt = init_state_dis
        next_state_dis_by_other = np.zeros(shape=[ns], dtype=np.float32)
        next_state_dis_by_other[self._bad_state_idx] = 1.0

        for state in range(ns):
            if state == self._bad_state_idx:
                tmp_state_dis = np.tile(next_state_dis_by_other, (na, 1))
                transition_prob[state, :, :] = tmp_state_dis
            else:
                tmp_state_dis = np.tile(next_state_dis_by_other, (na, 1))
                transition_prob[state, :, :] = tmp_state_dis
                transition_prob[state, self._opt_action_idx, :] = next_state_dis_by_opt

        return transition_prob

    def render(self, mode='human'):
        raise NotImplementedError


class DisCliffWalking(DiscountedTabularEnv):

    def __init__(self, num_state: int, num_action: int, gamma: float, initial_state_dis: np.ndarray,
                 optimal_action: int, early_stop=False) -> None:

        # set the state ``s-1'' as the bad state and the action ``0'' as the optimal action

        self._bad_state_idx = num_state - 1
        assert optimal_action < num_action, 'The optimal action is invalid.'
        self._opt_action_idx = optimal_action

        reward_mat = np.zeros(shape=(num_state, num_action), dtype=np.float32)
        reward_mat[: num_state-1, self._opt_action_idx] = 1.0
        transition_matrix = self._create_transition_matrix(num_state, num_action, initial_state_dis)
        super(DisCliffWalking, self).__init__(num_state, num_action, gamma, initial_state_dis, reward_mat,
                                              transition_matrix, early_stop)
        self.reset()

    def generate_experience(self, current_state_idx, action_idx):

        next_state_dis = self._T[current_state_idx, action_idx, :]
        next_state_idx = np.random.choice(a=self._ns, p=next_state_dis)
        reward = self._reward_mat[current_state_idx, action_idx]
        terminal = True if next_state_idx == self._bad_state_idx else False

        return (next_state_idx,
                reward,
                terminal)

    def _create_transition_matrix(self, num_state: int, num_action: int, init_state_dis: np.ndarray):
        """
        Create the transition matrix, a numpy array with shape [ns, na, ns].
        """
        ns, na = num_state, num_action
        transition_prob = np.zeros(shape=[ns, na, ns], dtype=np.float32)
        # the next state distribution induced by the optimal action and other actions
        next_state_dis_by_opt = init_state_dis
        next_state_dis_by_other = np.zeros(shape=[ns], dtype=np.float32)
        next_state_dis_by_other[self._bad_state_idx] = 1.0

        for state in range(ns):
            if state == self._bad_state_idx:
                tmp_state_dis = np.tile(next_state_dis_by_other, (na, 1))
                transition_prob[state, :, :] = tmp_state_dis
            else:
                tmp_state_dis = np.tile(next_state_dis_by_other, (na, 1))
                transition_prob[state, :, :] = tmp_state_dis
                transition_prob[state, self._opt_action_idx, :] = next_state_dis_by_opt

        return transition_prob

    def get_optimal_policy(self):
        """
        Return the stationary optimal policy, a numpy array with shape [ns, na].
        """
        M, N = self._ns, self._na
        action_dis = np.zeros(shape=[self._na], dtype=np.float32)
        action_dis[self._opt_action_idx] = 1.0
        optimal_policy = np.tile(action_dis, reps=(M, 1))
        return optimal_policy

    def render(self, mode='human'):
        raise NotImplementedError

    def compute_policy_value(self):
        raise NotImplementedError

    def _check_init_state_dis(self, state_dis: np.ndarray):

        is_valid = state_dis.shape[0] == self._ns and np.isclose(np.sum(state_dis), 1.0) \
                   and np.isclose(state_dis[self._bad_state_idx], 0.0)
        return is_valid


def main():
    for t in range(100):
        ns = 5
        na = 3
        seed = 300
        np.random.seed(seed)
        init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
        max_episode_steps = 10
        cw = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        optimal_policy = cw.run_value_iteration()
        true_optimal_policy = cw.get_optimal_policy()
        assert np.array_equal(optimal_policy[:ns-1, :, :], true_optimal_policy[:ns-1, :, :]),\
            'The calculated optimal policy is wrong.'
        print('Pass the test in iteration %d' % t)


def test_cliffwalking():

    ns = 5
    na = 3
    seed = 300
    np.random.seed(seed)
    init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
    max_episode_steps = 10
    cw = CliffWalking(ns, na, init_state_dis, max_episode_steps)
    # test_policy_evaluation(env=cw)
    test_occupancy_measure(env=cw)


def test_dis_cliffwalking():
    ns = 5
    na = 3
    seed = 300
    np.random.seed(seed)
    init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
    gamma = 0.99
    optimal_action = np.random.randint(na)
    discount_env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
    optimal_policy = discount_env.run_value_iteration()
    test_occupancy_measure(discount_env, 'Discounted')


def test_random_dis_cliffwalking():
    for t in range(100):
        ns = 5
        na = 3
        seed = 300
        np.random.seed(seed)
        init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
        gamma = 0.99
        optimal_action = np.random.randint(na)
        discount_env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        optimal_policy = discount_env.run_value_iteration()
        true_optimal_policy = discount_env.get_optimal_policy()
        assert np.array_equal(optimal_policy[:ns-1], true_optimal_policy[:ns-1]),\
            'The calculated optimal policy is wrong.'
        print('Pass the test in iteration %d' % t)


if __name__ == '__main__':
    test_dis_cliffwalking()
    # test_random_dis_cliffwalking()
    # main()
    # test_cliffwalking()

