"""This class creates a random MDP.

Taken from: https://github.com/google-research/google-research/blob/master/ksme/random_mdps/random_mdp.py
"""

import numpy as np
import pdb

class RandomMDP(object):
    def __init__(self, num_states, num_actions, policy_type='stochastic', reward_variance=1.0, use_terminal_state = False):
        assert num_states > 0, 'Number of states must be positive.'
        assert num_actions > 0, 'Number of actions must be positive.'
        self.init_state = 0
        self.state = self.init_state
        self.num_states = num_states
        self.n_state = num_states
        self.num_actions = num_actions
        self.n_action = num_actions
        self.observation_space = np.zeros((self.n_state,))
        self.action_space = np.zeros((self.n_action,))
        self.use_terminal_state = use_terminal_state
        # We start with a fully unnormalized SxAxS matrix.
        self.transition_probs = np.random.rand(num_states, num_actions, num_states)
        for x in range(num_states):
            for a in range(num_actions):
                # Pick the number of states with zero mass.
                num_non_next = np.random.randint(1, num_states)
                non_next_idx = np.random.choice(np.arange(num_states),
                                                size=num_non_next, replace=False)
                # Zero out the chosen states.
                self.transition_probs[x, a, non_next_idx] = 0.
                # Normalize to make them sum to one.
                self.transition_probs[x, a, :] /= np.sum(self.transition_probs[x, a, :])
        self.transitions = self.transition_probs

        if self.use_terminal_state:
            self.transitions[self.n_state - 1, :, :] = 0
        # Reward mean and stddev are picked randomly.
        self.rewards = np.random.normal(loc=1., scale=reward_variance,
                                        size=(num_states, num_actions))
        self.rewards = np.clip(self.rewards, -5.0, 8.0)
        rs = self.rewards
        self.rewards = ((rs - rs.min()) / (rs.max() - rs.min())) * (5 - (-5)) + (-5)
        self.reward_range = 5 - (-5)

    def reset(self):
        self.state = 0
        return self.state, {}
    
    def step(self, action):
        states = np.arange(self.num_states)
        rew = self.rewards[self.state, action]
        try:
            n_state = np.random.choice(states,
                                    p=self.transition_probs[self.state, action])
        except:
            pdb.set_trace()
        self.state = n_state
        if self.use_terminal_state:
            done = self.state == self.n_state - 1
        else:
            done = False
        return n_state, rew, done, False, {}

    def get_policy_probs(self, policy):
        # Clip rewards to lie in [0., 1.]
        #self.rewards = np.clip(self.rewards, 0.0, 1.0)
        policy_transition_probs = np.einsum('ijk,ij->ik',
                                                self.transition_probs,
                                                policy)
        policy_rewards = np.einsum('ij,ij->i', self.rewards, policy)

        pi_trans_sa = []

        for s in range(self.num_states):
            sub = []
            for a in range(self.num_actions):
                next_s = self.transition_probs[s][a].reshape(-1,1)
                test = np.einsum('ij,ij->ij', policy, next_s)
                sub.append(test)
            pi_trans_sa.append(sub)

        pi_trans_sa = np.array(pi_trans_sa)

        return policy_transition_probs, policy_rewards, pi_trans_sa
