from envs.tabular_env import TabularEnv
from envs.dis_tabular_env import DiscountedTabularEnv
import numpy as np
from utils.utils import sample_one_trajetory
from utils.envs.test_env import test_policy_evaluation, test_occupancy_measure
from utils.envs.env_utils import set_init_state_dis


class Bandit(TabularEnv):
    """
    Bandit-like environment.
    Each state is absorbing. At each state s, a_0 is the optimal action and r(s, a_0) = 1. The other actions are
    non-optimal and r(s, -a_0) = 0

    """

    def __init__(self, num_state: int, num_action: int, initial_state_dis: np.ndarray, max_episode_steps: int,
                 early_stop=False):

        self._opt_action_idx = np.random.randint(num_action)
        reward_vec = np.zeros(shape=[num_state, num_action], dtype=np.float32)
        reward_vec[:, self._opt_action_idx] = 1.0

        # Each state is absorbing.

        transition_matrix = np.zeros(shape=[num_state, num_action, num_state], dtype=np.float32)
        for state in range(num_state):
            transition_matrix[state, :, state] = 1.0

        super(Bandit, self).__init__(num_state, num_action, max_episode_steps, initial_state_dis, reward_vec,
                                     transition_matrix, early_stop)

        self.reset()

    def _check_init_state_dis(self, state_dis: np.ndarray):

        is_valid = (state_dis.shape[0] == self._ns) and np.isclose(np.sum(state_dis), 1.0)
        return is_valid

    def generate_experience(self, current_state_idx, action_idx):

        next_state_dis = self._T[current_state_idx, action_idx, :]
        next_state_idx = np.random.choice(a=self._ns, p=next_state_dis)
        reward = self.reward_vec[current_state_idx, action_idx]
        terminal = True

        return (next_state_idx,
                reward,
                terminal)

    def get_optimal_policy(self):

        """
        Get the optimal policy
        Returns:
            optimal_policy: the optimal policy, numpy array with shape [ns, na, H]
        """
        M, N, H = self._ns, self._na, self._max_episode_steps
        optimal_policy = np.zeros(shape=[M, N, H], dtype=np.float32)
        action_dis = np.zeros(shape=[self._na], dtype=np.float32)
        action_dis[self._opt_action_idx] = 1.0
        optimal_policy[:, :, :] = np.reshape(np.tile(action_dis, (M, 1)), (M, N, 1))

        return optimal_policy

    def render(self, mode='human'):
        raise NotImplementedError


class DisBandit(DiscountedTabularEnv):
    """
    Discounted Bandit environment.
    """

    def __init__(self, num_state: int, num_action: int, gamma: float, initial_state_dis: np.ndarray,
                 optimal_action: int, early_stop=False):

        self.optimal_action = optimal_action
        reward_mat = np.zeros(shape=(num_state, num_action), dtype=np.float32)
        reward_mat[:, optimal_action] = 1.0
        transition_matrix = np.zeros(shape=(num_state, num_action, num_state), dtype=np.float32)
        for state in range(num_state):
            transition_matrix[state, :, state] = 1.0
        super(DisBandit, self).__init__(num_state, num_action, gamma, initial_state_dis, reward_mat, transition_matrix,
                                        early_stop)

    def _check_init_state_dis(self, state_dis: np.ndarray):

        is_valid = (state_dis.shape[0] == self._ns) and np.isclose(np.sum(state_dis), 1.0)
        return is_valid

    def generate_experience(self, current_state_idx, action_idx):

        next_state_dis = self._T[current_state_idx, action_idx, :]
        next_state_idx = np.random.choice(a=self._ns, p=next_state_dis)
        reward = self.reward_mat[current_state_idx, action_idx]
        terminal = True

        return (next_state_idx,
                reward,
                terminal)

    def get_optimal_policy(self):
        """
        Get the optimal stationary policy.
        Return:
            optimal_policy: a numpy array with shape [ns].
        """
        M, N = self._ns, self._na
        optimal_policy = np.zeros(shape=(M, N), dtype=np.float32)
        optimal_policy[:, self.optimal_action] = 1.0

        return optimal_policy

    def render(self, mode='human'):
        raise NotImplementedError


def test_bandit():

    ns = 5
    na = 3
    seed = 300
    np.random.seed(seed)
    init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
    max_episode_steps = 10
    env = Bandit(ns, na, init_state_dis, max_episode_steps)
    test_policy_evaluation(env)
    test_occupancy_measure(env)


def test_dis_bandit():
    ns = 5
    na = 3
    init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
    gamma = 0.99

    for t in range(100):
        np.random.seed(t)
        optimal_action = np.random.randint(na)
        discount_env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        optimal_policy = discount_env.run_value_iteration()
        true_optimal_policy = discount_env.get_optimal_policy()
        assert np.array_equal(optimal_policy, true_optimal_policy), \
            'The calculated optimal policy is wrong.'

    print('Pass the value iteration test.')
    optimal_action = np.random.randint(na)
    discount_env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
    test_policy_evaluation(discount_env, env_type='Discounted')
    test_occupancy_measure(discount_env, env_type='Discounted')


def main():
    ns = 5
    na = 3
    init_state_dis = np.array([0.1, 0.2, 0.3, 0.4, 0], dtype=np.float32)
    max_episode_steps = 10
    env = Bandit(ns, na, init_state_dis, max_episode_steps)
    optimal_policy = env.run_value_iteration()
    traj, ret = sample_one_trajetory(env, optimal_policy)
    print(optimal_policy)
    print(traj)
    print(ret)


if __name__ == '__main__':
    test_dis_bandit()
    # main()
    # test_bandit()