import numpy as np
import math
import pdb

class ChainMDP(object):
    def __init__(self, length, stoch_prob = 0.1, reward_variance = 0.):
        self.length = length
        self.x = 0
        self.n_state = self.num_states = length
        self.n_action = self.num_actions = 2
        self.transition_probs = np.zeros((self.n_state, self.n_action, self.n_state))
        self.rewards = np.zeros((self.n_state, self.n_action))
        self.observation_space = np.zeros((self.n_state,))
        self.action_space = np.zeros((self.n_action,))
        self.use_terminal_state = False

        # self.rewards[:, 0] = -1
        # self.rewards[:, 1] = -1
        self.rewards = np.random.normal(loc=0., scale=reward_variance,
                                        size=(self.num_states, self.num_actions))
        self.rewards[0, 0] = 2. # taking left in left-most state
        self.rewards[length - 1, 1] = 1 # taking right in right-most state

        self.max_rew = np.max(self.rewards)
        self.min_rew = np.min(self.rewards)
        self.reward_range = np.abs(self.max_rew - self.min_rew)

        for s in range(self.n_state):
            left_ns = max(0, s - 1)
            right_ns = min(self.n_state - 1, s + 1)

            # left action
            self.transition_probs[s, 0, left_ns] = (1. - stoch_prob)
            self.transition_probs[s, 0, right_ns] = stoch_prob
            
            # right action
            self.transition_probs[s, 1, right_ns] = (1. - stoch_prob)
            self.transition_probs[s, 1, left_ns] = stoch_prob

    def reset(self):
        self.x = 0#np.random.randint(self.length)
        return self.state_encoding(), {}

    def state_encoding(self):
        return self.x

    def step(self, action):
        rew = self.rewards[self.x, action]
        n_state = np.random.choice(np.arange(self.n_state),
            p=self.transition_probs[self.x, action])
        self.x = n_state
        return self.state_encoding(), rew, False, False, {}

    def get_policy_probs(self, policy):
        policy_transition_probs = np.einsum('ijk,ij->ik',
                                                self.transition_probs,
                                                policy)
        policy_rewards = np.einsum('ij,ij->i', self.rewards, policy)

        pi_trans_sa = []

        for s in range(self.n_state):
            sub = []
            for a in range(self.n_action):
                next_s = self.transition_probs[s][a].reshape(-1,1)
                test = np.einsum('ij,ij->ij', policy, next_s)
                sub.append(test)
            pi_trans_sa.append(sub)

        pi_trans_sa = np.array(pi_trans_sa)

        return policy_transition_probs, policy_rewards, pi_trans_sa



# pi = np.random.dirichlet(np.ones(env.n_action), size=env.n_state)
# p_s, r_s, p_sa = env.get_policy_probs(pi)
# p_sa = p_sa.reshape(-1, p_sa.shape[2] * p_sa.shape[3])

# r = env.rewards
# r_flat = r.reshape(r.shape[0] * r.shape[1], -1)

# num_sa = p_sa.shape[0]
# discounted_p = p_sa * gamma

# diff = np.eye(num_sa) - discounted_p
# sr_sa = np.linalg.inv(diff)
# qvals = np.matmul(sr_sa, r_flat)

# # r = r_s
# # p = p_s
# # num_states = p.shape[0]
# # discounted_p = p * gamma
# # vals = np.matmul(np.linalg.inv(np.eye(num_states) - discounted_p), r)
# qvals = qvals.reshape(length, 2)
# print (qvals)
# pdb.set_trace()