from envs.dis_tabular_env import DiscountedTabularEnv
from utils.envs.test_env import test_occupancy_measure, test_policy_evaluation
import numpy as np


class DisTwoWayMDP(DiscountedTabularEnv):

    def __init__(self, num_state: int, num_action: int, gamma: float, initial_state_dist: np.ndarray,
                 optimal_det_policy: np.ndarray, early_stop=False):

        assert (num_state >= 3) and ((num_state-3) % 2 == 0), 'The number of states is invalid.'
        self._bad_state_idx = num_state - 1
        self._good_state_idx = num_state - 2
        # optimal_det_policy is a numpy array with shape = [num_state]
        self.optimal_det_policy = optimal_det_policy
        self._chain_length = int((num_state-3) / 2)

        # reward matrix
        reward_mat = np.zeros(shape=(num_state, num_action), dtype=np.float32)
        reward_mat[self._good_state_idx, :] = 1.0
        reward_mat[self._bad_state_idx, :] = -1.0
        transition_prob = self._create_transition_matrix(num_state, num_action)

        super(DisTwoWayMDP, self).__init__(num_state, num_action, gamma, initial_state_dist, reward_mat,
                                           transition_prob, early_stop)

        self.reset()

    def generate_experience(self, current_state_idx, action_idx):
        next_state_dis = self._T[current_state_idx, action_idx, :]
        next_state_idx = np.random.choice(a=self._ns, p=next_state_dis)
        reward = self._reward_mat[current_state_idx, action_idx]
        terminal = True if (next_state_idx == self._bad_state_idx) or (next_state_idx == self._good_state_idx) else False

        return (next_state_idx,
                reward,
                terminal)

    def _create_transition_matrix(self, num_state: int, num_action: int):
        transition_prob = np.zeros(shape=(num_state, num_action, num_state), dtype=np.float32)
        # the transition probability at the initial state.
        transition_prob[0, :, self._chain_length+1] = 1.0
        transition_prob[0, self.optimal_det_policy[0], self._chain_length+1] = 0.0
        transition_prob[0, self.optimal_det_policy[0], 1] = 1.0

        # the transition probability at two absorbing states.
        transition_prob[self._good_state_idx, :, self._good_state_idx] = 1.0
        transition_prob[self._bad_state_idx, :, self._bad_state_idx] = 1.0

        # the transition probability at expert trajectory
        for state_idx in range(1, self._chain_length+1, 1):
            # non-optimal actions lead to bad state
            transition_prob[state_idx, :, self._bad_state_idx] = 1.0
            transition_prob[state_idx, self.optimal_det_policy[state_idx], self._bad_state_idx] = 0.0
            # optimal actions lead to next or good states
            if state_idx == self._chain_length:
                transition_prob[state_idx, self.optimal_det_policy[state_idx], self._good_state_idx] = 1.0
            else:
                transition_prob[state_idx, self.optimal_det_policy[state_idx], state_idx+1] = 1.0

        # the transition probability at the other trajectory
        for state_idx in range(self._chain_length+1, 2*self._chain_length+1, 1):
            # all actions lead to next or good states
            if state_idx == 2*self._chain_length:
                transition_prob[state_idx, :, self._good_state_idx] = 1.0
            else:
                transition_prob[state_idx, :, state_idx+1] = 1.0
        normalizer = np.sum(transition_prob, axis=2)
        assert np.all(transition_prob >= 0) and np.allclose(normalizer, 1.0), 'Invalid transition probability'

        return transition_prob

    def get_optimal_policy(self):
        M, N = self._ns, self._na
        optimal_policy = np.eye(N)[self.optimal_det_policy]

        return optimal_policy

    def render(self, mode='human'):

        raise NotImplementedError

    def compute_policy_value(self):
        raise NotImplementedError

    def _check_init_state_dis(self, state_dis: np.ndarray):

        is_valid = state_dis.shape[0] == self._ns and np.isclose(np.sum(state_dis), 1.0) \
                   and np.isclose(state_dis[self._bad_state_idx], 0.0)
        return is_valid


if __name__ == '__main__':
    ns = 9
    na = 4
    gamma = 0.9
    init_state_dist = np.zeros(shape=[ns], dtype=np.float32)
    init_state_dist[0] = 1.0
    optimal_det_policy = np.random.randint(0, na, size=[ns])
    print(optimal_det_policy)
    env = DisTwoWayMDP(ns, na, gamma, init_state_dist, optimal_det_policy)
    opt_policy = env.get_optimal_policy()
    calculated_optimal_policy = env.run_value_iteration()
    value1 = env.policy_evaluation(opt_policy)
    value2 = env.policy_evaluation(calculated_optimal_policy)
    print(value2)
    print(calculated_optimal_policy)
    assert np.isclose(value1, value2)
    # test_policy_evaluation(env, 'Discounted')
    # test_occupancy_measure(env, 'Discounted')


