'''functions for policy optimization'''

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

from jax import value_and_grad
from jax import jacobian

import numpy as np

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

from sklearn.metrics import mean_squared_error

import data
import utils
import estimation_v3


def subgradient_descent(num_iter=10, step_size=0.1, key_start=1, X_samples=None, T_samples=None, Y_samples=None, rank=2, rho=None, method='direct', Z_estimation=True):
    
    p = X_samples.shape[1]
    pol_values = np.zeros(num_iter+1)
    obj_values = np.zeros(num_iter+1)
    phi_values = np.empty([num_iter+1,p])
    b_values = np.zeros(num_iter+1)
    phi_gradients = np.zeros([num_iter+1,p])
    b_gradients = np.zeros(num_iter+1)

    
    #Initialize random model coefficients
    
    key = random.PRNGKey(key_start)
    key, phi_key, b_key = random.split(key, 3)
    ## TODO: phi needs to be multidimensional
    phi = random.normal(phi_key, (p,))
    b = random.normal(b_key, ())
    
    phi_values[0,:] = phi
    b_values[0] = b
    
    grad_phi = grad(objective_value, argnums=0)
    grad_b = grad(objective_value, argnums=1)
    
    #print('initial phi, b: ', phi, b)
    
    for iters in range(num_iter):
        
        #print('begin iter: ', iters)
        
        #phi_grad_list = np.zeros(num_iter)
        #b_grad_list = np.zeros(num_iter)
        #pol_value_list = np.zeros(num_iter)
        
        
        if method == 'direct':
        
            mu_0_list, mu_1_list = gen_mu_values(X_samples=X_samples, T_samples=T_samples,Y_samples=Y_samples, rank=rank, rho=rho, method=method, Z_estimation=Z_estimation)
        elif method == 'IPW-obs':
            
            mu_0_list, mu_1_list = gen_mu_values(X_samples=X_samples, T_samples=T_samples,Y_samples=Y_samples, rank=rank, rho=rho, method=method, Z_estimation=Z_estimation)

        elif method == 'IPW':
            
            mu_0_list, mu_1_list = gen_mu_values(X_samples=X_samples, T_samples=T_samples,Y_samples=Y_samples, rank=rank, rho=rho, method=method, Z_estimation=Z_estimation)

        elif method == 'DR-obs':
            
            mu_0_list, mu_1_list = gen_mu_values(X_samples=X_samples, T_samples=T_samples,Y_samples=Y_samples, rank=rank, rho=rho, method=method, Z_estimation=Z_estimation)
        elif method == 'CV':
            
            mu_0_list, mu_1_list = gen_mu_values(X_samples=X_samples, T_samples=T_samples,Y_samples=Y_samples, rank=rank, rho=rho, method=method, Z_estimation=Z_estimation)
            
           
        #compute OPT x
        pol_value, costs_0_list, costs_2_list = compute_x_opt(phi=phi, b=b, X_samples=X_samples, T_samples=T_samples, mu_0_list=mu_0_list,
                         mu_1_list=mu_1_list)

        #compute objective value 
        obj_value = objective_value(phi = phi, b = b, X_samples=X_samples, T_samples=T_samples, mu_0_list=mu_0_list, mu_1_list=mu_1_list)
            
        # get gradient
        phi_grad = grad_phi(phi, b, X_samples=X_samples, T_samples=T_samples,
                            mu_0_list=mu_0_list, mu_1_list=mu_1_list)
        b_grad = grad_b(phi, b, X_samples=X_samples, T_samples=T_samples,
                            mu_0_list=mu_0_list, mu_1_list=mu_1_list)

            #phi_grad_list.append(phi_grad)
            #b_grad_list.append(b_grad)
            #pol_value_list.append(pol_value)
        
        #print('policy value: ', pol_value, 'Iter: ', iters)
        #print('phi gradients: ', phi_grad, 'Iter: ', iters)
        #print('b gradients: ', b_grad, 'Iter: ', iters)

        pol_values[iters+1] = pol_value
        obj_values[iters+1] = obj_value

        # descent
        phi = phi - step_size * phi_grad
        b = b - step_size * b_grad
        
        
        
        phi_values[iters+1,:] = phi
        phi_gradients[iters+1,:] = phi_grad
        b_values[iters+1] = b
        b_gradients[iters+1] = b_grad

        
    results_dict = {'phi_values': phi_values, 'b_values': b_values,
                    'phi_gradients': phi_gradients, 'b_gradients': b_gradients,
                   'pol_values': pol_values, 'obj_values': obj_values}

    #if iters == num_iter-1: 
    #    print('last iter phi, b: ', phi, b)
    #    print('all pol_values: ', pol_values)
        
    return results_dict



