import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from gym.envs.registration import registry, register, make, spec




# Cart Pole

parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=123, metavar='N',
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='interval between training status logs (default: 1)')
parser.add_argument('--learning-rate', type=float, default=1e-4, metavar='G',
                    help='learning-rate of optimizer (default: 1e-4)')
parser.add_argument('--classifier', type=str, required=True,
                    help='classifier of policy loss (HPO-AM, HPO-AM-log, HPO-AM-root, HPO-AM-sub or HPO-AM-square)')
parser.add_argument('--value-method', type=str, required=True,
                    help='returns based or Qvalue (returns or QValue)')
parser.add_argument('--max-episode', type=int, default=10000, metavar='N',
                    help='max-episode (default: 10000)')
parser.add_argument('--actor-delay', type=int, default=1, metavar='N',
                    help='actor-delay (default: 1)')
parser.add_argument('--wandb', action='store_true',
                    help='use wandb for profiling')
parser.add_argument('--margin', type=float, default=0.1, metavar='G',
                    help='margin (default: 0.1)')
parser.add_argument('--explore', type=bool, default=False,
                    help='explore action with epsilon 0.1, default=False')
parser.add_argument('--envi', type=str,
                    help='4x4 or 8x8')
parser.add_argument('--weight', action='store_true',
                    help='marginal loss y = advantage (default:y = advantage.sign())')
parser.add_argument('--fixedEps', type=str,
                    help='fixedEps (fixedEps or fixedEpsWeight)')
parser.add_argument('--vTable', type=bool, default=False,
                    help='Use HPO-AM-log value table in 40k episode')
args = parser.parse_args()
# MARGIN = 0.1

#register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_FixedReward_4x4_Env',reward_threshold=500.0,)
#register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_4x4_Env',reward_threshold=500.0,)
if args.envi == '4x4':
    register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_4x4_Env',reward_threshold=500.0,)
    GRIDWIDTH = 4
elif args.envi == '6x6':
    register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_6x6_Env',reward_threshold=500.0,)
    GRIDWIDTH = 6
elif args.envi == '8x8':
    register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_8x8_Env',reward_threshold=500.0,)
    GRIDWIDTH = 8
envname = 'gridworld_randR_env-v0'
ACTION_DIM = 4
import gym
env = gym.make(envname)
import random
# torch.manual_seed(123)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
env.seed(args.seed)
torch.manual_seed(args.seed)
print("env.observation_space.shape:",env.observation_space.shape)
last_episode_log = []
a2c_loss_log = []
loss_log = []
ploss_log = []
vloss_log = []
a2c_ploss_log = []
a2c_vloss_log = []
import os
if args.wandb:
    import wandb
    wandb.init(project="grid fixed",
        name=str("tb_"+args.classifier)+"_weight"+str(args.weight)+"_fixedEps"+str(args.fixedEps)+ "_alpha"+str(args.margin)  + "_seed" + str(args.seed),)
    config = wandb.config
    config.learning_rate = args.learning_rate


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from datetime import datetime

hpg_checkpoint_dir = "./models/hpg_fch256_0519tabular/"+envname
if not os.path.exists(hpg_checkpoint_dir):
    os.makedirs(hpg_checkpoint_dir)
#log_dir = "./logs/hpg_fch256/{}_{}_{}_{}".format(envname, datetime.now().strftime("%m-%d-%H-%M-%S"), args.classifier ,args.value_method )
ablation = ""
if args.weight:
    ablation = "W"

if args.fixedEps == 'fixedEps':
    ablation += "CE"
elif args.fixedEps == 'fixedEpsWeight':
    ablation += "CEW"
else:
    ablation += "AE"

if args.vTable:
    ablation += "_vOptTable" # 0609

if args.explore:
    ablation += "_explore"

log_dir = "./tensorboards/tb_"+str(args.classifier)+"_weight"+str(args.weight)+"_fixedEps"+str(args.fixedEps)+ "_alpha"+str(args.margin)  + "_seed" + str(args.seed) + "_logTime" + str(args.log_interval)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
writer = SummaryWriter(log_dir=log_dir)

