import numpy as np
import seeding as seeding

class synthetic_env:
    
    def __init__(self, 
                 state_space=30, 
                 action_space=4, 
                 state_dim=4, 
                 gam=0.95, 
                 seed=0,
                 tildeP=False):
        self.np_random, seed = seeding.np_random(seed)
        self.state_space = state_space
        self.action_space = action_space
        self.state_dim = state_dim
        self.gam =gam
        
        self.R = self.np_random.rand(state_space, action_space)
        P = self.np_random.rand(state_space, state_space, action_space)
        self.P = P/P.sum(axis=0)
        if tildeP:
            self.P = (1-self.gam) * self.P + self.gam/self.state_space
        self.states = self.np_random.rand(state_space, state_dim)
    
    def seed(self, seed):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    def reset(self):  
        self.state = self.np_random.randint(0,self.state_space-1)
        return self.states[self.state], self.onehot(self.state, self.state_space)
    
    def step(self, action):
        action_ = np.squeeze(action)
        r = self.R[self.state, action_]
        self.state = self.np_random.choice(self.state_space, 
                                      p=self.P[:, self.state, action_])  
        return self.states[self.state], r, self.onehot(self.state, self.state_space), False
    
    def set_stationary_state(self, mu):
        self.state = self.np_random.choice(self.state_space, p=mu)
        return self.states[self.state], self.onehot(self.state, self.state_space)
        
    def get_opt(self, pi):     
        mu, rank, Ps = self.get_mu(pi)
        if not rank == self.state_space:
            raise Exception('Non-solvable MDP.')
        # compute A_\pi,b_\pi
        A = 0
        b = 0
        for s in range(self.state_space):
            for s_ in range(self.state_space):
                A += mu[s] * Ps[s_,s] * self.states[s][:,np.newaxis] @ (self.gam * self.states[s_] - self.states[s])[np.newaxis,:]
            for a in range(self.action_space):
                b += mu[s] * pi[a,s] * self.R[s,a] * self.states[s]
        return -np.linalg.inv(A) @ b
    
    def get_mu(self, pi):
        # compute stationary distribution mu
        Ps = np.zeros((self.state_space, self.state_space))
        for s in range(self.state_space):
            Ps[:,s] = self.P[:,s,:] @ pi[:,s]
        a = np.eye(self.state_space) - Ps
        a = np.concatenate((a, np.ones(self.state_space)[np.newaxis,:]), axis=0)
        b = np.array([0] * self.state_space + [1])
        mu,_,rank,_ = np.linalg.lstsq(a, b, rcond=None)
        return mu, rank, Ps
    
    def onehot(self, index, length):
        v = np.zeros(length)
        v[index] = 1
        return v
        
        

# if __name__ == '__main__':
#     env = synthetic_env()
#     s, s_onehot = env.reset()
#     actions, states, rewards = [], [], []
#     for i in range(50):
#         action = np.random.randint(0,3)
#         actions.append(action)
#         state, rew, state_onehot, done = env.step(action)
#         states.append(state)
#         rewards.append(rew)
        