def compute_x_opt(phi=None, b=None, X_samples=None, T_samples=None, mu_0_list=None,
                 mu_1_list=None):
    

    pi_0_vals = 1 - pi_1(phi, b, X_samples)
    pi_1_vals = pi_1(phi, b, X_samples)

    pi_0_list = jnp.multiply(pi_0_vals, mu_0_list)
    pi_1_list = jnp.multiply(pi_1_vals, mu_1_list)

    costs_0 = jnp.mean(pi_0_list)
    costs_1 = jnp.mean(pi_1_list)

    policy_val = jnp.add(costs_0, costs_1)


    #cost_matrix = jnp.transpose(jnp.tile(test_sample_costs,(right_node_num,1)))
    
    #min_cost, row_ind, col_ind = utils.compute_matching(cost_matrix)
    
    policy_val = float(policy_val)
    
    
    return policy_val, pi_0_list, pi_1_list
    
    
def objective_value(phi, b, X_samples=None, T_samples=None,
                    mu_0_list=None, mu_1_list=None):
    
    '''
    
    Used for policy optimization.
    Compute obj value for a given policy, a given x, given estimators
    
    '''
    
        
    pi_0_vals = 1 - pi_1(phi, b, X_samples)
    pi_1_vals = pi_1(phi, b, X_samples)

    
    pi_0_list = jnp.multiply(pi_0_vals, mu_0_list)
    pi_1_list = jnp.multiply(pi_1_vals, mu_1_list)

    costs_0_list = jnp.mean(pi_0_list)
    costs_1_list = jnp.mean(pi_1_list)

    obj_value = jnp.add(costs_0_list, costs_1_list)

    return obj_value


def gen_mu_values(X_samples=None, T_samples=None,Y_samples=None, rank=2, rho=None, method='IPW', Z_estimation=True):
    
    global mu_0 
    global mu_1 
    if Z_estimation == True: 
        if method == 'direct':
                        
                mu_0, mu_1 = estimation_v3.fit_latent_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)


        elif method == 'IPW-obs':
                            
                mu_0, mu_1 = estimation_v3.fit_latent_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)
        elif method == 'IPW':
                            
                mu_0, mu_1 = estimation_v3.fit_latent_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)

           
        elif method == 'CV':
                            
                mu_0, mu_1 = estimation_v3.fit_latent_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)
        
        
        return mu_0, mu_1

    elif Z_estimation == False: 
        if method == 'direct':
                        
                mu_0, mu_1 = estimation_v3.fit_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)

        elif method == 'IPW-obs':
                            
                mu_0, mu_1 = estimation_v3.fit_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)         
        elif method == 'IPW':
                            
                mu_0, mu_1 = estimation_v3.fit_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)

        elif method == 'DR-obs':
                            
                mu_0, mu_1 = estimation_v3.fit_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method) 

        elif method == 'DR':
                            
                mu_0, mu_1 = estimation_v3.fit_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method) 
           
        elif method == 'CV':
                            
                mu_0, mu_1 = estimation_v3.fit_outcome(X_samples=X_samples, T_samples=T_samples, Y_samples=Y_samples, rank=rank, rho=rho, method=method)
        
        
        return mu_0, mu_1



def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)


def pi_1(phi, b, inputs):
    return sigmoid(jnp.dot(inputs, phi) + b)





