import numpy as np
import random
import pdb

def set_hyps(a,a_default):
    if a is None:
        return a_default
    else:
        return a

def set_seed(seed=1):
    if seed is not None:
        np.random.seed(seed)    
        random.seed(seed)

def env_setup(seed_init=1,num_states=5,num_actions=4,rho=None,transP_func=None,reward_func=None,lambda1=1.0,gamma=0.95):
    env_dict={}
    set_seed(seed_init)
    env_dict['seed_init']=seed_init
    env_dict['num_states']=num_states
    env_dict['num_actions']=num_actions
    
    if rho is None:
        rho=np.ones(env_dict['num_states'])/env_dict['num_states']
    else:
        if isinstance(rho, list):
            rho=np.array(rho)
        assert rho.size==env_dict['num_states'], "rho should have "+str(env_dict['num_states'])+" entries."
    rho=np.abs(rho).reshape(env_dict['num_states'])
    env_dict['rho']=rho/rho.sum()
    
    #The L2 ambiguity set contains all P(s'|s,a)=<Psi[:,a,s'], xi[s,:]> such that 
    #||xi[s,:]-xi0[s,:]||_2<=xi_radius for all s
    #sum_{s'} <Psi[:,a,s'], xi0[s,:]>  =  sum_{s'} <Psi[:,a,s'], xi[s,:]>  =1
    
    #If Psi=None, then this is tabular case where xi[s,a,s'] parameterizes P(s'|s,a)
    # Equivalently, Psi[:,a,s'] can be seen as one-hot with (a,s')-th entry being

    env_dict['gamma']=gamma
    env_dict['lambda']=lambda1
    
    if transP_func is None:
        def transP_func(pi):    #transition kernel p(s,a,s')=[pi(s',a)+pi(s,a)+1]/sum_{s''}[pi(s'',a)+pi(s,a)+1]
            num_states,num_actions=pi.shape
            transP=(pi+1).reshape(pi.shape+(1,))+pi.T.reshape((1,num_actions,num_states))       
            transP2=transP/transP.sum(axis=2,keepdims=True)
            return transP2
    
    if reward_func is None:     #reward r(s,a)=pi(s,a)
        def reward_func(pi):
            return pi
    
    # def reward_lambda(pi):
    #     return reward_func2(pi)-lambda1*np.log(pi)
    
    env_dict['transP_func']=transP_func
    env_dict['reward_func']=reward_func
    return env_dict

def get_transP_s2s(transP,pi): #Obtain state transition distribution transP_s2s(s,s')=sum_a p(s,a,s')*pi(s,a)
    num_states,num_actions=pi.shape
    return (transP*(pi.reshape(num_states,num_actions,1))).sum(axis=1)

def V_func(transP,reward,pi,gamma,max_iters=1000,eps=1e-12,is_print=False):
    num_states,num_actions=pi.shape
    transP_s2s=get_transP_s2s(transP,pi)
    V=np.zeros(num_states)
    for t in range(max_iters):
        V_next=(reward*pi).sum(axis=1)+gamma*transP_s2s.dot(V)
        if np.max(np.abs(V_next-V))<=eps:
            V=V_next.copy()
            break
        V=V_next.copy()
        if is_print:
            print("Value-iteration "+str(t)+": V="+str(V))
    if is_print:
        print()
#    pdb.set_trace()
    return V

def Q_func(transP,reward,pi,gamma,V=None):
    if V is None:
        V=V_func(transP,reward,pi,gamma,max_iters=1000,eps=1e-12,is_print=False)
    return (transP*(reward.reshape(reward.shape[0],reward.shape[1],1)+gamma*V.reshape(1,1,-1))).sum(axis=2)
    #Return matrix Q where Q[s,a] is the Q function value at (s,a)

