import os, sys
sys.path.append('..')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import argparse
from explainer.DGP_XRL import DGPXRL
import time


max_ep_len = 200

# Get the shared parameters, prepare training/testing data.
num_class = 2
seq_len = 200 
input_dim = 380 + 17 
n_action = 0
len_diff = max_ep_len - seq_len
total_data_idx = np.arange(30000)
train_idx = total_data_idx[0:int(total_data_idx.shape[0]*0.8), ]
test_idx = total_data_idx[int(total_data_idx.shape[0]*0.8):, ]

hiddens = [64, 32, 8] # [64, 64]
encoder_type = 'MLP'
rnn_cell_type = 'GRU'
n_epoch = 200
batch_size = 40
save_path = 'models_g/'
likelihood_type = 'classification'
n_stab_samples = 10

def iqm(x):
    x = np.asarray(x).flatten()
    q25, q75 = np.percentile(x, [25, 75])
    return np.mean(x[(x >= q25) & (x <= q75)])


if True:
    # Explainer 6 - DGP.
    hiddens = [64, 32, 8]
    optimizer = 'adam'
    num_inducing_points = 600
    using_ngd = False # Whether to use natural gradient descent.
    using_ksi = False # Whether to use KSI approximation, using this with other options as False.
    using_ciq = False # Whether to use Contour Integral Quadrature to approximate K_{zz}^{-1/2}, Use it together with NGD.
    using_sor = False # Whether to use SoR approximation, not applicable for KSI and CIQ.
    using_OrthogonallyDecouple = False # Using together NGD may cause numerical issue.
    grid_bound = [(-3, 3)] * hiddens[-1] * 2
    weight_x = False # True
    logit = True
    lambda_1 = 0.01 # 0.005
    local_samples = 10
    likelihood_sample_size = 16

    

    

    obs = np.load("traj_dat/obs.npy")
    
    acts = np.load("traj_dat/acts.npy")
    
    
    
    #Pseudo Trajectory

    obs_ref = obs[train_idx]

    acts_ref = acts[train_idx]

    rewards = np.load("traj_dat/rewards.npy")

    # this for benign

    # obs = obs[test_idx]


    # acts = acts[test_idx]

    # length_test = np.load("traj_dat/length_ob.npy")[test_idx]

    # this for trojan 

    obs = np.load("traj_dat/obs_trigger_2.npy")
 
    acts = np.load("traj_dat/acts_trigger_2.npy")

    length_test = np.load("traj_dat/trigger_2_length.npy")


    

    length_ref = np.load("traj_dat/length_ob.npy")[train_idx]
    

    
    
    #length_trigger = np.load("traj_dat/trigger_2_length.npy")[-64:]
    print(length_test.shape)
    rewards = rewards[:, -1, 0]
              # terminal reward
    rewards = (rewards >300).astype(np.int64)

    dgp_explainer = DGPXRL(train_len=train_idx.shape[0], seq_len=seq_len, len_diff=len_diff, input_dim=input_dim,
                           hiddens=hiddens, likelihood_type=likelihood_type, lr=0.01, optimizer_type=optimizer,
                           n_epoch=n_epoch, gamma=0.1, num_inducing_points=num_inducing_points, n_action=n_action,
                           grid_bounds=grid_bound, encoder_type=encoder_type, inducing_points=None,
                           mean_inducing_points=None, num_class=num_class, rnn_cell_type=rnn_cell_type,
                           using_ngd=using_ngd, using_ksi=using_ksi, using_ciq=using_ciq, using_sor=using_sor,
                           using_OrthogonallyDecouple=using_OrthogonallyDecouple, weight_x=weight_x, lambda_1=lambda_1)

    name = 'dgp_' + likelihood_type + '_' + rnn_cell_type + '_' + \
           str(num_inducing_points)+'_'+ str(using_ngd) + '_' + str(using_ngd) + '_' \
           + str(using_ksi) + '_' + str(using_ciq) + '_' + str(using_sor) + '_' \
           + str(using_OrthogonallyDecouple) + '_' + str(weight_x) + '_' + str(lambda_1) + '_' \
           + str(local_samples) + '_' + str(likelihood_sample_size) + '_' + str(logit)

    dgp_explainer.load("models_g/dgp_classification_GRU_600_False_False_False_False_False_False_False_0.01_10_16_True_model.data_200_model.data")
    
    eps_x = np.arange(64*3)
    eps_ref = np.arange(64)

    obs_ref = obs_ref[:64]
    acts_ref = acts_ref[:64]




    
for eps in [0.05,0.1,0.15,0.2,0.25]:

    obs =np.clip(obs + eps*np.random.random(obs.shape),-5,5)

    u_iqm = np.empty(64*5*3)

    for t in range(10,15):
        results = []
        latency = []
        for i in eps_x:
            
            obs_ref_cp = obs_ref.copy()
            acts_ref_cp = acts_ref.copy()
            
            results_ref = []

            
            
            indx = max_ep_len-length_test[i]+t
            
            for h in eps_ref:
                
                obs_ref_cp[h][:indx+1] = obs[i][:indx+1] 
                acts_ref_cp[h][:indx+1] = acts[i][:indx+1]

                
            
            start = time.time()
            var = dgp_explainer.get_var(eps_ref,64,obs_ref_cp,acts_ref_cp,rewards)
            
            end = time.time()

            

#     print(var)
            for h in eps_ref:

                results_ref.append(var[h][indx])

            #np.save(f"benign_kx/{t}_{i}.npy",np.array(results_ref))
            results.append(iqm(results_ref))
            u_iqm[(t-10)*64*3+i] = iqm(results_ref)
            latency.append(end-start)


        print("STEP:",t)
        print(iqm(results))
        print(results)
  

    np.save(f"{FILE}/trojan_{eps}.npy",u_iqm)
