

import os
import time as ti

# fig_prefix = 'figures/'+ti.strftime("%m%d-%H%M")
s_currentpath = os.getcwd()

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--beta", type = float, default= 0.40, help="for Atk_C theshold")
parser.add_argument("--baseline", type = int, default= 1, help= " 0 - with an attack or 1 - yes, use baseline model " )
# parser.add_argument("-v", "--verbosity", action="count", default=0)
parser.add_argument("--Atktype", type = int, default= 1, help= "choose attack (treatment) type, 1 - attack_F, 2 = adversarial")
parser.add_argument("--causal", type = int, default= 1, help= "0 - Not use treatment info, 1 - use treatment info")
parser.add_argument("--network", type = int, default=0, help="0 - DQN, 1 - CEVAE, 2 - VAE, 3 - New CEVAE")
parser.add_argument("--history", type = int, default= 0, help= "0 - not show treatment history")
parser.add_argument("--attack_network_path", type = str, default="", help="path to attack model, must be specified if use adversarial attack.")
parser.add_argument("--evaluate", type = int, default= 1, help= "1 - yes, evaluate the trained agent, 0 - no" )
parser.add_argument("--Ftype", type = int, default= 1, help= "type of F: 1 - reverse, 2 - Random Noise, else - Zerout" )
parser.add_argument("--ratio", type = float, default= 0.2, help= "ratio of random time frame select" )
parser.add_argument("--shellratio", type = int, default= 1, help= "noise ratio for shell" )

args = parser.parse_args()
# answer = args.x**args.y

floder = "figures_pixel"
fig_prefix = floder + '/test_' + ti.strftime("%m%d-%H%M") + '_F'+ str(args.Ftype)
data_prefix = './data_pixel/'+ti.strftime("%m%d-%H%M") + '_F'+ str(args.Ftype)


if args.baseline == 1:
    if args.causal == 0:
        fig_prefix = fig_prefix + "_baseline_Net_" +str(args.network)+ "_0."+str(args.shellratio) # + env_type
    else:
        fig_prefix = fig_prefix + "_baseline_Net_"  +str(args.network)+ "_0."+str(args.shellratio)
else:
    if args.causal == 0:
        fig_prefix = fig_prefix + "_Atk_T"+ str(args.Atktype)+ '_Net_' +str(args.network)+ "_0."+str(args.shellratio)
    else:
        fig_prefix = fig_prefix + "_Atk_T"+ str(args.Atktype)+ '_Net_' +str(args.network)+ "_0."+str(args.shellratio)

from unityagents import UnityEnvironment


import sys
sys.path.append("../")  # include the root directory as the main
import eda
import pandas as pd
import numpy as np
import gym
from collections import deque
import pickle
import random
import torch
from collections import deque
from algorithms_step.dqn_agent import DQNAgent, DDQNAgent, DDQNPREAgent
from algorithms_step.model import AbstractDQN
from algorithms_step.attacker import atk_model

from pixel_tool import CartPole_Pixel


env = CartPole_Pixel(gym.make('CartPole-v0'))
env.seed(0)

# reset the environment
# env.reset()
# number of actions
action_size = 2
state_size = 2500


# And finally, we are going to train the model. We will consider that this environment is solved if the agent is able to receive an average reward (over 100 episodes) of at least +13.


n_episodes = 5000
eps_start = 1.
eps_end=0.01
eps_decay=0.995
max_t = 1000
s_model = 'dqn'
ratio_t = int(200 * args.shellratio)  

# initialize
agent = DQNAgent(state_size=state_size, action_size=action_size, seed=0, select_network=args.network)
atker = atk_model(Beta = args.beta, Ftype = args.Ftype)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if args.attack_network_path != "":
    attack_network = QNetwork(state_size=state_size, action_size=action_size, seed=0)
    state = torch.load(args.attack_network_path)
    attack_network.load_state_dict(state)
    attack_network = attack_network.to(device)
    attack_network.eval()

print("Beta: ", atker.Beta, " T: " + str(args.Atktype), " Baseline: ", str(args.baseline), " CEVAE: ", str(args.causal), " Show T_His: ", str(args.history), end = "\n")


scores = []
scores_std = []                    # List containing the std dev of the last 100 episodes
scores_avg = []                    # List containing the mean of the last 100 episodes
scores_atkratio = []
scores_window = deque(maxlen=100)  # last 100 scores
eps = eps_start                    # initialize epsilon


for i_episode in range(1, n_episodes+1):
    treat_his = np.ones(max_t)
    cnt = 0.
    # env.reset()
    prev_x = env.reset()[::8,::8,0].reshape(1,2500) # reset the environment # get the current state
    state = np.zeros(2500)
    score = 0                                          # initialize the score
    atk_list = np.random.choice(max_t, ratio_t)
    for t in range(max_t):
        # action = np.random.randint(action_size)        # select a random action
        action, act_vals = agent.act(state, eps)
        if args.baseline == 0:
            if t in atk_list:
                C_atk = True
            else:
                C_atk = False
        else:
            C_atk = False
        cur_x, reward, done, _ = env.step(env.action_space.sample())
#         next_state, reward, done, _ = env.step(action)                  # see if episode has finished
        cur_x = cur_x[::8,::8,0].reshape(1,2500)
        next_state = cur_x - prev_x
        prev_x = cur_x
        if C_atk == True:
            if   args.Atktype == 1:
                fake_state, t_i = atker.Atck_F(state, C_atk)
            elif args.Atktype == 2:
                fake_state, t_i = atker.Atck_Adv(state, action, reward, next_state, done, attack_network, agent.qnetwork_target, C_atk)
            else:
                raise ValueError("Attack type not supported")

            if args.causal == 1:
                agent.step(fake_state, action, reward, next_state, done, t_i) # add the experience to the agents replay memory
            else:
                agent.step(fake_state, action, reward, next_state, done, 1.) # no treatment diff
            
            treat_his[t] = t_i # if t_1 = 0 update into treatment history 
            cnt = cnt +1.

        else:    
            # print("ok: ", t)
            t_i = 1.
            agent.step(state, action, reward, next_state, done, t_i)
            
        score += reward                                # update the score
        state = next_state                             # roll over the state to next time step
        if done:                                       # exit loop if episode finished
