import gym
import numpy as np


class DiscountedTabularEnv(gym.Env):
    """ Super class for discounted tabular environments.

    Attributes:
        _ns: number of states
        _na: number of actions
        _gamma: discounted factor
        _max_episode_steps: the effective horizon
        _init_state_dis: initial state distribution
        _T: stationary transition matrix, a numpy array with shape [ns, na, ns]
        _reward_mat: stationary reward matrix, a numpy array with shape [ns, na]
        early_stop: whether early stop
        observation_space: the state space
        action_space: the action space
    """

    _episode_step = 0
    _current_state_idx = 0

    def __init__(self, num_state: int, num_action: int, gamma: float, initial_state_dis: np.ndarray,
                 reward_mat: np.ndarray, transition_matrix: np.ndarray, early_stop: bool):

        self._ns = num_state
        self._na = num_action
        self._gamma = gamma

        # The effective horizon, we may need to tune it.
        self._max_episode_steps = int((2.0 / (1.0 - gamma)) * np.log(1.0 / (1.0 - gamma)))

        assert self._check_init_state_dis(initial_state_dis), 'Invalid initial state distribution!'
        self._init_state_dis = initial_state_dis
        self._T = transition_matrix

        self._reward_mat = reward_mat
        self.early_stop = early_stop
        self.observation_space = gym.spaces.Discrete(self._ns)
        self.action_space = gym.spaces.Discrete(self._na)

    def render(self, mode='human'):
        raise NotImplementedError

    @property
    def size(self):
        return self._ns

    @property
    def reward_mat(self):
        return self._reward_mat.copy()

    @property
    def init_state_distribution(self):
        return self._init_state_dis.copy()

    @property
    def transition_probability(self):
        return self._T.copy()

    def _check_init_state_dis(self, state_dis: np.ndarray):

        raise NotImplementedError

    def reset(self):
        self._episode_step = 0
        self._current_state_idx = np.random.choice(self._ns, p=self._init_state_dis)

        return self._current_state_idx

    def generate_experience(self, current_state_idx, action_idx):

        raise NotImplementedError

    def get_optimal_policy(self):
        """
            Return the stationary optimal policy, a numpy array with shape [ns, na].
        """

        raise NotImplementedError

    def step(self, action):
        assert self.action_space.contains(action), 'Invalid action'
        self._episode_step += 1
        next_state_idx, reward, terminal = self.generate_experience(self._current_state_idx, action)
        self._current_state_idx = next_state_idx
        done = terminal if self.early_stop else False
        if self._episode_step >= self._max_episode_steps:
            done = True

        return next_state_idx, reward, done, {'terminal': terminal}

    def _init_value_function_policy_storage(self):
        M, N = self._ns, self._na
        V_function = np.zeros((M))
        Q_function = np.zeros((M, N))
        policy = np.full_like(Q_function, fill_value=1.0/N)
        return V_function, Q_function, policy

    def _generate_greedy_policy(self, Q_functions: np.ndarray):
        """
        Args:
            Q_functions: Q functions, a numpy array with shape [num_state, num_action].
        Returns:
            greedy_policy: the policy acts greedily w.r.t Q_function, a numpy array with shape
            [num_state, num_action]
        """
        M, N = self._ns, self._na
        greedy_policy = np.zeros(shape=(M, N), dtype=np.float32)

        for state in range(M):
            action_dis = np.zeros(shape=(N), dtype=np.float32)
            Q_function_at_state = Q_functions[state, :]
            # randomized the greedy policy to avoid bias
            greedy_action = np.random.choice(np.flatnonzero(Q_function_at_state == Q_function_at_state.max()))
            action_dis[greedy_action] = 1.0
            greedy_policy[state, :] = action_dis

        return greedy_policy

    def _run_policy_evaluation(self, policy: np.ndarray):
        """ Run policy evaluation with underlying reward function and transition probability.
        Here we calculate the value function by equation: V^{pi} = (I - gamma P^pi)^-1 r^pi

        Args:
            policy: a numpy array with shape [num_state, num_action].
        Returns:
            V_functions: the V function of policy, a numpy array with shape [num_state]
            Q_functions: the Q function of policy, a numpy array with shape [num_state, num_action]
        """

        V_function, Q_function, _ = self._init_value_function_policy_storage()
        M, N = self._ns, self._na
        gamma = self._gamma
        transition_prob = self.transition_probability
        reward_mat = self.reward_mat
        transition_matrix_by_pi = np.sum(transition_prob * np.reshape(policy, newshape=[M, N, 1]), axis=1)
        reward_vector_by_pi = np.sum(reward_mat * policy, axis=1)
        tmp = np.linalg.inv(np.eye(M) - gamma * transition_matrix_by_pi)
        V_function = np.matmul(tmp, reward_vector_by_pi)
        Q_function = reward_mat + gamma * np.sum(transition_prob * np.reshape(V_function, newshape=[1, 1, M]), axis=2)
        policy_value = float(np.sum(V_function * self._init_state_dis))

        return V_function, Q_function, policy_value

    def cal_policy_transition_prob(self, policy: np.ndarray):
        """Output the transition kernel P^{pi} under the induced policy.

        Args:
            policy: a numpy array with shape [num_state, num_action].
        Returns:
            transition_matrix_by_pi: a numpy array with shape [num_state*num_action, num_state*num_action]
        """
        M, N = self._ns, self._na
        transition_prob = self.transition_probability
        transition_prob_new = np.reshape(transition_prob, newshape=(M, N, M, 1))
        policy_new = np.reshape(policy, newshape=(1, 1, M, N))
        transition_matrix_by_pi = transition_prob_new * policy_new
        transition_matrix_by_pi = np.reshape(transition_matrix_by_pi, newshape=(M*N, -1))

        return transition_matrix_by_pi

    def policy_evaluation(self, policy: np.ndarray):
        """

        Args:
            policy: numpy array with shape [ns, na]
        Returns:
            policy_value: the policy value
        """

        _, _, policy_value = self._run_policy_evaluation(policy)
        return policy_value

    def run_policy_iteration(self, eps=1e-3):
        """
        Policy iteration algorithm. Warning: determine the stop rule of the loop by bellman error.
        Args:
            eps: the desired error

        """
        gamma = self._gamma
        _, _, policy = self._init_value_function_policy_storage()
        reward_vec = np.reshape(self.reward_mat, newshape=(-1))
        while True:
            tmp_v_func, tmp_q_func, _ = self._run_policy_evaluation(policy=policy)
            greedy_policy = self._generate_greedy_policy(Q_functions=tmp_q_func)
            policy = greedy_policy
            policy_transition_mat = self.cal_policy_transition_prob(policy)
            new_tmp_q_func = np.reshape(tmp_q_func, newshape=(-1))
            bellman_error = reward_vec + gamma * np.matmul(policy_transition_mat, new_tmp_q_func) - new_tmp_q_func
            if np.max(np.abs(bellman_error)) < (1-gamma) * eps:
                break

        opt_v_function, opt_q_function, opt_value = self._run_policy_evaluation(policy=policy)

        return policy, opt_v_function, opt_q_function, opt_value

    def run_value_iteration(self, eps=1e-3):
        V_function, Q_function, policy = self._init_value_function_policy_storage()
        M, N, gamma = self._ns, self._na, self._gamma
        reward_matrix = self.reward_mat
        reward_vec = np.reshape(reward_matrix, newshape=(-1))
        transition_matrix = self.transition_probability

        while True:
            cur_v = np.max(Q_function, axis=1)
            next_q = reward_matrix + gamma * np.sum(transition_matrix * np.reshape(cur_v, [1, 1, M]), axis=2)
            Q_function = next_q.copy()
            Q_vec = np.reshape(Q_function, newshape=(-1))
            greedy_policy = self._generate_greedy_policy(Q_function)
            policy_transition_mat = self.cal_policy_transition_prob(greedy_policy)
            bellman_error = reward_vec + gamma * np.matmul(policy_transition_mat, Q_vec) - Q_vec
            if np.max(np.abs(bellman_error)) < (1 - gamma) * eps:
                break

        opt_policy = self._generate_greedy_policy(Q_functions=Q_function)
        return opt_policy

    def calculate_occupancy_measure(self, policy: np.ndarray):
        """
        Calculate the discounted occupancy measure induced by a policy.
        d_pi (s) = (1-gamma) d_0 (s) + gamma sum_{s', a'} rho_pi (s', a') P (s| s', a')

        Args:
            policy: a numpy array with shape [S, A]
        Returns:
            rho: a numpy array with shape [S, A], where rho(s, a) = (1-gamma) sum_{h=0}^inf gamma^h Pr (s_h=s, a_h=a).
        """
        M, N, gamma = self._ns, self._na, self._gamma
        H = self._max_episode_steps
        rho = np.random.random(size=(M, N))
        normalizer = float(np.sum(rho))
        rho = rho / normalizer
        transition_matrix = self.transition_probability
        init_state_dis = self.init_state_distribution
        for t in range(H):
            tmp_ds = (1.0 - gamma) * init_state_dis + \
                     gamma * np.sum(transition_matrix * np.reshape(rho, [M, N, 1]), axis=(0, 1))
            tmp_rho = policy * np.reshape(tmp_ds, newshape=(M, 1))
            rho = tmp_rho

        return rho

    def calculate_occupancy_measure_v2(self, policy: np.ndarray):
        """
        Calculate the discounted occupancy measure by formula d^pi = (I - gamma P^pi^T)^-1 d_0
        This implementation is more quick.

        Args:
            policy: a numpy array with shape [S, A]
        Returns:
            rho: a numpy array with shape [S, A], where rho(s, a) = (1-gamma) sum_{h=0}^inf gamma^h Pr (s_h=s, a_h=a).
        """

        transition_prob = self.transition_probability
        M, N, gamma = self._ns, self._na, self._gamma
        init_state_dist = self.init_state_distribution
        transition_matrix_by_policy = np.sum(transition_prob * np.reshape(policy, newshape=[M, N, 1]), axis=1)
        tmp = np.linalg.inv(np.eye(M) - gamma * transition_matrix_by_policy.transpose())
        state_dist = np.matmul(tmp, init_state_dist)
        rho = (1.0 - gamma) * policy * np.reshape(state_dist, newshape=(M, 1))
        return rho