class TabularModel():
    def __init__(self):
        self.past_prob_table = np.full((env.observation_space.n , env.action_space.n ),0.25)
        self.past_value_table = np.full((env.observation_space.n , env.action_space.n ),0.0)
        #self.past_value_table = np.full((env.observation_space.n , env.action_space.n ),-0.5)
        self.saved_actions = []
        self.rewards = []

        if args.vTable:
            # HPO-AM-log 40k value
            #self.past_value_table[1] = np.array([-0.28, -0.22, -0.32, -0.11])
            #self.past_value_table[2] = np.array([-0.46, -0.46, -0.36, -0.29])
            #self.past_value_table[3] = np.array([-0.46, -0.48, -0.51, -0.44])
            #self.past_value_table[4] = np.array([-0.17, -0.23, -0.32, -0.34])
            #self.past_value_table[5] = np.array([-0.30, -0.47, -0.34, -0.31])
            #self.past_value_table[6] = np.array([-0.40, -0.43, -0.48, -0.46])
            #self.past_value_table[7] = np.array([-0.52, -0.48, -0.43, -0.46])
            #self.past_value_table[8] = np.array([-0.27, -0.38, -0.34, -0.41])
            #self.past_value_table[9] = np.array([-0.44, -0.61, -0.43, -0.40])
            #self.past_value_table[10] = np.array([-0.41, -0.60, -0.44, -0.42])
            #self.past_value_table[11] = np.array([-0.51, -0.52, -0.53, -0.59])
            #self.past_value_table[12] = np.array([-0.25, -0.25, -0.37, -0.25])
            #self.past_value_table[13] = np.array([-0.46, -0.40, -0.37, -0.30])
            #self.past_value_table[14] = np.array([-0.52, -0.45, -0.38, -0.35])
            #self.past_value_table[15] = np.array([-0.50, -0.55, -0.53, -0.47])
            
            # optimal value: reward = -1.2 or 1 (mean -0.2)
            # discount factor 0.99
            self.past_value_table[1] = np.array([-0.398, -0.59402, -0.59402, -0.2])
            self.past_value_table[2] = np.array([-0.59402, -0.7880798, -0.7880798, -0.398])
            self.past_value_table[3] = np.array([-0.7880798, -0.7880798, -0.980199, -0.59402])
            self.past_value_table[4] = np.array([-0.2, -0.59402, -0.59402, -0.398])
            self.past_value_table[5] = np.array([-0.398, -0.7880798, -0.7880798, -0.398])
            self.past_value_table[6] = np.array([-0.59402, -0.980199, -0.980199,-0.59402])
            self.past_value_table[7] = np.array([-0.7880798, -0.980199, -1.17039701, -0.7880798])
            self.past_value_table[8] = np.array([-0.398, -0.7880798, -0.7880798, -0.59402])
            self.past_value_table[9] = np.array([-0.59402, -0.980199, -0.980199, -0.59402])
            self.past_value_table[10] = np.array([-0.7880798, -1.17039701, -1.17039701, -0.7880798])
            self.past_value_table[11] = np.array([-0.980199, -1.17039701, -1.335869304, -0.980199])
            self.past_value_table[12] = np.array([-0.59402, -0.9880199, -0.7880798, -0.7880798])
            self.past_value_table[13] = np.array([-0.7880798, -1.17039701, -0.980199, -0.7880798])
            self.past_value_table[14] = np.array([-0.980199, -1.35869304, -1.35869304, -1.17039701])
            self.past_value_table[15] = np.array([-1.17039701, -1.35869304, -1.35869304, -1.17039701])

    def update_prob(self, state, action, prob_gradient):
        #np_prob = np.array(prob)
        #if np.sum(np_prob-np.min(np_prob)) == 0:
        #    return
        #print("[Before] table", self.past_prob_table[state,:])
        #print("value table", self.past_value_table[state,:])
        #print("adv value table", adv)
        #print("loss", prob)
        self.past_prob_table[state,:] -= 0.01*np.array(prob_gradient)
        if self.past_prob_table[state,:].sum() == 1 and np.alltrue(self.past_prob_table[state,:] >= 0):
            return

        #print("[Loss] table", self.past_prob_table[state,:])
        # refer to https://gist.github.com/daien/1272551/edd95a6154106f8e28209a1c7964623ef8397246#file-simplex_projection-py-L14
        m = ACTION_DIM
        u = np.sort(self.past_prob_table[state,:])[::-1]
        #print("u", u)
        sum_u = np.cumsum(u)
        #print("sum_u", sum_u)
        #rho1 = np.nonzero(vecS * np.arange(1, m+1) > (vecC - vecS))
        #print("rho1", rho1)
        #print(np.nonzero(u * np.arange(1, m+1) > (sum_u - 1)))
        rho = np.nonzero(u * np.arange(1, m+1) > (sum_u - 1))[0][-1]
        #print("rho", rho)
        theta = float(sum_u[rho] - 1) / (rho + 1)
        #print("theta", theta)
        self.past_prob_table[state,:] = (self.past_prob_table[state,:] - theta).clip(min=0)
        #print("[Adjust] table", self.past_prob_table[state,:], self.past_prob_table[state,:].sum())
        #vecH = vecS - vecC / (np.arange(m) + 1)

    def set_value(self, state, action, value):
        self.past_value_table[state, action] = value
    
    def update_value(self, state, action, value):
        #print("Loss", value, state, action)
        #print("Loss", value)
        #print("value table", self.past_value_table[state,:])
        #print(self.past_value_table[state, action])
        self.past_value_table[state, action] -= 0.01*value
        #print(self.past_value_table[state, action])
        #print("[After] value table", self.past_value_table[state,:])

#model = Policy() 
model = TabularModel() 
if args.wandb :
    wandb.watch(model)
eps = np.finfo(np.float32).eps.item()

def select_action(state):
    state = torch.from_numpy(state).float() 
    # print("state",state)
    probs, state_value = model(state) #state value -> 4 q (s,a)
    # print("probs",probs)
    # print("state_value",state_value)
    # create a categorical distribution over the list of probabilities of actions
    m = Categorical(probs)

    # and sample an action using the distribution
    action = m.sample()

    # save to action buffer
    # model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    
    # print("action",action)
    # model.saved_actions.append( [probs[action], state_value,state,action] )

    # the action to take (left or right)
    # return action.item()
    # return probs[action], state_value,state,action.item()
    return probs, state_value,state,action,m.log_prob(action) #4 items probs ,4 items q s value , 1 state ,1 action

def select_action_tabular(state, explore=False):
    probs = model.past_prob_table[state,:][0]
    state_value = model.past_value_table[state,:][0]
    action = np.random.choice(ACTION_DIM, 1, p=probs)
    if explore==True and np.random.rand() < 0.1:
        action = np.random.choice(ACTION_DIM, 1)

    return probs, state_value, state, action, np.log(probs[int(action[0])]+1e-8)