#             print("With: ",round(cnt/3.,2),"% Potential attack", end = "\n")
            if args.history == 1:
                treat_his = treat_his[0:t+1]
                print("Treatment history: ", treat_his, end = "\n")
            break
    scores_window.append(score)       # save most recent score
    scores.append(score)              # save most recent score
    scores_std.append(np.std(scores_window)) # save most recent std dev
    scores_avg.append(np.mean(scores_window)) # save most recent std dev
    scores_atkratio.append(cnt/300.) # save the attack ratio
    eps = max(eps_end, eps_decay*eps) # decrease epsilon
#     print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="\n")
    if i_episode % 20 == 0:
        print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
    if np.mean(scores_window)>=200.0:
        s_msg = '\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'
        print(s_msg.format(i_episode, np.mean(scores_window)))        
        break

torch.save(agent.qnetwork_local.state_dict(), '%scheckpoint_%s.pth' % (data_prefix, s_model))
# save data to use latter
d_data = {'episodes': i_episode,
          'scores': scores,
          'scores_std': scores_std,
          'scores_avg': scores_avg,
          'scores_window': scores_window, 'atk_ratio': scores_atkratio}
pickle.dump(d_data, open('%ssim-data-%s.data' % (data_prefix, s_model), 'wb'))



d_data = pickle.load(open(data_prefix+'sim-data-dqn.data', 'rb'))
s_msg = 'Environment solved in {:d} episodes!\tAverage Score: {:.2f} +- {:.2f}'
print(s_msg.format(d_data['episodes'], np.mean(d_data['scores_window']), np.std(d_data['scores_window'])))

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

#recover data
na_raw = np.array(d_data['scores'])
na_mu = np.array(d_data['scores_avg'])
na_sigma = np.array(d_data['scores_std'])

# plot the scores
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)

# plot the sores by episode
ax1.plot(np.arange(len(na_raw)), na_raw)
ax1.set_xlim(0, len(na_raw)+1)
ax1.set_ylabel('Score')
ax1.set_xlabel('Episode # solved in ' + str(i_episode))
ax1.set_title('raw scores, red-solveed 195')
ax1.axhline(y=195., xmin=0.0, xmax=1.0, color='r', linestyle='--', linewidth=0.9, alpha=0.9)

# plot the average of these scores
ax2.plot(np.arange(len(na_mu)), na_mu)
ax2.fill_between(np.arange(len(na_mu)), na_mu+na_sigma, na_mu-na_sigma, facecolor='gray', alpha=0.1)


if args.evaluate == 1:
    # reset the environment
    print('############# Basic Evaluate #############', end = '\n')
        # get the default brain
    state = env.reset()
    eva_agent = DQNAgent(state_size=state_size, action_size=action_size, seed=0, select_network=args.network)
    # load the weights from file
    eva_log = data_prefix + 'checkpoint_' + s_model +'.pth'# '0626-1919checkpoint_dqn.pth'
    eva_agent.qnetwork_local.load_state_dict(torch.load(eva_log))
    eva_score = 0.
    for i in range(4):
        while True:
            action, act_vals = eva_agent.act(state, eps=0.0001)
            next_state, reward, done, _ = env.step(action)  
            state = next_state 
            eva_score += reward 
            state = next_state
            if done:
                break
        env_info = env.reset()

print("Evaluate Score (max_t = 1000): {}".format(eva_score/4.))
    
ax2.axhline(y=eva_score/4., xmin=0.0, xmax=1.0, color='green', linestyle='--', linewidth=0.9, alpha=0.9)

# if args.evaluate == 1:
#     # reset the environment
#     print('############# Robust Evaluate #############', end = '\n')
#         # get the default brain
#     state = env.reset()
#     # examine the state space 
#     eva_agent = DQNAgent(state_size=state_size, action_size=action_size, seed=0, select_network=args.network)
#     # load the weights from file
#     eva_log = data_prefix + 'checkpoint_' + s_model +'.pth'# '0626-1919checkpoint_dqn.pth'
#     eva_agent.qnetwork_local.load_state_dict(torch.load(eva_log))
#     robust_score = 0.
#     for i in range(4):
#         while True:
#             if random.randint(0,100) < args.ratio * 100:
#                 state, _ = atker.Atck_F(state, True)
#             action, act_vals = eva_agent.act(state, eps=0.01)
#             next_state, reward, done, _ = env.step(action)  
#             robust_score += reward 
#             state = next_state
#             if done:
#                 break
#         env_info = env.reset()

# print("Robust Score (max_t = 1000): {}".format(robust_score/4.))
# ax2.axhline(y=robust_score/4., xmin=0.0, xmax=1.0, color='black', linestyle='-.', linewidth=0.9, alpha=0.9)

ax2.set_ylabel('Scores')
ax2.set_xlabel('Episode # solved in '+ str(i_episode))
ax2.set_title('avg-blue | validation-green:' + str(eva_score/4.))
    
f.tight_layout()

# f.savefig(fig_prefix + 'dqn.eps', format='eps', dpi=1200)
f.savefig(fig_prefix + '_dqn.pdf', format='pdf')

env.close()

