
import sys

from RobustRL_utils import *
import matplotlib.pyplot as plt
from tqdm import tqdm

def generate_trials(env_dict, trial_num, trial_len, state_init=None):
    trials = []
    state_space = env_dict['state_space']
    action_space = env_dict['action_space']
    
    rho = env_dict['rho']
    pi = env_dict['pi']
    xi = env_dict['xi']

    for i in range(trial_num):
        trial = {'states': [], 'actions': [], 'cost': 0}
        # print('rho: {}'.format(rho))
        state_init = np.random.choice(state_space, size=1, p=rho, replace=True)[0]
        # print('state_init: {}'.format(state_init))
        # s = state_init
        for h in range(trial_len):
            if h == 0:
                s = state_init
            else:
                s = s_next
            a = get_a(action_space, pi=pi, s=s, num=1, Psi=None)
            # print('a: {}'.format(a))
            s_next = get_s(state_space, xi, s=s, a=a, num=1, Psi=None)
            # print('s, a, s`: {}'.format(s, a, s_next))
            trial['states'].append(s)
            trial['actions'].append(a)
            trial['cost'] += env_dict['cost'][s][a][s_next]
            
        trials.append(trial)
    return trials

def get_lambda_stochastic(num_states, num_actions, gamma,  trials):
    '''get lambda hat ''' # dict removed
    # gamma = env_dict['gamma']
    trial_num = len(trials)
    trial_len = len(trials[0]['states'])
    
    # num_states = env_dict['num_states']
    # num_actions = env_dict['num_actions']
    lambda_hat = np.zeros((num_states, num_actions))
    
    
    gamma_powers = np.power(gamma, np.arange(trial_len))
    for trial in trials:
        states = trial['states']
        actions = trial['actions']
        # indices = (states, actions)
        lambda_hat[states, actions] += gamma_powers[:len(states)]

    # for trial in trials: # try to accelarate
    #     for h in range(trial_len):
    #         s, a = trial['states'][h], trial['actions'][h]
    #         lambda_hat[s,a] += pow(gamma, h) * 1
    
    lambda_hat *= (1-gamma) / trial_num
    return lambda_hat


def get_nabla_log_pi(pi, s, a):
    '''function used in calculating the stochastic gradient of g(theta)'''
    nabla_log_pi = np.zeros_like(pi)
    nabla_log_pi[s][a] = 1/(pi[s][a]) 
    return nabla_log_pi
    


def get_stochastic_grad_theta(num_states, num_actions, gamma, pi, trials, costs_hat):
    '''stochastic gradient function for theta, `pi_theta` policy'''
    # state_space = env_dict['state_space']
    # action_space = env_dict['action_space']
    
    # num_states = env_dict['num_states'] 
    # num_actions = env_dict['num_actions'] 
    # gamma = env_dict['gamma']
    
    
    grad_theta = np.zeros((num_states, num_actions))
    
    
    trial_num = len(trials)
    trial_len = len(trials[0]['states'])
    
    grad_sum = 0
    for trial in trials:
        trial_gradsum = np.zeros_like(pi)
        for t in range(trial_len):
            s, a = trial['states'][t], trial['actions'][t]
            # if t+1 >= trial_num:
            #     s_next = get_s(state_space, env_dict['xi0'], s, a, num=1, Psi=None)
            # else:
            #     s_next = trial['states'][t+1]
            log_sum_mat = np.zeros_like(pi)
            for h in range(t+1):
                sh, ah = trial['states'][h], trial['actions'][h]
                # nabla_log_pi 
                if pi[sh][ah] > 0:
                    log_sum_mat += get_nabla_log_pi(pi, sh, ah)
                else:
                    print('impossible error: pi[{}][{}]=0'.format(sh, ah))
            
            # trial_gradsum += pow(gamma, t) * costs_hat[s, a, s_next] * log_sum_mat
            trial_gradsum += pow(gamma, t) * costs_hat[s, a] * log_sum_mat
        grad_sum += trial_gradsum
            
    grad_theta = grad_sum/trial_num
    
    return grad_theta    
    

def get_nabla_log_pxi(xi, s, a, s_next):
    '''function used in calculating the stochastic gradient of g(xi)'''
    nabla_log_xi = np.zeros_like(xi)
    nabla_log_xi[s][a][s_next] = 1/xi[s][a][s_next]
    return nabla_log_xi

