#!/usr/bin/env python
# coding: utf-8

# # Dueling Double Deep Q-Network (DDDQN)
# ---
# Implementation of the agent with OpenAI Gym's LunarLander-v2 environment. The code is based on materials from Udacity Deep Reinforcement Learning Nanodegree Program. 
# 
# ### 1. Import the Necessary Packages

# In[1]:


import gym
# !pip install box2d
import random
import torch
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

import os
import time as ti

# fig_prefix = 'figures/'+ti.strftime("%m%d-%H%M")
data_prefix = './data/'+ti.strftime("%m%d-%H%M")
s_currentpath = os.getcwd()

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--beta", type = float, default= 0.40, help="for Atk_C theshold")
parser.add_argument("--baseline", type = int, default= 1, help= " 0 - with an attack or 1 - yes, use baseline model " )
# parser.add_argument("-v", "--verbosity", action="count", default=0)
parser.add_argument("--Atktype", type = int, default= 1, help= "choose attack (treatment) type, 1 - attack_F, 2 = adversarial")
parser.add_argument("--causal", type = int, default= 1, help= "0 - Not use treatment info, 1 - use treatment info")
parser.add_argument("--network", type = int, default=0, help="0 - DQN, 1 - CEVAE, 2 - VAE, 3 - New CEVAE")
parser.add_argument("--history", type = int, default= 0, help= "0 - not show treatment history")
parser.add_argument("--attack_network_path", type = str, default="", help="path to attack model, must be specified if use adversarial attack.")
parser.add_argument("--evaluate", type = int, default= 1, help= "1 - yes, evaluate the trained agent, 0 - no" )
parser.add_argument("--Ftype", type = int, default= 1, help= "type of F: 1 - reverse, 2 - Random Noise, else - Zerout" )
parser.add_argument("--num_eval", type = int, default= 10, help= "10 times eval" )
parser.add_argument("--ratio", type = float, default= 0.2, help= "ratio of random time frame select" )
parser.add_argument("--shellratio", type = int, default= 1, help= "noise ratio for shell" )

args = parser.parse_args()
# answer = args.x**args.y

env_type = "figures"
fig_prefix = env_type + '/test_' + ti.strftime("%m%d-%H%M") + '_F'+ str(args.Ftype)

if args.baseline == 1:
    if args.causal == 0:
        fig_prefix = fig_prefix + "_baseline_Net_" +str(args.network) + "_0."+str(args.shellratio) # + env_type
    else:
        fig_prefix = fig_prefix + "_baseline_Net_"  +str(args.network) + "_0."+str(args.shellratio)
else:
    if args.causal == 0:
        fig_prefix = fig_prefix + "_Atk_T"+ str(args.Atktype)+ '_Net_' +str(args.network)+ "_0."+str(args.shellratio)
    else:
        fig_prefix = fig_prefix + "_Atk_T"+ str(args.Atktype)+ '_Net_' +str(args.network)+ "_0."+str(args.shellratio)



# !python -m pip install pyvirtualdisplay
#from pyvirtualdisplay import Display
#display = Display(visible=0, size=(1400, 900))
#display.start()

is_ipython = 'inline' in plt.get_backend()
if is_ipython:
    from IPython import display

plt.ion()


# ### 2. Instantiate the Environment and Agent
# 
# Initialize the environment in the code cell below.

# In[2]:


env = gym.make('LunarLander-v2')
env.seed(0)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)


# ### 3. Train the Agent 
# 
# Run the code cell below to train the agent from scratch. 

# In[8]:


import sys
sys.path.append("../")  # include the root directory as the main
from algorithms_step.dqn_agent import DQNAgent, DDQNAgent, DDQNPREAgent
import pandas as pd
import numpy as np
import gym
from collections import deque
import pickle
import random
import torch
from collections import deque
from algorithms_step.attacker import atk_model
state_size = 8
action_size = 4
agent = DQNAgent(state_size = state_size, action_size = action_size, seed=0)
atker = atk_model(Beta = args.beta, Ftype = args.Ftype)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

eval_num = args.num_eval
n_episodes = 1000
eps_start = 1.
eps_end=0.01
eps_decay=0.998
# max_t = 5000
s_model = 'dqn'

print("C_beta: ", atker.Beta, " T: " + str(args.Atktype), " Use baseline: ", str(args.baseline), " CEVAE: ", str(args.causal), end = "\n")

loss_book = []
scores = []
scores_std = []                    # List containing the std dev of the last 100 episodes
scores_avg = []                    # List containing the mean of the last 100 episodes
scores_atkratio = []
scores_window = deque(maxlen=100)  # last n scores
eps = eps_start                    # initialize epsilon

for i_episode in range(n_episodes):
    # treat_his = np.ones(max_t)
    cnt = 0.
    total_cnt = 0.
    state = env.reset() # reset the environment # get the current state
    score = 0                                          # initialize the score
    # atk_list = np.random.choice(max_t, ratio_t)
    while True: 