## hinge loss finish
def finish_episode(episode):
    """
    Training code. Calculates actor and critic loss and performs backprop.
    """
    R = 0
    saved_actions = model.saved_actions
    # print("hinge saved action",saved_actions)
    #policy_losses = [] # list to save actor (policy) loss
    #value_losses = [] # list to save critic (value) loss
    policy_gradients = [] # list to save actor (policy) loss
    value_gradients = [] # list to save critic (value) loss
    returns = [] # list to save the true values
    rewards = []
    pos_adv_prob_sums = []
    neg_adv_prob_sums = []
    prob_ratios = []

    epsilons = []
    np_advs = []
    # calculate the true value using rewards returned from the environment
    # print("args.value_method",args.value_method)
    for r in model.rewards[::-1]:
        # calculate the discounted value
        R = r + args.gamma * R
        # print("R",R)
        returns.insert(0, R)
        rewards.insert(0, r)
    #### IMPORTANT returns = torch.tensor(returns).to(torch.float32)  add  .to(torch.float32)
    #table_gradient = [0] *ACTION_DIM
    for idx in range(len(returns)):
        table_gradient = [0] *ACTION_DIM
        r = rewards[idx]
        R = returns[idx]
        
        if idx+1 < len(returns):
            next_r = rewards[idx+1]
            next_R = returns[idx+1]
            next_probs,next_value ,next_state,next_action ,next_next_state,next_log_p = saved_actions[idx+1]
        else :
            next_R = torch.tensor(0)
            next_r = torch.tensor(0)
            next_probs,next_value ,next_state,next_action ,next_next_state,next_log_p = saved_actions[idx]
        probs,value ,state,action ,next_state,log_p = saved_actions[idx]
        # adv = Q(s,a) -V(s)
        if args.value_method == 'returns':
            advantage = R - value.item()
        elif args.value_method == 'Qvalue':
            # Vs = torch.tensor(0.0)
            Vs = 0
            # print("value: ",value)
            for i in range(ACTION_DIM):
                # print("value[i]: ",value[i])
                # print("probs[i]: ",probs[i])
                Vs+=probs[i]*value[i]
            # print("Vs",Vs)
            # advantage = value[action] - Vs.clone().detach()
            advantage = value[action] - Vs
            #advantage = R - Vs
            # advantage = R - Vs.clone().detach()
        # max(0,margin- advantage *(pi - mu))
        # state_numpy = int  ( np.squeeze( action.detach().numpy() ) )
        state_numpy =  int(state)
        next_state_numpy =  int(next_state)
        action_numpy =  np.squeeze( action)
        # print("state", state_numpy )
        
        #print("y",y)
        np_adv = args.margin *  advantage #default
        # np_adv = args.margin * np.abs( advantage.clone().detach().numpy()) #default
        pos_adv_prob_sum = 0.0
        neg_adv_prob_sum = 0.0
        Vs = 0
        # print("value: ",value)
        for i in range(ACTION_DIM):
            # print("value[i]: ",value[i])
            # print("probs[i]: ",probs[i])
            #Vs+=probs[i]*value[i]
            Vs+=probs[i]*model.past_value_table[state_numpy, i]
        for a in range(ACTION_DIM):
            if value[a] > Vs: # advantage (s,a)>=0
                #pos_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
                pos_adv_prob_sum+= model.past_prob_table[state_numpy, a]
            elif value[a] < Vs:
                neg_adv_prob_sum+= model.past_prob_table[state_numpy, a]

        if args.classifier == 'HPO-AM-log' and args.value_method == 'Qvalue':
            # log version
            # print("neg_adv_prob_sum",neg_adv_prob_sum)
            epsilon = min(1.0, neg_adv_prob_sum / (pos_adv_prob_sum + 1e-8))# log(alpha *heta)+1) heta = min{1,xxxxx}
            epsilon = np.log( 1+args.margin * epsilon )
            epsilons.append(epsilon)
            # print("np_adv classifier",np_adv)
            for a in range(ACTION_DIM):
                a_adv = model.past_value_table[state_numpy, a] - Vs
                if args.weight:
                    w = np.abs(a_adv)
                else :
                    w = 1
                if args.fixedEps == 'fixedEps':
                    epsilon = args.margin
                elif args.fixedEps == 'fixedEpsWeight':
                    epsilon = args.margin* a_adv

                #if epsilon > np.sign(a_adv) * (np.log(probs[a] + 1e-8) - np.log(model.past_prob_table[ state_numpy ,a] + 1e-8)) / (np.abs(a_adv) + 1e-8):
                if epsilon > np.sign(a_adv) * (np.log(probs[a] + 1e-8) - np.log(model.past_prob_table[ state_numpy ,a] + 1e-8)):
                    action_gradient = - w * np.sign(a_adv) / (probs[a] + 1e-8)
                    #action_gradient = -np.abs(a_adv) * np.sign(a_adv) / (probs[a] + 1e-8)
                    #action_gradient = -np.abs(a_adv) * np.sign(a_adv) / (probs[a] + 1e-8)
                else:
                    action_gradient = 0
                table_gradient[a] = action_gradient
        elif args.classifier == 'HPO-AM' and args.value_method == 'Qvalue':
            # loss --> gradient
            epsilon = min(1.0, neg_adv_prob_sum / (pos_adv_prob_sum + 1e-8))# log(alpha *heta)+1) heta = min{1,xxxxx}
            epsilons.append(np_adv * epsilon)
            for a in range(ACTION_DIM):
                a_adv = model.past_value_table[state_numpy, a] - Vs
                #a_adv = value[a] - Vs
                # Pi/Mu
                prob_ratio = probs[a] / (model.past_prob_table[ state_numpy ,a] + 1e-8)
                if args.weight:
                    w = np.abs(a_adv)
                else :
                    w = 1
                if args.fixedEps == 'fixedEps':
                    epsilon = args.margin
                elif args.fixedEps == 'fixedEpsWeight':
                    epsilon = args.margin* a_adv

                #if epsilon >  np.sign(a_adv) * (prob_ratio - 1) / (np.abs(a_adv) + 1e-8):
                if epsilon >  np.sign(a_adv) * (prob_ratio - 1):
                    action_gradient = - w * np.sign(a_adv) / (model.past_prob_table[ state_numpy ,a] + 1e-8)
                else: 
                    action_gradient = 0
                #a_loss = a_adv / (model.past_prob_table[ state_numpy ,a] + 1e-8) if np.sign(a_adv)*prob_ratio > 1 else 0
                #a_loss = max(0, np.sign(a_adv) / (model.past_prob_table[ state_numpy ,a] + 1e-8))
                #a_loss = max(0, advantage / (model.past_prob_table[ state_numpy ,a] + 1e-8))
                table_gradient[a] = action_gradient
                #table_loss[a] += a_adv
            #np_adv = args.margin * heta
        elif args.classifier == 'HPO-AM-root' and args.value_method == 'Qvalue':
            heta = min(1.0, neg_adv_prob_sum / (pos_adv_prob_sum + 1e-8 ))# log(alpha *heta)+1) heta = min{1,xxxxx}
            epsilon = np.sqrt( 1 + args.margin*heta) - 1
            for a in range(ACTION_DIM):
                a_adv = model.past_value_table[state_numpy, a] - Vs
                # pi / mu
                prob_ratio = probs[a] / (model.past_prob_table[ state_numpy ,a] + 1e-8)
                if args.weight:
                    w = np.abs(a_adv)
                else :
                    w = 1
                if args.fixedEps == 'fixedEps':
                    epsilon = args.margin
                elif args.fixedEps == 'fixedEpsWeight':
                    epsilon = args.margin* a_adv
                #print(epsilon)
                #print( prob_ratio)
                #print( np.sqrt(prob_ratio))
                #print( np.sqrt(prob_ratio) - 1)
                #print( np.sign(a_adv) * (np.sqrt(prob_ratio) - 1))
                #print(np.sign(a_adv))
                #print(np.abs(a_adv))
                #print( np.sign(a_adv) * (np.sqrt(prob_ratio) - 1) / np.abs(a_adv))

                # epsilon - (sign(adv)/abs(adv))*(sqrt(pi/mu) - 1)
                if epsilon > np.sign(a_adv) * (np.sqrt(prob_ratio) - 1) :
                    # gradient = -np.abs(a_adv) * np.sign(a_adv) / 2 * sqrt(prob)
                    #action_gradient = -np.abs(a_adv) * np.sign(a_adv) / 2 * np.sqrt( model.past_prob_table[ state_numpy ,a] / (probs[a] + 1e-8))
                    #action_gradient = -np.abs(a_adv) * np.sign(a_adv) / 2 * np.sqrt(prob_ratio)
                    action_gradient = - w * np.sign(a_adv) / 2 * np.sqrt(prob_ratio)
                else: 
                    action_gradient = 0
                table_gradient[a] = action_gradient
        elif args.classifier == 'HPO-AM-sub' and args.value_method == 'Qvalue':
            heta = min(1.0, neg_adv_prob_sum / (pos_adv_prob_sum + 1e-8 ))# log(alpha *heta)+1) heta = min{1,xxxxx}
            min_mu = np.min(model.past_prob_table[ state_numpy ,:])
            epsilon = min_mu * args.margin * heta
            for a in range(ACTION_DIM):
                a_adv = model.past_value_table[state_numpy, a] - Vs
                # epsilon - (sign(adv)/abs(adv))*( pi - mu)
                if args.weight:
                    w = np.abs(a_adv)
                else :
                    w = 1
                if args.fixedEps == 'fixedEps':
                    epsilon = args.margin
                elif args.fixedEps == 'fixedEpsWeight':
                    epsilon = args.margin* a_adv

                #if epsilon > np.sign(a_adv) * (probs[a] - model.past_prob_table[ state_numpy ,a]) / (np.abs(a_adv) + 1e-8):
                if epsilon > np.sign(a_adv) * (probs[a] - model.past_prob_table[ state_numpy ,a]):
                    # gradient = -np.abs(a_adv) * np.sign(a_adv)
                    action_gradient = - w * np.sign(a_adv)
                    #action_gradient = -np.abs(a_adv) * np.sign(a_adv)
                else: 
                    action_gradient = 0
                table_gradient[a] = action_gradient
        elif args.classifier == 'HPO-AM-square' and args.value_method == 'Qvalue':
            heta = min(1.0, neg_adv_prob_sum / (pos_adv_prob_sum + 1e-8 ))# log(alpha *heta)+1) heta = min{1,xxxxx}
            epsilon = (1 + args.margin*heta)**2 - 1
            for a in range(ACTION_DIM):
                a_adv = model.past_value_table[state_numpy, a] - Vs
                # pi / mu
                prob_ratio = probs[a] / (model.past_prob_table[ state_numpy ,a] + 1e-8)
                # epsilon - (sign(adv)/abs(adv))*(sqrt(pi/mu) - 1)
                if args.weight:
                    w = np.abs(a_adv)
                else :
                    w = 1
                if args.fixedEps == 'fixedEps':
                    epsilon = args.margin
                elif args.fixedEps == 'fixedEpsWeight':
                    epsilon = args.margin* a_adv

                #if epsilon > np.sign(a_adv) * ((prob_ratio)**2 - 1) / (np.abs(a_adv) + 1e-8):
                if epsilon > np.sign(a_adv) * ((prob_ratio)**2 - 1):
                    # gradient = -np.abs(a_adv) * np.sign(a_adv) * 2 * prob
                    #action_gradient = -np.abs(a_adv) * np.sign(a_adv) * 2 * prob_ratio
                    action_gradient = - w * np.sign(a_adv) * 2 * prob_ratio
                else: 
                    action_gradient = 0
                table_gradient[a] = action_gradient
                
        pos_adv_prob_sums.append(pos_adv_prob_sum)
        neg_adv_prob_sums.append(neg_adv_prob_sum)
        prob_ratios.append(pos_adv_prob_sum / (pos_adv_prob_sum + 1e-8))

        #print("np_adv", np_adv )
        np_advs.append(np_adv)
        #policy_losses.append(np.mean(np.array(table_gradient)))
        policy_gradients.append(table_gradient)


        if args.value_method == 'returns':
            val_gradient = vloss(value,   torch.tensor([R]).clone().float() )
        elif args.value_method == 'Qvalue':
            Vnexts = 0
            if(next_state == 0):
                Vnexts = 0
            else:
                for i in range(ACTION_DIM):
                    prob_next = model.past_prob_table[ next_state, i][0]
                    q_next = model.past_value_table[ next_state, i][0]
                    Vnexts += prob_next * q_next
                    # argmax version
                    #if q_next > Vnexts:
                    #    Vnexts = q_next
            #print(Vnexts, prob_next, q_next)
            #val_loss = np.sqrt((model.past_value_table[state_numpy, action_numpy] - (args.gamma *Vnexts + r ))**2) 
            value_gradient = (model.past_value_table[state_numpy, action_numpy] - (args.gamma *Vnexts + r ))
            #val_loss = 2*( - (args.gamma *Vnexts + r )) if (model.past_value_table[state_numpy, action_numpy] - (args.gamma *Vnexts + r )) != 0 else 0
            #val_loss = 2*( - (args.gamma *Vnexts + r )) if (model.past_value_table[state_numpy, action_numpy] - (args.gamma *Vnexts + r )) != 0 else 0
            #val_loss = (model.past_value_table[state_numpy, action_numpy] - (args.gamma *Vnexts + rewards[idx] ))**2
        #print(val_loss)
        #model.set_value(state_numpy, action_numpy, R)
        #model.set_value(state_numpy, action_numpy, model.past_value_table[state_numpy, action_numpy] - 0.1*val_loss)
        #model.set_value(state_numpy, action_numpy, Vnexts + rewards[idx] )
        #model.set_value(state_numpy, action_numpy, max(R, model.past_value_table[state_numpy, action_numpy]) )
        model.update_value(state_numpy, action_numpy, value_gradient)
        model.update_prob(state_numpy ,action_numpy, table_gradient)
        #model.update_prob(state_numpy ,action_numpy, table_loss, adv_table_loss)
        value_gradients.append( value_gradient )
    mean_policy_gradients = np.mean(np.array(policy_gradients), axis=0)
    writer.add_histogram("Gradient/policy", mean_policy_gradients, episode)
    writer.add_histogram("Gradient/scalar", np.mean(np.array(value_gradients)), episode)
    loss = 0
    # reset rewards and action buffer
    del model.rewards[:]
    del model.saved_actions[:]