def get_stochastic_grad_xi(num_states, num_actions, gamma, xi, trials, costs_hat):
    '''stochastic gradient function for theta, `P_xi` transition kernelReturn: shape(s, a, s)'''
    
    # state_space = env_dict['state_space']
    # action_space = env_dict['action_space']
    
    # num_states = env_dict['num_states'] 
    # num_actions = env_dict['num_actions'] 
    # gamma = env_dict['gamma']
    
    state_space = np.array(range(0, num_states))
    grad_xi = np.zeros((num_states, num_actions, num_states))
    
    # gamma = env_dict['gamma']
    trial_num = len(trials)
    trial_len = len(trials[0]['states'])
    
    grad_sum = 0
    for trial in trials:
        trial_gradsum = np.zeros_like(xi)
        for t in range(trial_len):
            s, a = trial['states'][t], trial['actions'][t]
            # if t+1 >= trial_num:
            #     s_next = get_s(state_space, env_dict['xi0'], s, a, num=1, Psi=None)
            # else:
            #     s_next = trial['states'][t+1]
            log_sum_mat = np.zeros_like(xi)
            for h in range(t+1):
                sh, ah = trial['states'][h], trial['actions'][h]
                sh_next = trial['states'][h+1] if h+1 < t+1 else get_s(state_space, xi, sh, ah, num=1)
                # nabla_log_pi 
                log_sum_mat += get_nabla_log_pxi(xi, sh, ah, sh_next)
            
            # trial_gradsum += pow(gamma, t) * costs_hat[s, a, s_next] * log_sum_mat
            trial_gradsum += pow(gamma, t) * costs_hat[s, a] * log_sum_mat # only for cost 
        grad_sum += trial_gradsum
            
    grad_xi = grad_sum/trial_num
    
    return grad_xi   

def get_cost_from_grad(gamma, state_space, lambda_hat):
    '''
    Example 2.2    Variational policy gradient method for reinforcement learning with general utilities
    https://arxiv.org/pdf/2007.02151.pdf
    '''
    # state_space = env_dict['state_space']
    # action_space = env_dict['action_space']
    # gamma = env_dict['gamma']
    cost_hat = np.zeros_like(lambda_hat)
    
    
    lambda_s = (1 - gamma) * np.sum(lambda_hat, axis=1)
    cost_hat0 = -(1 - gamma) * np.log(lambda_s[:, np.newaxis])
    cost_hat0 = np.tile(cost_hat0, (1, lambda_hat.shape[1]))

    # for s in state_space:
    #     lambda_s = (1-gamma) * np.sum(lambda_hat[s,:])
    #     cost_hat[s,:] = -(1-gamma) * np.log(lambda_s)
        
    # print(cost_hat0, cost_hat)
    return cost_hat0


def obtain_stochastic_function(env_dict, trial_nums, trial_lens):
    '''
    Algorithm 1
    Get function $\hat{F}(z)=[g_theta, -g_xi]$
    '''
    m_lambda, m_theta, m_xi = trial_nums['lambda'], trial_nums['theta'], trial_nums['xi']
    H_lambda, H_theta, H_xi = trial_lens['lambda'], trial_lens['theta'], trial_lens['xi']
    

    num_states = env_dict['num_states']
    num_actions = env_dict['num_actions']
    gamma = env_dict['gamma']
    lambda_trials = generate_trials(env_dict, m_lambda, H_lambda, state_init=None)
    # lambda_hat = get_lambda_stochastic(env_dict, lambda_trials)
    lambda_hat = get_lambda_stochastic(num_states, num_actions, gamma,  lambda_trials)

    # print('lambda_hat.shape: {}'.format(lambda_hat.shape)) 
    
    # get hat cost
    # Type I: linear type f = <cost,lambda>, partial f/partial lambda = cost
    # cost_hat = np.sum(env_dict['cost'], axis=2)
    
    # Type II:entropy test
    # cost_hat = get_cost_hat(env_dict, lambda_hat)
    state_space = env_dict['state_space']
    cost_hat = get_cost_from_grad(gamma, state_space, lambda_hat)
    
    
    theta_trials = generate_trials(env_dict, m_theta, H_theta, state_init=None)
    grad_theta = get_stochastic_grad_theta(num_states, num_actions, gamma, pi=env_dict['pi'], trials=theta_trials, costs_hat=cost_hat)

    
    
    xi_trials = generate_trials(env_dict, m_xi, H_xi, state_init=None)
    grad_xi = get_stochastic_grad_xi(num_states, num_actions, gamma, xi=env_dict['xi'], trials=xi_trials, costs_hat=cost_hat)
    
    grad_theta_flatten = grad_theta.flatten()
    grad_xi_flatten = grad_xi.flatten()
    
    Fz_flatten = np.concatenate((grad_theta_flatten, -grad_xi_flatten))
    # return [grad_theta, -grad_xi]
    return Fz_flatten

def z_proj(env_dict, h, num_states, num_actions, check=False):
    hpi = h[0:num_states*num_actions].reshape((num_states,num_actions))
    hxi = h[num_states*num_actions:].reshape((num_states,num_actions,num_states))
    
    xi0 = env_dict['xi0']
    xi_radius = env_dict['xi_radius']
    
    zpi = proj_L2_pi(hpi)
    zxi = proj_L2_xi(hxi,xi0,xi_radius,Psi=None,Psi_proj=None)
    
    
    if check:
        print('check in function: z_proj')
        print('check pi sum: {}'.format(np.sum(zpi, axis=1)))
        print('check xi sum (reshaped): {}'.format(np.sum(zxi, axis=2).reshape(1,-1)))
    z = np.concatenate((zpi.flatten(), zxi.flatten()))
    return z
    

