import numpy as np

from core.MDP import MDP
from utils.utils import phi_exp_factory, phi_exp_inv_factory, phi_exp_combined_factory


class RandomMDP():
    
    def __init__(self,
                 S_size=70,
                 A_size=10,
                 gamma=.9,
                 seed=21):
        
        '''
        Randomly generate a MDP model.
        '''
        
        self.S_size = S_size
        self.A_size = A_size
        self.gamma = gamma
        
        # Randomly generate MDP.
        np.random.seed(seed)
        self.P = np.random.uniform(size=(A_size, S_size, S_size), low=0, high=1)
        for a in range(A_size): 
            for s in range(S_size):
                self.P[a,s] = self.P[a,s] / np.sum(self.P[a,s])    # Normalize.
        self.rewards = np.random.uniform(size=(A_size, S_size))
        self.rewards = np.expand_dims(self.rewards, axis=2).repeat(S_size, 2)
        
        self.mdp = MDP(self.P, self.gamma, self.rewards)
        
        
    def solve_mdp(self,
                  mode={"alg": "policy_iteration"},
                  max_iter=10000,
                  epsilon=1e-7,
                  step_size=1,
                  asynchronous=False,
                  init=False,
                  verbose=False,
                  need_return=False,
                  noise=None,
                  seed=21):
        
        assert mode["alg"] in ["value_iteration", "policy_iteration",
                               "projected_Q_descent", "policy_descent",
                               "softmax", "softmax_adaptive", "softmax_temp", "softmax_NPG",
                               "phi", "escort", "escort_normalized"]
        
        alg = mode["alg"]
        
        if init:
            self.mdp.init_policy_and_V(random_init=True)
        
        if not verbose:
            print("Solving!")
        
        if alg == "value_iteration":
            return_dict = self.mdp.value_iteration(epsilon=epsilon, max_iter=max_iter, asynchronous=asynchronous,
                                              need_return=need_return, silence=verbose, seed=seed)
        elif alg == "policy_iteration":
            return_dict = self.mdp.policy_iteration(max_iter=max_iter,
                                              need_return=need_return, silence=verbose, seed=seed)
        elif alg == "projected_Q_descent":
            return_dict = self.mdp.projected_Q_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, noise=noise, seed=seed)
        elif alg == "policy_descent":
            return_dict = self.mdp.projected_Q_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, mode="policy_descent", noise=noise, seed=seed)
        elif alg == "softmax":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, noise=noise, seed=seed)
        elif alg == "softmax_adaptive":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, mode="adaptive", noise=noise, seed=seed)     
        # elif alg == "softmax_temp":
        #     return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
        #                                       need_return=need_return, silence=verbose, mode="temp", noise=noise, seed=seed)    
        elif alg == "softmax_NPG":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, mode="NPG", noise=noise, seed=seed)  
        elif alg == "escort_normalized":
            p = mode["p"]
            return_dict = self.mdp.escort_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, mode="normalized", p=p, noise=noise, seed=seed)
        elif alg == "escort":
            p = mode["p"]
            return_dict = self.mdp.escort_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, mode="origin", p=p, noise=noise, seed=seed)                   
        elif alg == "phi": 
            phi = mode["phi"]
            step_include_d = mode.get("step_include_d", False)
            init_type = mode.get("init_type", "softmax")
            return_dict = self.mdp.phi_policy_update(phi, max_iter=max_iter, step_size=step_size,
                                                     need_return=need_return, silence=verbose, noise=noise, step_include_d=step_include_d, seed=seed, init_type=init_type)
        if need_return:
            return return_dict    
        
        
    def V_curve_in_VI(self,
                      raw_V_list):
        
        V_policy_list = []
        
        for V in raw_V_list:
            self.mdp.set_V(V)
            self.mdp.extract_policy()
            self.mdp.evaluate_policy()
            V_policy_list.append(self.mdp.V.copy())
            
        return V_policy_list