def evaluation_everyGrid(model):
    ret = True
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        state = torch.from_numpy( np.array([i]) ).float() 
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        #probs ,v_exp = model(state)
        #probs = probs .detach().numpy()
        probs = model.past_prob_table[i,:] 
        dir = np.argmax(probs)
        
        if i == 0:
            print('G ', end = '')
        elif dir == 0 :
            print('U ', end = '')
        elif dir == 1:
            print('R ', end = '')
        elif dir == 2:
            print('D ', end = '')
        elif dir == 3:
            print('L ', end = '')
        
        if i == 0:
            ret = ret #do nothing
            print('       ',  end = '')
        elif i >=1 and i <=GRIDWIDTH-1:
            #if i == GRIDWIDTH-1:
            #    print(i,"prob :",probs)
            if dir != 3:#  dir is not L
                ret =  False
            if probs[3] < 0.99:
                ret = False
            print('({:.2f}) '.format(probs[3]), end = '')
        elif i % GRIDWIDTH == 0 :
            if dir != 0: #  dir is not U
                ret = False
            if probs[0] < 0.99:
                ret = False
            print('({:.2f}) '.format(probs[0]), end = '')
        else :
            if dir != 0 and dir != 3: # bottom right 3*3  dir is not U or L 
                ret = False
            if probs[0] + probs[3] < 0.99:
                ret = False
            print('({:.2f}) '.format(probs[0]+probs[3]), end = '')
        
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ')
        #if i == GRIDWIDTH-1 and probs[3]<0.99:
        #    ret = False
        #if i== GRIDWIDTH and probs[0]<0.99:
        #    ret = False
    return ret
    