#     for t in range(max_t):
        action, act_vals = agent.act(state, eps)
        if args.baseline == 0:
            if random.randint(0,100) < args.shellratio * 10:
                C_atk = True
            else:
                C_atk = False
        else:
            C_atk = False
        next_state, reward, done, _ = env.step(action)                  # see if episode has finished
        
        if C_atk == True:
            if   args.Atktype == 1:
                fake_state, t_i = atker.Atck_F(state, C_atk)
            elif args.Atktype == 2:
                fake_state, t_i = atker.Atck_Adv(state, action, reward, next_state, done, attack_network, agent.qnetwork_target, C_atk)
            else:
                raise ValueError("Attack type not supported")

            if args.causal == 1:
                agent.step(fake_state, action, reward, next_state, done, t_i) # add the experience to the agents replay memory
            else:
                agent.step(fake_state, action, reward, next_state, done, 0.) # no treatment diff
            
            #treat_his[t] = t_i # if t_1 = 0 update into treatment history 
            cnt+=1.

        else:
            # print("ok: ", t)
            t_i = 0.
            agent.step(state, action, reward, next_state, done, t_i)
        score += reward                                # update the score
        state = next_state                             # roll over the state to next time step
        eps = max(eps_end, eps_decay*eps)              # decrease epsilon
        total_cnt +=1.
        if done:                                       # exit loop if episode finished
            break

    if i_episode == 0:
        torch.save(agent.qnetwork_local.state_dict(), '%scheckpoint_%s.pth' % (data_prefix, s_model))
    else:
        if score >= max(scores):
            torch.save(agent.qnetwork_local.state_dict(), '%scheckpoint_%s.pth' % (data_prefix, s_model))
        # else:
        #     if score == max(scores):
        #         if agent.check_loss() < min(loss_book):
        #             torch.save(agent.qnetwork_local.state_dict(), '%scheckpoint_%s.pth' % (data_prefix, s_model))
    scores_window.append(score)       # save most recent score
    scores.append(score)              # save most recent score
    scores_std.append(np.std(scores_window)) # save most recent std dev
    scores_avg.append(np.mean(scores_window)) # save most recent std dev
    scores_atkratio.append(round(cnt/total_cnt,2)) # save the attack ratio
    loss_book.append(agent.check_loss())    # else:
        #     tmp_log = data_prefix + 'checkpoint_' + s_model +'.pth'# '0626-1919checkpoint_dqn.pth'
        #     agent.qnetwork_local.load_state_dict(torch.load(tmp_log))
    if i_episode % 20 == 0:
        print("With: ",round(100*cnt/total_cnt,2),"% timing attack", end = "\n")
        print('\rEpisode {}   Score: {:.2f}, Average Score: {:.2f}, Loss: {:.2f}'.format(i_episode, score, np.mean(scores_window), agent.check_loss()))
    if np.mean(scores_window)>=195.0:
        s_msg = '\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'
        print(s_msg.format(i_episode, np.mean(scores_window)))        
        break


# save data to use latter
d_data = {'episodes': i_episode,
          'scores': scores,
          'scores_std': scores_std,
          'scores_avg': scores_avg,
          'scores_window': scores_window, 'atk_ratio': scores_atkratio, 'loss':loss_book}
pickle.dump(d_data, open('%ssim-data-%s.data' % (data_prefix, s_model), 'wb'))

d_data = pickle.load(open(data_prefix+'sim-data-dqn.data', 'rb'))
s_msg = 'Environment solved in {:d} episodes!\tAverage Score: {:.2f} +- {:.2f}'
print(s_msg.format(d_data['episodes'], np.mean(d_data['scores_window']), np.std(d_data['scores_window'])))

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

#recover data
na_raw = np.array(d_data['scores'])
na_mu = np.array(d_data['scores_avg'])
na_sigma = np.array(d_data['scores_std'])
record_attack = np.array(d_data['atk_ratio'])
loss_r = np.array(d_data['loss'])

# plot the scores
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)

# plot the sores by episode
ax1.plot(np.arange(len(na_raw)), na_raw)
ax1.set_xlim(0, len(na_raw)+1)
ax1.set_ylabel('Score')
ax1.set_xlabel('Episode # solved in ' + str(i_episode))
ax1.set_title('raw scores, red-solved 195')
ax1.axhline(y=195., xmin=0.0, xmax=1.0, color='r', linestyle='--', linewidth=0.9, alpha=0.9)
ax12 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:orange'
ax12.set_ylabel('atk ratio', color=color)  # we already handled the x-label with ax1
ax12.plot(np.arange(len(record_attack)), record_attack, color=color)
ax12.tick_params(axis='y', labelcolor=color)
# plot the average of these scores
ax2.plot(np.arange(len(na_mu)), na_mu)
ax2.fill_between(np.arange(len(na_mu)), na_mu+na_sigma, na_mu-na_sigma, facecolor='gray', alpha=0.1)
ax22 = ax2.twinx()  # instantiate a second axes that shares the same x-axis

ax22.set_ylabel('loss')  # we already handled the x-label with ax1
ax22.plot(np.arange(len(loss_r)), loss_r, 'c')
ax22.tick_params(axis='y')