def update_pi_xi(env_dict, z_new, check=False):
    num_states = env_dict['num_states']
    num_actions = env_dict['num_actions']
    pi = z_new[0:num_states*num_actions].reshape((num_states,num_actions))
    xi = z_new[num_states*num_actions:].reshape((num_states,num_actions,num_states))
    env_dict['pi'] = pi
    env_dict['xi'] = xi

    if check:
        print('check pi sum: {}'.format(np.sum(pi, axis=1)))
        print('check xi sum (reshaped): {}'.format(np.sum(xi, axis=2).reshape(1,-1)))
        # print()

    return pi, xi


def get_exact_gradient(pi, xi, cost, gamma, lambda_exact, Psi=None,V=None):
    # TODO: lambda_exact is same shape as lambda (S x A)
    num_states, num_actions = pi.shape
    V = V_func(pi,xi,cost,gamma,Psi=None)
    Q = Q_func(pi,xi,cost,gamma,Psi=None,V=V)
    
    # print('V.shape: {}, Q.shape: {}'.format(V.shape, Q.shape))
    V_theta = np.zeros_like(pi)
    V_xi = np.zeros_like(xi)
    
    V_theta = np.sum(lambda_exact[:, :, np.newaxis] * Q[:, np.newaxis, :], axis=1) / (1-gamma) 
    V_xi = lambda_exact[:, :, np.newaxis] * (xi + gamma * V)
    
    # for s in range(num_states):
    #     V_theta[s, :] = np.sum(lambda_exact[s, :], axis=1) * Q[s, :] / (1-gamma) 
    #     V_xi[:,:,s] = lambda_exact * (xi[:,:,s] + gamma*V[s]) / (1-gamma) # reshape
    return V_theta, V_xi # gradient on theta and xi

def obtain_exact_gradient(env_dict):
    # calculate exact gradient
    pi = env_dict['pi']
    xi = env_dict['xi']
    rho = env_dict['rho']
    gamma = env_dict['gamma']
    state_space = env_dict['state_space']
    
    lambda_exact = occupation(pi,xi,rho,gamma,Psi=None) # shape of S x A
    cost_exact = get_cost_from_grad(gamma, state_space, lambda_exact)
    V_theta, V_xi = get_exact_gradient(pi, xi, cost_exact, gamma, lambda_exact, Psi=None, V=None)
    
    Fz_flatten = np.concatenate((V_theta.flatten(), -V_xi.flatten()))
    return Fz_flatten

def obtain_bc_pseg_plus(env_dict, num_states, num_actions, K, alphas, beta, z_init, h_init, trial_nums, trial_lens):
    '''
    Algorithm 5
    '''
    z_prev = z_init # here z_prev means z_{-1}
    norms = np.zeros(K)
    z = np.zeros((K, z_init.shape[0]))
    
    z[0] = z_init
    print('calculating stocastic gradient based on initial parameters')
    hatF_zprev = obtain_stochastic_function(env_dict, trial_nums, trial_lens)
    hatF_z = hatF_zprev
    

    h_prev = h_init
    for k in tqdm(range(K), file=sys.stderr):
        h = z[k] - beta * hatF_z + (1-alphas[k]) * (h_prev - z_prev + beta * hatF_zprev)
        barz = z_proj(env_dict, h, num_states, num_actions, check=False)
        
        update_pi_xi(env_dict, barz, check=False)
        hatF_barz = obtain_stochastic_function(env_dict, trial_nums, trial_lens) 
        F_barz = obtain_exact_gradient(env_dict)
        norm = np.linalg.norm(z_proj(env_dict, barz + beta * F_barz, num_states, num_actions, check=False) - barz)/beta
        norms[k] = norm
        # print('Iteration {}/{}: norm = {}'.format(k, K, norm), file=sys.stdout)
        tqdm.write('Iteration {}/{}: norm = {}'.format(k, K, norm), file=sys.stdout)
        

        # update z_{k+1}
        ztemp = z[k] - alphas[k] * (h - barz + beta * hatF_barz)
        if k+1 < K:
            z[k+1]  = z_proj(env_dict, ztemp, num_states, num_actions, check=False)
            update_pi_xi(env_dict, z[k+1], check=False)
        else:
            break
            
        h_prev = h
        z_prev = z[k] 
        hatF_zprev = hatF_z # modified to be the function to use the parameters in the new environment
        hatF_z = obtain_stochastic_function(env_dict, trial_nums, trial_lens)
        
        
    k_rand = np.random.randint(0, K, size=1) # uniformly output an index from [K]
    return z[k_rand], norms