def our_0FW(env_dict,lambda1,num_iters,num_samples,Delta,delta,beta,pi0=None,is_print=False,numV_iters=1000,V_iter_eps=1e-12,is_printV=False):
    if delta<=0:
        raise ValueError("The input argument delta should be positive.")
    if Delta<=delta:
        raise ValueError("The input arguments should satisfy Delta<delta.")

    num_states,num_actions=env_dict['num_states'],env_dict['num_actions']
    if pi0 is None:
        pi=np.ones((num_states,num_actions))/num_actions
    else:
        pi=pi0.copy()
    if np.min(pi)<Delta: 
        raise ValueError("The initial policy pi0 is either None (uniform policy) or with entries not smaller than Delta.")

    V_unreg=np.zeros(num_iters+1)
    Entropy=np.zeros(num_iters+1)
    V_reg=np.zeros(num_iters+1)
    # g_dot_dpi_min=+np.inf
    # pi_opt=-1
    # t_opt=-1
    
    g_coeff=num_states*(num_actions-1)/(2*num_samples*delta)
    pi_max=1-Delta*num_actions
    for t in range(num_iters):
        transP=env_dict['transP_func'](pi)
        reward=env_dict['reward_func'](pi)
        
        V_unreg[t]=V_func(transP,reward,pi,env_dict['gamma'],max_iters=numV_iters,\
                            eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
        Entropy[t]=V_func(transP,-np.log(pi),pi,env_dict['gamma'],max_iters=numV_iters,\
                            eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
        V_reg[t]=V_unreg[t]+lambda1*Entropy[t]
        if is_print:
            print("0FW-iteration "+str(t)+": V_unreg="+str(V_unreg[t])+", V_reg="+str(V_reg[t]))
        
        # v=np.random.normal(size=(num_samples,num_states,num_actions))   
        # v/=np.sqrt((v*v).sum(axis=(1,2),keepdims=True))   #v[i,s,a] is v_i(s,a), v[:,:,i] is an uniformly distributed unit vector.
        # u=v-np.mean(v,axis=1,keepdims=True)
        # u/=np.sqrt((u*u).sum(axis=(1,2),keepdims=True))   #u[i,s,a] is u_i(s,a).

        ghat=0
        for i in range(num_samples):
            v=np.random.normal(size=(num_states,num_actions)) 
            v/=np.sqrt((v*v).sum())                      #v is an uniformly distributed unit vector of dimensionality |S|*|A|.
            u=v-np.mean(v,axis=1,keepdims=True)
            u/=np.sqrt((u*u).sum())
            
            pi_now=pi-delta*u
            transP=env_dict['transP_func'](pi_now)
            reward=env_dict['reward_func'](pi_now)-lambda1*np.log(pi_now)
            Vneg=V_func(transP,reward,pi_now,env_dict['gamma'],max_iters=numV_iters,eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
            
            pi_now=pi+delta*u
            transP=env_dict['transP_func'](pi_now)
            reward=env_dict['reward_func'](pi_now)-lambda1*np.log(pi_now)
            Vpos=V_func(transP,reward,pi_now,env_dict['gamma'],max_iters=numV_iters,eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])

            ghat+=(Vpos-Vneg)*u
        ghat*=g_coeff
        
        argmax_a=np.argmax(ghat,axis=1)
        one_hot_max_a=np.eye(num_actions)[argmax_a]
        pi_tilde=Delta+pi_max*one_hot_max_a

        # g_dot_dpi=(ghat*(pi_tilde-pi)).sum()
        # if g_dot_dpi<g_dot_dpi_min:
        #     pi_opt=pi.copy()
        #     t_opt=t
        #     g_dot_dpi_min=g_dot_dpi

        pi=(1-beta)*pi+beta*pi_tilde
        
    transP=env_dict['transP_func'](pi)
    reward=env_dict['reward_func'](pi)
    V_unreg[num_iters]=V_func(transP,reward,pi,env_dict['gamma'],max_iters=numV_iters,\
                        eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
    Entropy[num_iters]=V_func(transP,-np.log(pi),pi,env_dict['gamma'],max_iters=numV_iters,\
                        eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
    V_reg[num_iters]=V_unreg[num_iters]+lambda1*Entropy[num_iters]
    if is_print:
        print("0FW-iteration "+str(num_iters)+": V_unreg="+str(V_unreg[num_iters])+", V_reg="+str(V_reg[num_iters]))

    return {'pi_last':pi,"V_unreg":V_unreg,"Entropy":Entropy,"V_reg":V_reg,\
            "delta":delta,"Delta":Delta,"beta(stepsize)":beta,"num_samples":num_samples}
    # return {'pi_last':pi,'pi_opt':pi_opt,"Vt":V,"t_opt":t_opt,"V_opt":V[t_opt],\
    #         "delta":delta,"Delta":Delta,"beta(stepsize)":beta,"num_samples":num_samples}

def repeat_train(env_dict,lambda1,outer_iters,inner_iters,eta,pi0=None,is_print=False,numV_iters=1000,V_iter_eps=1e-12,is_printV=False):
    num_states,num_actions=env_dict['num_states'],env_dict['num_actions']
    if pi0 is None:
        pi=np.ones((num_states,num_actions))/num_actions
    else:
        pi=pi0.copy()
    
    V_unreg=np.zeros(outer_iters+1)
    Entropy=np.zeros(outer_iters+1)
    V_reg=np.zeros(outer_iters+1)
    
    for t in range(outer_iters):
        transP=env_dict['transP_func'](pi)
        reward=env_dict['reward_func'](pi)
                    
        V_unreg[t]=V_func(transP,reward,pi,env_dict['gamma'],max_iters=numV_iters,\
                            eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
        Entropy[t]=V_func(transP,-np.log(pi),pi,env_dict['gamma'],max_iters=numV_iters,\
                            eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
        V_reg[t]=V_unreg[t]+lambda1*Entropy[t]        
            
        if is_print:
            print("Repeated Training iteration "+str(t)+": V_unreg="+str(V_unreg[t])+", V_reg="+str(V_reg[t]))
        
        log_pi=np.log(pi)
        for k in range(inner_iters):
            V=V_func(transP,reward,pi,env_dict['gamma'],max_iters=numV_iters,eps=V_iter_eps,is_print=is_printV)
            Q=Q_func(transP,reward,pi,env_dict['gamma'],V)
            log_pi+=eta*Q
            log_pi-=log_pi.max(axis=1,keepdims=True)
            pi=np.exp(log_pi)
            sum1=pi.sum(axis=1).reshape(-1,1)
            log_pi-=np.log(sum1)
            pi/=sum1
            if np.max(np.abs(np.exp(log_pi)-pi))>1e-14:
                assert False, "Value Error: exp(log_pi) should almost be equal to pi"

    t=outer_iters
    transP=env_dict['transP_func'](pi)
    reward=env_dict['reward_func'](pi)
    V_unreg[t]=V_func(transP,reward,pi,env_dict['gamma'],max_iters=numV_iters,\
                        eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
    Entropy[t]=V_func(transP,-np.log(pi),pi,env_dict['gamma'],max_iters=numV_iters,\
                        eps=V_iter_eps,is_print=is_printV).dot(env_dict['rho'])
    V_reg[t]=V_unreg[t]+lambda1*Entropy[t]  
    if is_print:
        print("Repeated Training iteration "+str(t)+": V_unreg="+str(V_unreg[t])+", V_reg="+str(V_reg[t]))

    return {'pi_last':pi,"V_unreg":V_unreg,"Entropy":Entropy,"V_reg":V_reg,\
            "eta(stepsize)":eta,"outer_iters":outer_iters,"inner_iters":inner_iters}