def log_value_table(model):
    for i in range(env.observation_space.n):
        values = model.past_value_table[i,:]
        print("{:2} [ ".format(i), end='')
        for value in values:
            print("{:4.2f} ".format(value), end='')
        print("] ", end='')
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ')

def evaluation(model):
    ret = True
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        state = torch.from_numpy( np.array([i]) ).float() 
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        #probs ,v_exp = model(state)
        #probs = probs .detach().numpy()
        probs = model.past_prob_table[i,:] 
        dir = np.argmax(probs)
        
        if i == 0:
            print('G ', end = '')
        elif dir == 0 :
            print('U ', end = '')
        elif dir == 1:
            print('R ', end = '')
        elif dir == 2:
            print('D ', end = '')
        elif dir == 3:
            print('L ', end = '')
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ')
        if i == 0:
            ret = ret #do nothing
        elif i >=1 and i <=GRIDWIDTH-1:
            #if i == GRIDWIDTH-1:
            #    print(i,"prob :",probs)
            if dir != 3:#  dir is not L
                ret =  False
        elif i % GRIDWIDTH == 0 :
            if dir != 0: #  dir is not U
                ret = False
        else :
            if dir != 0 and dir != 3: # bottom right 3*3  dir is not U or L 
                ret = False
        if i == GRIDWIDTH-1 and probs[3]<0.66:
            ret = False
        if i== GRIDWIDTH and probs[0]<0.66:
            ret = False
    return ret

def evaluation2(model):
    ret = True
    l1_norm =0.0
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        state = torch.from_numpy( np.array([i]) ).float() 
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        #probs ,v_exp = model(state)
        #probs = probs .detach().numpy()
        probs = model.past_prob_table[i,:] 
        dir = np.argmax(probs)
        if i == 0:
            ret = ret
        elif i >=1 and i <=GRIDWIDTH-1:
            l1_norm +=(1-probs[3])
            if i == GRIDWIDTH-1:
                print(i,"prob :",probs)
            if dir != 3:#  dir is not L
                ret =  False
            
        elif i % GRIDWIDTH == 0 :
            l1_norm +=(1-probs[0])
            if dir != 0: #  dir is not U
                ret = False
        else :
            l1_norm +=(1-probs[3]-probs[0])
            if dir != 0 and dir != 3: # bottom right 3*3  dir is not U or L 
                ret = False
        if i == GRIDWIDTH-1 and probs[3]<0.9:
            ret = False
        if i== GRIDWIDTH and probs[0]<0.9:
            ret = False
        if i == 0:
            print('G ', end = '')
        elif dir == 0 :
            print('U ', end = '')
        elif dir == 1:
            print('R ', end = '')
        elif dir == 2:
            print('D ', end = '')
        elif dir == 3:
            print('L ', end = '')
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ')
    return ret,l1_norm


def main():
    # origin a2c
    print("learning rate",args.learning_rate)
    print("env.action.space",env.action_space)
    running_reward = 0
    running_stepcount = 0
    a2c_reward_record = []
    a2c_reward_EWMA = []
    # run inifinitely many episodes
    # args.max_episode = 2000000
    MAX_STEP = 200
    # for i_episode in range(args.max_episode):
    #for i_episode in range(0):

    #    # reset environment and episode reward
    #    state = env.reset()
    #    distance = env.distance
    #    print("a2c state,distance",state,distance)
    #    ep_reward = 0
    #    step_count = 0 
    #    # for each episode, only run 9999 steps so that we don't 
    #    # infinite loop while learning
    #    for t in range( MAX_STEP):
    #        # state = get_discrete_state(state)
    #        # print("state: ",state)
    #        
    #        probs, value,state,action ,log_p = a2c_select_action(state)
    #        # print("442 action",action)
    #        # print("env.action_space.sample()",env.action_space.sample())
    #        # take the action
    #        next_state, true_reward, done, info = env.step(action.item())
    #        # print("r1 r2 reward",r1,r2,reward)
    #        # next_state = get_discrete_state(next_state)
    #        a2c_model.saved_actions.append( [probs, value,state,action ,next_state,log_p] )
    #        if i_episode % args.log_interval == 0:
    #            last_episode_log.append(state)
    #        if args.render:
    #            env.render()
    #        state = next_state
    #        a2c_model.rewards.append(true_reward)
    #        ep_reward += true_reward
    #        if done:
    #            break
    #        step_count = t 
    #    print("ep_reward:",ep_reward)
    #    # update cumulative reward
    #    running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
    #    running_stepcount = 0.05 * abs(distance - step_count) + (1 - 0.05) * running_stepcount
    #    a2c_reward_record.append(ep_reward)
    #    a2c_reward_EWMA.append(running_reward)
    #    # perform backprop
    #    a2c_finish_episode(i_episode)

    #    # log results
    #    if i_episode % args.log_interval == 0:
    #        print('a2c Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f} running_(stepcount - beststepcount):{}'.format(
    #              i_episode, ep_reward, running_reward,running_stepcount))
    #        print("last_episode_log",last_episode_log)
    #        last_episode_log.clear()

    #    # check if we have "solved" the cart pole problem
    #    # if  running_stepcount < 0.2 :
    #    #     print("a2c episode :",i_episode)
    #    #     print("Solved! Running reward is now {} and "
    #    #           "the last episode runs to {} time steps!".format(running_reward, t))
    #    #     break
    #    if  i_episode % args.log_interval == 0 and evaluation(a2c_model) == True:
    #        print("a2c episode :",i_episode)
    #        print("Solved! Running reward is now {} and "
    #              "the last episode runs to {} time steps!".format(running_reward, t))
    #        break

    #  hinge
    running_reward = 0
    # past_prob_table = np.full((16 , 4 ),0.25)
    #past_prob_table = np.full((env.observation_space.n , env.action_space.n ),0.25)
    #past_value_table = np.full((env.observation_space.n , env.action_space.n ),0)
    # run inifinitely many episodes
    start_point_log = []
    reward_record = []
    hinge_reward_EWMA = []
    i_episode = 0
    for i_episode in range(args.max_episode):

        # reset environment and episode reward
        state = env.reset()
        distance = env.distance
        #print("state,distance",state,distance)
        start_point_log.append(state)
        ep_reward = 0
        step_count = 0 
        # for each episode, only run 9999 steps so that we don't 
        # infinite loop while learning
        for t in range(MAX_STEP):
            
            # select action from policy
            # action = select_action(state)
            #probs, value,state,action ,log_p = select_action(state)
            #if i_episode < 1000:
            if args.explore:
                probs, value,state,action ,log_p = select_action_tabular(state, explore=True)
            else:
                probs, value,state,action ,log_p = select_action_tabular(state)
            # take the action
            # print("action.item()",action.item())
            next_state, reward, done, info = env.step(action.item())
            # state, reward, done, _ = env.step(action)
            step_count += 1
            model.saved_actions.append( [probs, value,state,action ,next_state,log_p] )
            if i_episode % args.log_interval == 0:
                last_episode_log.append(state)
            if args.render:
                env.render()
            state = next_state
            model.rewards.append(reward)
            ep_reward += reward
            if done:
                break
            
        #print("ep_reward:",ep_reward)
        reward_record.append(ep_reward)
        # update cumulative reward
        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        running_stepcount = 0.05 * abs(distance - step_count) + (1 - 0.05) * running_stepcount
        hinge_reward_EWMA.append(running_reward)

        # perform backprop
        # finish_episode(past_prob_table)
        #finish_episode(i_episode,past_prob_table)
        #finish_episode(i_episode,past_prob_table, past_value_table)
        finish_episode(i_episode)
        #writer.add_scalar("Reward/EWMA(step)", running_reward, i_episode)
        eval2ret , l1_norm = evaluation2(model)
        # log results
        if i_episode % args.log_interval == 0:
            print('hinge Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f} \trunning_(stepcount - beststepcount): {}'.format(
                  i_episode, ep_reward, running_reward,running_stepcount))
            #print("last_episode_log",last_episode_log)
            last_episode_log.clear()
            writer.add_scalar("Reward/EWMA(step)", running_reward, i_episode)
            writer.add_scalar("Reward/Ep_reward", ep_reward, i_episode)
            writer.add_scalar("L1_norm", l1_norm, i_episode)
            #writer.add_scalar("Reward/EWMA(step)", running_reward, step_count)
            if args.wandb:
                wandb.log({"hinge ewma reward": running_reward})
                wandb.log({"l1_norm": l1_norm})
        # check if we have "solved" the cart pole problem
        # if  running_stepcount < 0.2 :
        #     print("hinge episode :",i_episode)
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break
        #if i_episode % args.log_interval == 0 and evaluation(model) == True:
        #if i_episode % args.log_interval == 0 and evaluation_everyGrid(model) == True:
        #    print("hinge episode :",i_episode)
        #    print("Solved! Running reward is now {} and "
        #          "the last episode runs to {} time steps!".format(running_reward, t))
        #    #log_value_table(model)
        #    #break


    ## print("start_point_log",start_point_log)
    ## print("last_episode_log",last_episode_log)
    ## print(" reward record",reward_record)
    ## print("loss log",loss_log)
    ## print("ploss log",ploss_log)
    ## print("vloss log",vloss_log)
    ## mpl.rcParams['agg.path.chunksize'] = 10000
    #plt.title('loss a2c vs hingeloss')
    #plt.xlabel('episode')
    #plt.ylabel('loss')
    #line= [plt.Line2D([],[])]*2
    #line[0],=plt.plot(loss_log,linestyle="none",marker=".",markersize=3,alpha =0.6)
    #
    #line[1],=plt.plot(a2c_loss_log,linestyle="none",marker=".",markersize=3,alpha =0.6)
    #
    #from time import gmtime, strftime
    #num=strftime("%Y-%m-%d %H:%M:%S", gmtime())
    #import re
    #num = re.sub('[^A-Za-z0-9]+', '', num)
    #plt.legend([line[0], line[1]], ['hinge_loss', 'a2c_loss'])
    ## plt.legend([line[0], line[1], line[2],line[3]], ['hinge_loss', 'hinge_reward', 'a2c_loss','a2c_reward'])
    #plt.savefig('./fig/{classifier}_A2C_randReward_loss_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    #plt.clf()
    #plt.title('reward a2c vs hinge')
    #plt.xlabel('episode')
    #plt.ylabel('reward')
    #line[0],=plt.plot(reward_record,alpha =0.8)
    #line[1],=plt.plot(a2c_reward_record,alpha =0.8)
    #plt.legend([line[0], line[1]], ['hinge_reward', 'a2c_reward'])
    #plt.savefig('./fig/{classifier}_A2C_randReward_rewardLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    #plt.clf()
    #plt.title('ewma reward a2c vs hinge')
    #plt.xlabel('episode')
    #plt.ylabel('reward')
    #line[0],=plt.plot(hinge_reward_EWMA,alpha =0.8)
    #line[1],=plt.plot(a2c_reward_EWMA,alpha =0.8)
    #plt.legend([line[0], line[1]], ['hinge_reward', 'a2c_reward'])
    #plt.savefig('./fig/{classifier}_A2C_randReward_EWMArewardLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    #plt.clf()
    #plt.title('loss a2c ')
    #plt.xlabel('episode')
    #plt.ylabel('loss')
    #line[0],=plt.plot(a2c_ploss_log,alpha =0.8)
    #line[1],=plt.plot(a2c_vloss_log,alpha =0.8)
    #plt.legend([line[0], line[1]], ['a2c policy_loss', 'a2c value_loss'])
    #plt.savefig('./fig/{classifier}_A2C_p_V_lossLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    #plt.clf()
    #plt.title('loss hinge ')
    #plt.xlabel('episode')
    #plt.ylabel('loss')
    #line[0],=plt.plot(ploss_log,alpha =0.8)
    #line[1],=plt.plot(vloss_log,alpha =0.8)
    #plt.legend([line[0], line[1]], ['hinge policy_loss', 'hinge value_loss'])
    #plt.savefig('./fig/{classifier}_hinge_p_V_lossLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    #plt.clf()
    ## line[1],=plt.plot(ploss_log,linestyle="none",marker=".",markersize=5)
    ## line[2],=plt.plot(vloss_log,linestyle="none",marker=".",markersize=5)
    ## line[3],=plt.plot(reward_record)
    #fo = open("{classifier}_prob_dir_{time}.txt".format(classifier = args.classifier ,time = num), "w")
    #print("### max prob dir from hinge Policy network",file=fo)
    #for i in range(env.observation_space.n):
    #    # [(-1, 0), (0, 1), (1, 0), (0, -1)]
    #    # values_ = q_value[state, :]
    #    # action = np.argmax(values_)
    #    state = torch.from_numpy( np.array([i]) ).float()
    #    # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
    #    probs ,v_exp = model(state)
    #    probs = probs .detach().numpy()
    #    # m = Categorical(probs)
    #    # probs = policy(state).detach().numpy()
    #    dir = np.argmax(probs)
    #    if i == 0:
    #        print('G ', end = '',file=fo)
    #    elif dir == 0 :
    #        print('U ', end = '',file=fo)
    #    elif dir == 1:
    #        print('R ', end = '',file=fo)
    #    elif dir == 2:
    #        print('D ', end = '',file=fo)
    #    elif dir == 3:
    #        print('L ', end = '',file=fo)
    #    # else:
    #    #     print('G ', end = '')
    #    if i % ACTION_DIM == 3:
    #        print(' ',file=fo)
    #print('### prob from hinge Policy network',file=fo)
    #for i in range(env.observation_space.n):
    #    # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
    #    state = torch.from_numpy( np.array([i]) ).float()
    #    # probs = policy(state).detach().numpy()
    #    # probs = policy(state).detach().numpy()
    #    probs ,v_exp = model(state)
    #    probs = probs .detach().numpy()
    #    print(i,probs," ",file=fo)
    #    # for j in range(env.action_space.n):
    #    #     print(probs[j]," ", end = '')
    #    # print(" ")
    #    # log results
    #    # if i_episode % args.log_interval == 0:
    #    #     print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
    #    #           i_episode, ep_reward, running_reward))

    #    # # check if we have "solved" the cart pole problem
    #    # if running_reward > env.spec.reward_threshold:
    #    #     print("Solved! Running reward is now {} and "
    #    #           "the last episode runs to {} time steps!".format(running_reward, t))
    #    #     break
    #print("0 :U, 1:R, 2:D, 3:L",file=fo)
    #print('### hinge q table value',file=fo)
    #for i in range(env.observation_space.n):
    #    state = torch.from_numpy( np.array([i]) ).float()
    #    probs ,q_exp = model(state)
    #    q_exp=q_exp.detach().numpy()
    #    print(q_exp,file=fo)
    #    # for j in range(env.action_space.n):
    #    #     print(q_exp[j]," ", end = '')
    #    # print(" ")
    ## print("### hinge q table dir")
    ## for i in range(env.observation_space.n):
    ##     # [(-1, 0), (0, 1), (1, 0), (0, -1)]
    ##     # values_ = q_value[state, :]
    ##     # action = np.argmax(values_)
    ##     state = torch.from_numpy( np.array([i]) ).float()
    ##     probs ,q_exp = model(state)
    ##     q_exp=q_exp.detach().numpy()
    ##     dir = np.argmax(q_exp[ :])
    ##     if i == 0:
    ##         print('G ', end = '')
    ##     elif dir == 0 :
    ##         print('U ', end = '')
    ##     elif dir == 1:
    ##         print('R ', end = '')
    ##     elif dir == 2:
    ##         print('D ', end = '')
    ##     elif dir == 3:
    ##         print('L ', end = '')
    ##     # else:
    ##     #     print('G ', end = '')
    ##     if i % ACTION_DIM == 3:
    ##         print(' ')
    ## print(' ')

    #print("### max prob dir from a2c Policy network",file=fo)
    #for i in range(env.observation_space.n):
    #    # [(-1, 0), (0, 1), (1, 0), (0, -1)]
    #    # values_ = q_value[state, :]
    #    # action = np.argmax(values_)
    #    state = torch.from_numpy( np.array([i]) ).float()
    #    # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
    #    probs ,v_exp = a2c_model(state)
    #    probs = probs .detach().numpy()
    #    # m = Categorical(probs)
    #    # probs = policy(state).detach().numpy()
    #    dir = np.argmax(probs)
    #    if i == 0:
    #        print('G ', end = '',file=fo)
    #    elif dir == 0 :
    #        print('U ', end = '',file=fo)
    #    elif dir == 1:
    #        print('R ', end = '',file=fo)
    #    elif dir == 2:
    #        print('D ', end = '',file=fo)
    #    elif dir == 3:
    #        print('L ', end = '',file=fo)
    #    # else:
    #    #     print('G ', end = '')
    #    if i % ACTION_DIM == 3:
    #        print(' ',file=fo)
    #print('### prob from a2c Policy network',file=fo)
    #for i in range(env.observation_space.n):
    #    # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
    #    state = torch.from_numpy( np.array([i]) ).float()
    #    # probs = policy(state).detach().numpy()
    #    # probs = policy(state).detach().numpy()
    #    probs ,v_exp = a2c_model(state)
    #    probs = probs .detach().numpy()
    #    print(i,probs," ",file=fo)
    #    # for j in range(env.action_space.n):
    #    #     print(probs[j]," ", end = '')
    #    # print(" ")
    #    # log results
    #    # if i_episode % args.log_interval == 0:
    #    #     print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
    #    #           i_episode, ep_reward, running_reward))

    #    # # check if we have "solved" the cart pole problem
    #    # if running_reward > env.spec.reward_threshold:
    #    #     print("Solved! Running reward is now {} and "
    #    #           "the last episode runs to {} time steps!".format(running_reward, t))
    #    #     break
    #print("0 :U, 1:R, 2:D, 3:L",file=fo)
    #print('### a2c q table value',file=fo)
    #for i in range(env.observation_space.n):
    #    state = torch.from_numpy( np.array([i]) ).float()
    #    probs ,q_exp = a2c_model(state)
    #    q_exp=q_exp.detach().numpy()
    #    print(q_exp,file=fo)
    #    # for j in range(env.action_space.n):
    #    #     print(q_exp[j]," ", end = '')
    #    # print(" ")
    # print("### a2c q table dir")
    # for i in range(env.observation_space.n):
    #     # [(-1, 0), (0, 1), (1, 0), (0, -1)]
    #     # values_ = q_value[state, :]
    #     # action = np.argmax(values_)
    #     state = torch.from_numpy( np.array([i]) ).float()
    #     probs ,q_exp = a2c_model(state)
    #     q_exp=q_exp.detach().numpy()
    #     dir = np.argmax(q_exp[ :])
    #     if i == 0:
    #         print('G ', end = '')
    #     elif dir == 0 :
    #         print('U ', end = '')
    #     elif dir == 1:
    #         print('R ', end = '')
    #     elif dir == 2:
    #         print('D ', end = '')
    #     elif dir == 3:
    #         print('L ', end = '')
    #     # else:
    #     #     print('G ', end = '')
    #     if i % ACTION_DIM == 3:
    #         print(' ')
    # print(' ')
    


if __name__ == '__main__':
    main()