if args.evaluate == 1:
    # reset the environment
    print('############# Basic Evaluate #############', end = '\n')
        # get the default brain
    state = env.reset()
    eva_agent = DQNAgent(state_size=state_size, action_size=action_size, seed=0, select_network=args.network)
    # load the weights from file
    eva_log = data_prefix + 'checkpoint_' + s_model +'.pth'# '0626-1919checkpoint_dqn.pth'
    eva_agent.qnetwork_local.load_state_dict(torch.load(eva_log))
    eva_score = 0.
    for i in range(eval_num):
        while True:
#             print(state,"shape:" ,state.shape)
            action, act_vals = eva_agent.act(state, eps=0.00)
            next_state, reward, done, _ = env.step(action)  
            state = next_state 
            eva_score += reward 
            state = next_state
            if done:
                break
        env_info = env.reset()

print("Evaluate Score "+": {}".format(eva_score/float(eval_num)))
    
ax2.axhline(y=eva_score/float(eval_num), xmin=0.0, xmax=1.0, color='green', linestyle='--', linewidth=0.9, alpha=0.9)

if args.evaluate == 1:
    # reset the environment
    print('############# Noise Evaluate #############', end = '\n')
        # get the default brain
    state = env.reset()
    eva_agent = DQNAgent(state_size=state_size, action_size=action_size, seed=0, select_network=args.network)
    # load the weights from file
    eva_log = data_prefix + 'checkpoint_' + s_model +'.pth'# '0626-1919checkpoint_dqn.pth'
    eva_agent.qnetwork_local.load_state_dict(torch.load(eva_log))
    robust_score = 0.
    for i in range(eval_num):
        while True:
            if random.randint(0,100) < args.shellratio * 10:
                state, _ = atker.Atck_F(state, True)
#                 print(state,"shape:" ,state.shape)
            action, act_vals = eva_agent.act(state, eps=0.00)
            next_state, reward, done, _ = env.step(action)  
            state = next_state 
            robust_score += reward 
            state = next_state
            if done:
                break
        env_info = env.reset()

print("Robust Score " +": {}".format(robust_score/float(eval_num)))
ax2.axhline(y=robust_score/float(eval_num), xmin=0.0, xmax=1.0, color='black', linestyle='-.', linewidth=0.9, alpha=0.9)

ax2.set_ylabel('Scores')
ax2.set_xlabel('Episode # solved in '+ str(i_episode))
ax2.set_title('avg-blue | validation-green:' + str(eva_score/float(eval_num))+'| robust:'+ str(robust_score/float(eval_num)))
    
f.tight_layout()

# f.savefig(fig_prefix + 'dqn.eps', format='eps', dpi=1200)
f.savefig(fig_prefix + '_dqn.pdf', format='pdf')



env.close()


# def train(n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
#     """Deep Q-Learning.
    
#     Params
#     ======
#         n_episodes (int): maximum number of training episodes
#         max_t (int): maximum number of timesteps per episode
#         eps_start (float): starting value of epsilon, for epsilon-greedy action selection
#         eps_end (float): minimum value of epsilon
#         eps_decay (float): multiplicative factor (per episode) for decreasing epsilon
#     """
#     scores = []                        # list containing scores from each episode
#     scores_window = deque(maxlen=100)  # last 100 scores
#     eps = eps_start                    # initialize epsilon
#     for i_episode in range(1, n_episodes+1):
#         state = env.reset()
#         score = 0
#         for t in range(max_t):
#             action = agent.act(state, eps)
#             next_state, reward, done, _ = env.step(action)
#             agent.step(state, action, reward, next_state, done)
#             state = next_state
#             score += reward
#             if done:
#                 break 
#         scores_window.append(score)       # save most recent score
#         scores.append(score)              # save most recent score
#         eps = max(eps_end, eps_decay*eps) # decrease epsilon
#         print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
#         if i_episode % 100 == 0:
#             print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
#         if np.mean(scores_window)>=200.0:
#             print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
#             torch.save(agent.qnetwork_local.state_dict(), 'checkpoint_Dueling_DDQN.pth')
#             break
#     return scores

# scores = train()

# # plot the scores
# fig = plt.figure()
# ax = fig.add_subplot(111)
# plt.plot(np.arange(len(scores)), scores)
# plt.ylabel('Score')
# plt.xlabel('Episode #')
# plt.show()


# ### 4. Watch a Smart Agent!
# 
# In the next code cell, you will load the trained weights from file to watch a smart agent!

# In[3]:


# from dqn_agent import Agent

# agent = Agent(state_size=8, action_size=4, seed=0)
# # load the weights from file
# agent.qnetwork_local.load_state_dict(torch.load('checkpoint_Dueling_DDQN.pth', map_location=lambda storage, loc: storage))

# for i in range(3):
#     state = env.reset()
#     img = plt.imshow(env.render(mode='rgb_array'))
#     for j in range(200):
#         action = agent.act(state)
#         img.set_data(env.render(mode='rgb_array')) 
#         plt.axis('off')
#         display.display(plt.gcf())
#         display.clear_output(wait=True)
#         state, reward, done, _ = env.step(action)
#         if done:
#             break 
            
# env.close()


# In[ ]:




