import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from gym.envs.registration import registry, register, make, spec




# Cart Pole

parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=123, metavar='N',
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='interval between training status logs (default: 1)')
parser.add_argument('--learning-rate', type=float, default=1e-4, metavar='G',
                    help='learning-rate of optimizer (default: 1e-4)')
parser.add_argument('--classifier', type=str, required=True,
                    help='classifier of policy loss (PPO or PI_MU_1 or root)')
parser.add_argument('--value-method', type=str, required=True,
                    help='returns based or Qvalue (returns or QValue)')
parser.add_argument('--max-episode', type=int, default=10000, metavar='N',
                    help='max-episode (default: 10000)')
parser.add_argument('--actor-delay', type=int, default=1, metavar='N',
                    help='actor-delay (default: 1)')
parser.add_argument('--wandb', action='store_true',
                    help='use wandb for profiling')
parser.add_argument('--margin', type=float, default=0.1, metavar='G',
                    help='margin (default: 0.1)')
parser.add_argument('--weight', action='store_true',
                    help='marginal loss y = advantage (default:y = advantage.sign())')
parser.add_argument('--fixedEps', type=str,
                    help='fixedEps (fixedEps or fixedEpsWeight)')
parser.add_argument('--envi', type=str,
                    help='4x4 or 8x8')
args = parser.parse_args()
# MARGIN = 0.1 args.fixedEps

# register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_FixedReward_4x4_Env',reward_threshold=500.0,)
if args.envi == '4x4':
    register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_4x4_OneHot_Env',reward_threshold=500.0,)
    GRIDWIDTH = 4
elif args.envi == '8x8':
    register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_8x8_Env',reward_threshold=500.0,)
    GRIDWIDTH = 8
elif args.envi == '6x6':
    register(id='gridworld_randR_env-v0',entry_point='gridworld_randR_env:Gridworld_RandReward_6x6_Env',reward_threshold=500.0,)
    GRIDWIDTH = 6
envname = 'gridworld_randR_env-v0'
ACTION_DIM = 4

import gym
env = gym.make(envname)
import random
# torch.manual_seed(123)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
env.seed(args.seed)
# torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
print("env.observation_space.shape:",env.observation_space.shape)
last_episode_log = []
a2c_loss_log = []
loss_log = []
ploss_log = []
vloss_log = []
a2c_ploss_log = []
a2c_vloss_log = []
import os
if args.wandb:
    import wandb
    wandb.init(project="grid one hot",
        name=str(args.classifier)+"_weight"+str(args.weight)+"_fixedEps"+str(args.fixedEps)+ "_alpha"+str(args.margin)  + "_seed" + str(args.seed),)
    config = wandb.config
    config.learning_rate = args.learning_rate


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from datetime import datetime
hpg_checkpoint_dir = "./models/hpg_fch256/"+envname
if not os.path.exists(hpg_checkpoint_dir):
    os.makedirs(hpg_checkpoint_dir)
log_dir = "./logs/hpg_fch256/{}_{}_{}_{}".format(envname, datetime.now().strftime("%m-%d-%H-%M-%S"), args.classifier ,args.value_method )
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
writer = SummaryWriter(log_dir=log_dir)

def hard_update(source,target):
    for source_param, target_param in zip(source.parameters(),target.parameters()):
        target_param.data.copy_(source_param.data)

def init_weights(m):
    if type(m) == nn.Linear:
        # torch.nn.init.xavier_uniform(m.weight)
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0.1)

class a2c_Policy(nn.Module):
    """
    implements both actor and critic in one model
    """
    def __init__(self):
        super(a2c_Policy, self).__init__()
        self.affine1 = nn.Linear( env.observation_space.n , 32)
        # self.affine1.apply(init_weights)
        self.affine2 = nn.Linear(32, 32)
        # self.affine2.apply(init_weights)
        self.affine3 = nn.Linear(64, 32)
        # self.affine3.apply(init_weights)
        # actor's layer
        self.action_head = nn.Linear(32, ACTION_DIM)

        # critic's layer 1->4  v->q
        if args.value_method == 'returns':
            self.value_head = nn.Linear(32, 1)
        elif args.value_method == 'Qvalue':
            self.value_head = nn.Linear(32, ACTION_DIM)
        # self.value_head = nn.Linear(32, 1)
        # self.value_head = nn.Linear(32, ACTION_DIM)
        # action & reward buffer
        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        """
        forward of both actor and critic
        """
        x = self.affine1(x)
        x = self.affine2(x)
        # x = self.affine3(x)
        # x = F.relu(self.affine3(x))

        # actor: choses action to take from state s_t 
        # by returning probability of each action
        # x.register_hook(lambda lmb: lmb.clamp(min=-10, max=10))
        action_prob = F.softmax(self.action_head(x), dim=-1)
        # action_prob.register_hook(lambda lmb: lmb.clamp(min=0.0, max=1.0))
        # critic: evaluates being in the state s_t
        # state_values = self.value_head( nn.Sigmoid(x) )
        state_values = self.value_head(x)
        # return values for both actor and critic as a tuple of 2 values:
        # 1. a list with the probability of each action over the action space
        # 2. the value from state s_t 
        return action_prob, state_values


a2c_model = a2c_Policy()

a2c_optimizer = optim.Adam(a2c_model.parameters(), lr=args.learning_rate)
eps = np.finfo(np.float32).eps.item()
# wandb.watch(a2c_model)

def a2c_select_action(state):
    state = torch.from_numpy(state).float()
    probs, state_value = a2c_model(state)
    # print("probs",probs)
    # create a categorical distribution over the list of probabilities of actions
    m = Categorical(probs)

    # and sample an action using the distribution
    action = m.sample()

    # save to action buffer

    # the action to take (left or right)
    # return action.item()
    return probs, state_value,state,action,m.log_prob(action)


def a2c_finish_episode(episode):
    """
    Training code. Calculates actor and critic loss and performs backprop.
    """
    R = 0
    saved_actions = a2c_model.saved_actions
    policy_losses = [] # list to save actor (policy) loss
    value_losses = [] # list to save critic (value) loss
    returns = [] # list to save the true values
    rewards = []
    # calculate the true value using rewards returned from the environment
    for r in a2c_model.rewards[::-1]:
        # calculate the discounted value
        R = r + args.gamma * R
        # print("R",R)
        returns.insert(0, R)
        rewards.insert(0, r)
        # if args.value_method == 'returns':
        #     returns.insert(0, R)
        # elif args.value_method == 'Qvalue':
        #     returns.insert(0, r)

    returns = torch.tensor(returns)
    rewards = torch.tensor(rewards)
    # returns = (returns - returns.mean()) / (returns.std() + eps)

    # for (log_prob, value), R in zip(saved_actions, returns):
    # for (probs, value,state,action ,next_state), R in zip(saved_actions, returns):
    for idx in range(len(returns)):
    # for (probs, value,state,action ,next_state,log_p), R in zip(saved_actions, returns):
        # advantage = R - value.item()
        r = rewards[idx]
        R = returns[idx]
        
        if idx+1 < len(returns):
            next_r = rewards[idx+1]
            next_R = returns[idx+1]
            next_probs,next_value ,next_state,next_action ,next_next_state,next_log_p = saved_actions[idx+1]
        else :
            next_R = torch.tensor(0)
            next_r = torch.tensor(0)
            next_probs,next_value ,next_state,next_action ,next_next_state,next_log_p = saved_actions[idx]
        probs,value ,state,action ,next_state,log_p = saved_actions[idx]

        if args.value_method == 'returns':
            advantage = R - value.item()
        elif args.value_method == 'Qvalue':
            # Vs = 0
            Vs = torch.tensor( 0.0 )
            for i in range(ACTION_DIM):
                # print("q_value[i]",q_value[i])
                Vs+=probs[i]*value[i]
            # print("Vs",Vs)
            # advantage = value[action] - Vs
            advantage = value[action] - Vs
            # advantage = R - Vs.clone().detach()
        # advantage = (R - value.item()).float()
        # calculate actor (policy) loss 
        policy_losses.append(-log_p * advantage)
        # print("policy_loss: ",-log_p * advantage)
        # calculate critic (value) loss using L1 smooth loss)
        vloss = nn.MSELoss()
        # val_loss = vloss(value[action],  args.gamma *Vnexts.detach().float()  + R.float())
        # val_loss = F.smooth_l1_loss(value,   torch.tensor([R]).float() )
        if args.value_method == 'returns':
            val_loss = vloss(value,   torch.tensor([R]).clone().float() )
        elif args.value_method == 'Qvalue':
            probs_next ,q_next = a2c_model( torch.from_numpy( np.array([next_state]) ).float() )
            # print("q_next",q_next)
            q_next = torch.squeeze(q_next)
            probs_next = torch.squeeze(probs_next)
            Vnexts = torch.tensor( 0.0 )
            # Vnexts = 0
            for i in range(ACTION_DIM):
                Vnexts+=probs_next[i]*q_next[i]
            # if(next_state == 0):
            if( np.where(next_state == 1)==0 ):
                # Vnexts = 0
                print("next_state == 0")
                Vnexts = torch.tensor( 0.0 )
            # val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  +torch.tensor( R ).detach())
            # val_loss = vloss( value[action] , args.gamma*torch.max(next_value) + r )
            # val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  +torch.tensor( r ).clone().detach())
            # val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  +torch.tensor( r ).clone().detach())
            val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  + r ) 
        # print("val loss: ",val_loss)
        # value_losses.append(F.smooth_l1_loss(Vs, value[action] ))
        value_losses.append(val_loss)

    # reset gradients
    a2c_optimizer.zero_grad()

    # sum up all the values of policy_losses and value_losses
    # print("policy_losses:", policy_losses)
    # print("value_losses:", value_losses)
    if episode % args.actor_delay ==0:
        loss = (torch.stack(policy_losses).sum() + torch.stack(value_losses).sum())
    else :
        loss =  torch.stack(value_losses).sum()
    print("Loss:", loss)
    loss_log_unit = loss.detach().numpy()
    vloss_log_unit = torch.stack(value_losses).sum().detach().numpy()
    ploss_log_unit = torch.stack(policy_losses).sum().detach().numpy()
    a2c_loss_log.append(loss_log_unit)
    a2c_vloss_log.append(vloss_log_unit)
    a2c_ploss_log.append(ploss_log_unit)
    writer.add_scalar("A2C/Loss/Total", loss_log_unit, episode)
    writer.add_scalar("A2C/Loss/Value_loss", vloss_log_unit, episode)
    writer.add_scalar("A2C/Loss/Policy_loss", ploss_log_unit, episode)
    if args.wandb and episode % args.log_interval ==0:
        wandb_a2c_loss_log_unit = loss_log_unit
        wandb_a2c_vloss_log_unit = vloss_log_unit
        wandb_a2c_ploss_log_unit = ploss_log_unit
        wandb.log({"a2c loss": wandb_a2c_loss_log_unit})
        wandb.log({"a2c vloss": wandb_a2c_vloss_log_unit})
        wandb.log({"a2c ploss": wandb_a2c_ploss_log_unit})
    # perform backprop
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(a2c_model.parameters(), args.clip)
    a2c_optimizer.step()
    
    # reset rewards and action buffer
    del a2c_model.rewards[:]
    del a2c_model.saved_actions[:]

# hinge
class Policy(nn.Module):
    """
    implements both actor and critic in one model
    """
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear( env.observation_space.n , 32)
        # self.affine1.apply(init_weights)
        self.affine2 = nn.Linear(32, 32)
        # self.affine2.apply(init_weights)
        self.affine3 = nn.Linear(64, 32)
        # self.affine3.apply(init_weights)
        # actor's layer
        self.action_head = nn.Linear(32, ACTION_DIM)

        # critic's layer 1->4  v->q
        # self.value_head = nn.Linear(32, 1)
        if args.value_method == 'returns':
            self.value_head = nn.Linear(32, 1)
        elif args.value_method == 'Qvalue':
            self.value_head = nn.Linear(32, ACTION_DIM)
        # action & reward buffer
        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        """
        forward of both actor and critic
        """
        
        x = self.affine1(x)
        x = self.affine2(x)
        # x = self.affine3(x)
        # x = F.relu(self.affine3(x))

        # actor: choses action to take from state s_t 
        # by returning probability of each action
        action_prob = F.softmax(self.action_head(x), dim=-1)

        # critic: evaluates being in the state s_t
        # state_values = self.value_head( nn.Sigmoid(x) )
        state_values = self.value_head(x)

        # return values for both actor and critic as a tuple of 2 values:
        # 1. a list with the probability of each action over the action space
        # 2. the value from state s_t 
        return action_prob, state_values


model = Policy() 
if args.wandb :
    wandb.watch(model)
old_model = Policy() 
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
eps = np.finfo(np.float32).eps.item()

def past_log_p(state, action,old_model):
    #state = torch.from_numpy(state).float().unsqueeze(0) 
    #state = torch.from_numpy(state).float().unsqueeze(0) 
    probs, _ = old_model(state) #state value -> 4 q (s,a)
    m = Categorical(probs)
    return m.log_prob(action)
def past_p(state, action,old_model):
    #state = torch.from_numpy(state).float().unsqueeze(0) 
    #state = torch.from_numpy(state).float().unsqueeze(0) 
    probs, _ = old_model(state) #state value -> 4 q (s,a)
    # print("past_p probs",probs)
    # m = Categorical(probs)
    return probs[action]

def select_action(state):
    state = torch.from_numpy(state).float() 
    # print("state",state)
    probs, state_value = model(state) #state value -> 4 q (s,a)
    # print("probs",probs)
    # print("state_value",state_value)
    # create a categorical distribution over the list of probabilities of actions
    m = Categorical(probs)

    # and sample an action using the distribution
    action = m.sample()

    # save to action buffer
    # model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    
    # print("action",action)
    # model.saved_actions.append( [probs[action], state_value,state,action] )

    # the action to take (left or right)
    # return action.item()
    # return probs[action], state_value,state,action.item()
    return probs, state_value,state,action,m.log_prob(action) #4 items probs ,4 items q s value , 1 state ,1 action

## hinge loss finish
# def finish_episode(past_prob_table):
def finish_episode(episode,past_prob_table):
    """
    Training code. Calculates actor and critic loss and performs backprop.
    """
    R = 0
    saved_actions = model.saved_actions
    print("args.margin",args.margin)
    # print("hinge saved action",saved_actions)
    policy_losses = [] # list to save actor (policy) loss
    value_losses = [] # list to save critic (value) loss
    returns = [] # list to save the true values
    rewards = []
    # calculate the true value using rewards returned from the environment
    # print("args.value_method",args.value_method)
    for r in model.rewards[::-1]:
        # calculate the discounted value
        R = r + args.gamma * R
        # print("R",R)
        returns.insert(0, R)
        rewards.insert(0, r)
        # if args.value_method == 'returns':
        #     returns.insert(0, R)
        # elif args.value_method == 'Qvalue':
        #     returns.insert(0, r)
    #### IMPORTANT returns = torch.tensor(returns).to(torch.float32)  add  .to(torch.float32)
    returns = torch.tensor(returns)
    rewards = torch.tensor(rewards)
    # print("returns",returns)
    # returns = (returns - returns.mean()) / (returns.std() + eps) #this line would leads to NAN
    # print("returns",returns)
    # for (probs, value,state,action ,next_state,log_p), R in zip(saved_actions, returns):
    # print("len(returns):",len(returns))
    # print("len(rewards):",len(rewards))
    for idx in range(len(returns)):
        # print("R",R)
        # R = 
        r = rewards[idx]
        R = returns[idx]
        # print("saved_actions[idx]",saved_actions[idx])
        
        if idx+1 < len(returns):
            next_r = rewards[idx+1]
            next_R = returns[idx+1]
            next_probs,next_value ,next_state,next_action ,next_next_state,next_log_p = saved_actions[idx+1]
        else :
            next_R = torch.tensor(0)
            next_r = torch.tensor(0)
            next_probs,next_value ,next_state,next_action ,next_next_state,next_log_p = saved_actions[idx]
        probs,value ,state,action ,next_state,log_p = saved_actions[idx]
        # print("probs",probs)
        # print("value",value)
        # print("state",state)
        # print("action",action)
        # print("next_state",next_state)
        # print("log p",log_p)
        # R -> Q(s,a)
        # adv = Q(s,a) -V(s)
        if args.value_method == 'returns':
            advantage = R - value.item()
        elif args.value_method == 'Qvalue':
            # Vs = torch.tensor(0.0)
            Vs = 0
            # print("value: ",value)
            for i in range(ACTION_DIM):
                # print("value[i]: ",value[i])
                # print("probs[i]: ",probs[i])
                Vs+=probs[i]*value[i]
            # print("Vs",Vs)
            # advantage = value[action] - Vs.clone().detach()
            advantage = value[action] - Vs
            # if args.wandb:
            #     wandb.log({"hinge advantage": advantage})
            #     wandb.log({"hinge advantage with returns": R - Vs.clone().detach() })
            # advantage = R - value[action].item()
            # advantage = R - Vs.clone().detach()
        # max(0,margin- advantage *(pi - mu))
        # calculate actor (policy) loss 
        # state_numpy =  int (np.squeeze( state.detach().numpy() ))
        # state_numpy =  np.squeeze( state.detach().numpy() )
        state_numpy =  np.squeeze( np.where( state.detach().numpy() == 1 ) )
        action_numpy =  np.squeeze( action.numpy() )
        # print("state", state_numpy )
        # print("action",action_numpy )
        # Classifier 1. log(pi)-log(mu)
        if args.classifier == 'AM-log':
            # print("args.classifier PPO")
            x1 = log_p.unsqueeze(0)
            x2 = past_log_p(state, action,old_model).clone().detach().unsqueeze(0)
            # print("past_log_p x2",x2)
            # x2 = torch.from_numpy(  np.array( past_prob_table[ state_numpy ,action_numpy   ] )  ).unsqueeze(0)
            # print("past_prob_table x2",x2)
            # x2 = log_p.unsqueeze(0).detach()
        elif args.classifier == 'AM':
        # Classifier 2. pi/mu - 1
            x1 = probs[action].unsqueeze(0) / past_p(state, action,old_model).clone().detach().unsqueeze(0)
            x2 = torch.from_numpy( np.array([1]) )
        elif args.classifier == 'AM-root':
            # Classifier 2. root(pi/mu) - 1
            x1 = torch.sqrt( probs[action].unsqueeze(0) / past_p(state, action,old_model).clone().detach().unsqueeze(0) )
            # print("probs[action].unsqueeze(0)",probs[action].unsqueeze(0))
            # print("past_p(state, action,old_model).clone().detach().unsqueeze(0)",past_p(state, action,old_model).clone().detach().unsqueeze(0))
            # print("x1",x1)
            x2 = torch.from_numpy( np.array([1]) )
        elif args.classifier == 'AM-sub':
            x1 = probs[action].unsqueeze(0)
            x2 = past_p(state, action,old_model).clone().detach().unsqueeze(0)
        elif args.classifier == 'AM-square':
            x1 = torch.square( probs[action].unsqueeze(0) / past_p(state, action,old_model).clone().detach().unsqueeze(0) )  
            x2 = torch.from_numpy( np.array([1]) )

        # x2 = torch.from_numpy(  np.array( past_prob_table[ state_numpy   ] )  )
        # print("x1",x1)
        # print("x2",x2)
        # x1 tensor([-0.1754], grad_fn=<UnsqueezeBackward0>)
        # x2 tensor([-0.1754])
        # print("advantage",advantage)
        # y = advantage.sign()
        # past_prob_table[ state_numpy ,action_numpy] = log_p
        if args.weight:
            y = advantage.unsqueeze(0)
        else :
            y = advantage.unsqueeze(0).sign()
        # y = advantage.unsqueeze(0).clone().detach()
        past_prob_table[ state_numpy ,action_numpy] = log_p.detach().numpy()
        # print("past_prob_table[ state_numpy ,action_numpy]",past_prob_table[ state_numpy ,action_numpy])
        # y = advantage.unsqueeze(0).sign()
        # the margin set to be epsilon -> epsilon|Advantage|
        # print("y",y)
        np_adv = args.margin *  advantage.clone().detach().numpy() #default
        # np_adv = args.margin * np.abs( advantage.clone().detach().numpy()) #default
        if args.classifier == 'AM-log' and args.value_method == 'Qvalue':
            pos_adv_prob_sum = 0.0
            neg_adv_prob_sum = 0.0
            #  = old_model(state)
            for a in range(ACTION_DIM):
            #     # print("past_p(state, action,old_model).detach().numpy()",past_p(state, action,old_model).detach().numpy())
                if value[action] > Vs: # advantage (s,a)>=0
                    pos_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
                elif value[action] < Vs:
                    neg_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
            # print("pos_adv_prob_sum",pos_adv_prob_sum)
            # print("neg_adv_prob_sum",neg_adv_prob_sum)
            if pos_adv_prob_sum == 0.0:
                # print("pos_adv_prob_sum == 0")
                heta = 1.0
            else :
                heta = min(1.0, neg_adv_prob_sum /pos_adv_prob_sum )# log(alpha *heta)+1) heta = min{1,xxxxx}
            np_adv = np.log( 1+args.margin * heta )
            # np_adv = np.log( 0.5 + args.margin * heta )
            # if args.wandb :
            #     wandb.log({"np_adv classifier": np.log( 1+args.margin * heta )})
            #     wandb.log({"np_adv default ": args.margin *  advantage.clone().detach().numpy() })
            #     wandb.log({"np_adv test ": np.log( 0.5 + args.margin * heta ) })
            # # print("np_adv classifier",np_adv)
        elif args.classifier == 'HPO-AM' and args.value_method == 'Qvalue':
            pos_adv_prob_sum = 0.0
            neg_adv_prob_sum = 0.0
            #  = old_model(state)
            for a in range(ACTION_DIM):
                if value[action] > Vs: # advantage (s,a)>=0
                    pos_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
                elif value[action] < Vs:
                    neg_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
            if pos_adv_prob_sum == 0.0:
                # print("pos_adv_prob_sum == 0")
                heta = 1.0
            else :
                heta = min(1.0, neg_adv_prob_sum /pos_adv_prob_sum )
            # log(alpha *heta)+1) heta = min{1,xxxxx}
            np_adv = args.margin * heta
        elif args.classifier == 'HPO-AM-root' and args.value_method == 'Qvalue':
            pos_adv_prob_sum = 0.0
            neg_adv_prob_sum = 0.0
            #  = old_model(state)
            for a in range(ACTION_DIM):
                if value[action] > Vs: # advantage (s,a)>=0
                    pos_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
                elif value[action] < Vs:
                    neg_adv_prob_sum+= past_p(state, a,old_model).detach().numpy()
            if pos_adv_prob_sum == 0.0:
                # print("pos_adv_prob_sum == 0")
                heta = 1.0
            else :
                heta = min(1.0, neg_adv_prob_sum /pos_adv_prob_sum )
            np_adv = np.sqrt( 1+args.margin * heta ) - 1
        elif args.classifier == 'HPO-AM-sub' and args.value_method == 'Qvalue':
            pos_adv_prob_sum = 0.0
            neg_adv_prob_sum = 0.0
            #  = old_model(state)
            minimumMu = 1.0
            for a in range(ACTION_DIM):
                p = past_p(state, a,old_model).detach().numpy()
                minimumMu = min( minimumMu , p )
                if value[action] > Vs: # advantage (s,a)>=0
                    pos_adv_prob_sum+= p
                elif value[action] < Vs:
                    neg_adv_prob_sum+= p
            if pos_adv_prob_sum == 0.0:
                # print("pos_adv_prob_sum == 0")
                heta = 1.0
            else :
                heta = min(1.0, neg_adv_prob_sum /pos_adv_prob_sum )
            np_adv = heta * args.margin * minimumMu
        elif args.classifier == 'HPO-AM-square' and args.value_method == 'Qvalue':
            pos_adv_prob_sum = 0.0
            neg_adv_prob_sum = 0.0
            for a in range(ACTION_DIM):
                p = past_p(state, a,old_model).detach().numpy()
                if value[action] > Vs: # advantage (s,a)>=0
                    pos_adv_prob_sum+= p
                elif value[action] < Vs:
                    neg_adv_prob_sum+= p
            if pos_adv_prob_sum == 0.0:
                heta = 1.0
            else :
                heta = min(1.0, neg_adv_prob_sum /pos_adv_prob_sum )
            # print("heta",heta)
            np_adv = np.square( (heta * args.margin) +1.0 )  -1.0
        if args.fixedEps == 'fixedEps':
            np_adv = args.margin
        elif args.fixedEps == 'fixedEpsWeight':
            np_adv = args.margin* advantage.clone().detach().numpy()
        # print("np_adv",np.abs( np_adv ))
        # ploss = nn.MarginRankingLoss(margin=MARGIN * np.amax( np_adv ) )
        # np_adv = args.margin * advantage.clone().detach().numpy()
        # print("np_adv default(args.margin * advantage)", np_adv)
        ploss = nn.MarginRankingLoss(margin= np_adv )
        # ploss = nn.MarginRankingLoss(margin= np.abs(np_adv ))
        # policy_loss = y*(x2-x1)
        #  Theoretically, the sufficient conditions in Theorem 3.2 can be applied on top of an arbitrary smooth monotone increasing function F to construct an objective function as
        # choose log as F
        policy_loss = ploss(x1, x2, y) 
        # policy_loss = ploss(torch.log(x1), torch.log(x2), y)
        # policy_loss = ploss(x2, x1, y)
        # hinge loss  margin - y * (x1 - x2)
        # print("policy loss",policy_loss)
        # policy_losses.append(MARGIN-log_prob * advantage)
        policy_losses.append(policy_loss)
        # calculate critic (value) loss using L1 smooth loss
        # value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))
        # q loss ->  q(s,a) - q(next_state, argmax(action)) + R
        # print("torch.from_numpy(next_state).float()",torch.from_numpy( np.array([next_state]) ).float())

        # q_next.detach()
        # print("q_next",q_next)
        # q_next.detach()
        # print("argmax q next",torch.argmax(q_next[:]))
        # print("q_next[torch.argmax(q_next)]",q_next[ torch.argmax(q_next[:])] )
        # print("Vnexts",Vnexts)
        # print("value[action]",value[action])
        # print("Vnexts",Vnexts)
        # value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))
        # val_loss = F.smooth_l1_loss(value[action],  args.gamma *q_next[ torch.argmax(q_next[:])].detach() +torch.tensor( R ).detach()  )
        vloss = nn.MSELoss()
        # val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  +R.clone().detach())
        # val_loss = F.smooth_l1_loss(value[action],  args.gamma *Vnexts.detach()  +torch.tensor( R ).detach()  )
        # val_loss = F.smooth_l1_loss(value[action],  Vnexts.detach()  +torch.tensor( R ).detach()  )
        # val_loss = F.smooth_l1_loss(value,   torch.tensor([R]).float() )
        # val_loss = vloss(value,   torch.tensor([R]).float() )
        if args.value_method == 'returns':
            val_loss = vloss(value,   torch.tensor([R]).clone().float() )
        elif args.value_method == 'Qvalue':
            # print("next_state",next_state)
            # print("torch.from_numpy( np.array([next_state]) ).float()",torch.from_numpy( np.array([next_state]) ).float())
            probs_next ,q_next = model( torch.from_numpy( np.array([next_state]) ).float()  )
            # probs_next ,q_next = model( torch.from_numpy( np.array([next_state]) ).float() )
            # print("probs_next ,q_next:",probs_next ,q_next)
            q_next = torch.squeeze(q_next)
            probs_next = torch.squeeze(probs_next)
            # calculate Value(next_state)  from next_state , q value and probability from model
            Vnexts = torch.tensor( 0.0 )
            # Vnexts = 0
            for i in range(ACTION_DIM):
                Vnexts+=probs_next[i]*q_next[i]
            if( np.where(next_state == 1)==0 ):
                # Vnexts = 0
                print("next_state == 0")
                Vnexts = torch.tensor( 0.0 )
            # print("value[action]: ",value[action])
            # print("Vnexts.clone().detach()",Vnexts.clone().detach())
            # print("torch.tensor( r ).clone().detach()",torch.tensor( r ).clone().detach())
            # print("torch.max(next_value): " ,torch.max(next_value) )
            # val_loss = vloss( value[action] , args.gamma*torch.max(next_value) + r )
            # val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  +torch.tensor( R ).clone().detach())
            val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  + r ) 
            # print("r: ",r)
            # print("val_loss",val_loss)
            # val_loss = F.smooth_l1_loss(value[action],  args.gamma *Vnexts.clone().detach()  +torch.tensor( [r] ).clone().detach()  )
            # val_loss = vloss(value[action],   torch.tensor( R ).clone().detach() ) # test with returns
            # val_loss = vloss(value[action],  args.gamma *Vnexts.clone().detach()  +torch.tensor( R ).clone().detach())
        # print("val_loss",val_loss)
        value_losses.append( val_loss )
        # value_losses.append(F.smooth_l1_loss(value[action],  q_next[torch.argmax(q_next[:])].detach() +torch.tensor( R ).detach()  ))
    # reset gradients
    hard_update(model, old_model)
    optimizer.zero_grad()
    # print("torch.stack(policy_losses).sum().float()",torch.stack(policy_losses).sum().float()
    # print("torch.stack(policy_losses).sum().float()",torch.stack(policy_losses).sum().float())
    # print("policy_losses:", policy_losses)
    # print("value_losses:", value_losses)
    # sum up all the values of policy_losses and value_losses
    if episode % args.actor_delay ==0:
        # print("episode mod args.actor_delay",episode)
        loss = (torch.stack(policy_losses).sum()  + torch.stack(value_losses).sum()) 
        # loss = (torch.stack(policy_losses).sum()*torch.tensor( 100.0 )  + torch.stack(value_losses).sum()) 
    else :
        loss =  torch.stack(value_losses).sum()
    # .to(torch.float32)
    print("loss",loss)
    # loss=loss.float()
    # perform backprop
    loss.backward()
    optimizer.step()
    vloss_log_unit = torch.stack(value_losses).sum().detach().numpy()
    ploss_log_unit = torch.stack(policy_losses).sum().detach().numpy()
    loss_log_unit = loss.detach().numpy()
    vloss_log.append(vloss_log_unit)
    ploss_log.append(ploss_log_unit)
    loss_log.append(loss_log_unit)
    writer.add_scalar("HPG/Loss/Total", loss_log_unit, episode)
    writer.add_scalar("HPG/Loss/Value_loss", vloss_log_unit, episode)
    writer.add_scalar("HPG/Loss/Policy_loss", ploss_log_unit, episode)
    if args.wandb and episode % args.log_interval ==0:
        wandb_hinge_loss_log_unit = loss_log_unit
        wandb_hinge_vloss_log_unit = vloss_log_unit
        wandb_hinge_ploss_log_unit = ploss_log_unit
        wandb.log({"hinge loss": loss_log_unit})
        wandb.log({"hinge vloss": vloss_log_unit})
        wandb.log({"hinge ploss": ploss_log_unit})
    # reset rewards and action buffer
    del model.rewards[:]
    del model.saved_actions[:]


def evaluation(model):
    ret = True
    # l1_norm =0.0
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        state = torch.from_numpy( np.array([i]) ).float() 
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        probs ,v_exp = model(state)
        probs = probs .detach().numpy()
        dir = np.argmax(probs)
        if i == 0:
            ret = ret
        elif i >=1 and i <=GRIDWIDTH-1:
            if i == GRIDWIDTH-1:
                print(i,"prob :",probs)
            if dir != 3:#  dir is not L
                ret =  False
            
        elif i % GRIDWIDTH == 0 :
            if dir != 0: #  dir is not U
                ret = False
        else :
            if dir != 0 and dir != 3: # bottom right 3*3  dir is not U or L 
                ret = False
        if i == GRIDWIDTH-1 and probs[3]<0.9:
            ret = False
        if i== GRIDWIDTH and probs[0]<0.9:
            ret = False
        if i == 0:
            print('G ', end = '')
        elif dir == 0 :
            print('U ', end = '')
        elif dir == 1:
            print('R ', end = '')
        elif dir == 2:
            print('D ', end = '')
        elif dir == 3:
            print('L ', end = '')
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ')
    return ret

def evaluation2(model):
    ret = True
    l1_norm =0.0
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        probs ,v_exp = model(state)
        probs = probs .detach().numpy()
        dir = np.argmax(probs)
        if i == 0:
            ret = ret
        elif i >=1 and i <=GRIDWIDTH-1:
            l1_norm +=(1-probs[3])
            if i == GRIDWIDTH-1:
                print(i,"prob :",probs)
            if dir != 3:#  dir is not L
                ret =  False
            
        elif i % GRIDWIDTH == 0 :
            l1_norm +=(1-probs[0])
            if dir != 0: #  dir is not U
                ret = False
        else :
            l1_norm +=(1-probs[3]-probs[0])
            if dir != 0 and dir != 3: # bottom right 3*3  dir is not U or L 
                ret = False
        if i == GRIDWIDTH-1 and probs[3]<0.9:
            ret = False
        if i== GRIDWIDTH and probs[0]<0.9:
            ret = False
        if i == 0:
            print('G ', end = '')
        elif dir == 0 :
            print('U ', end = '')
        elif dir == 1:
            print('R ', end = '')
        elif dir == 2:
            print('D ', end = '')
        elif dir == 3:
            print('L ', end = '')
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ')
    return ret,l1_norm

def main():
    # origin a2c
    print("learning rate",args.learning_rate)
    print("env.action.space",env.action_space)
    running_reward = 0
    running_stepcount = 0
    a2c_reward_record = []
    a2c_reward_EWMA = []
    # run inifinitely many episodes
    # args.max_episode = 2000000
    MAX_STEP = 200
    a2c_max_episode = args.max_episode if args.classifier == 'a2c' else 0
    for i_episode in range(a2c_max_episode):
    # for i_episode in range(args.max_episode):
    # for i_episode in range(0):

        # reset environment and episode reward
        state = env.reset()
        distance = env.distance
        print("a2c state,distance",state,distance)
        ep_reward = 0
        step_count = 0 
        # for each episode, only run 9999 steps so that we don't 
        # infinite loop while learning
        for t in range( MAX_STEP):
            # state = get_discrete_state(state)
            # print("state: ",state)
            
            probs, value,state,action ,log_p = a2c_select_action(state)
            # print("442 action",action)
            # print("env.action_space.sample()",env.action_space.sample())
            # take the action
            next_state, true_reward, done, info = env.step(action.item())
            # print("r1 r2 reward",r1,r2,reward)
            # next_state = get_discrete_state(next_state)
            a2c_model.saved_actions.append( [probs, value,state,action ,next_state,log_p] )
            if i_episode % args.log_interval == 0:
                last_episode_log.append(np.where( state.detach().numpy() == 1 ))
            if args.render:
                env.render()
            state = next_state
            a2c_model.rewards.append(true_reward)
            ep_reward += true_reward
            if done:
                break
            step_count = t 
        print("ep_reward:",ep_reward)
        # update cumulative reward
        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        running_stepcount = 0.05 * abs(distance - step_count) + (1 - 0.05) * running_stepcount
        a2c_reward_record.append(ep_reward)
        a2c_reward_EWMA.append(running_reward)
        # perform backprop
        a2c_finish_episode(i_episode)
        eval2ret , l1_norm = evaluation2(a2c_model)
        # log results
        if i_episode % args.log_interval == 0:
            print('a2c Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f} running_(stepcount - beststepcount):{}'.format(
                  i_episode, ep_reward, running_reward,running_stepcount))
            print("last_episode_log",last_episode_log)
            last_episode_log.clear()
            if args.wandb:
                wandb.log({"a2c ewma reward": running_reward})
                wandb.log({"l1_norm": l1_norm})
        # check if we have "solved" the cart pole problem
        # if  running_stepcount < 0.2 :
        #     print("a2c episode :",i_episode)
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break
        # if  i_episode % args.log_interval == 0 and evaluation(a2c_model) == True:
        #     print("a2c episode :",i_episode)
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break

    #  hinge
    running_reward = 0
    # past_prob_table = np.full((16 , 4 ),0.25)
    past_prob_table = np.full((env.observation_space.n , env.action_space.n ),0.25)
    # run inifinitely many episodes
    start_point_log = []
    reward_record = []
    hinge_reward_EWMA = []
    i_episode = 0
    hinge_max__episode = args.max_episode if args.classifier != 'a2c' else 0
    for i_episode in range(hinge_max__episode):
    # for i_episode in range(0):

        # reset environment and episode reward
        state = env.reset()
        distance = env.distance
        print("state,distance",state,distance)
        start_point_log.append(state)
        ep_reward = 0
        step_count = 0 
        # for each episode, only run 9999 steps so that we don't 
        # infinite loop while learning
        for t in range(MAX_STEP):
            
            # select action from policy
            # action = select_action(state)
            probs, value,state,action ,log_p = select_action(state)
            # take the action
            # print("action.item()",action.item())
            next_state, reward, done, info = env.step(action.item())
            # state, reward, done, _ = env.step(action)
            step_count += 1
            model.saved_actions.append( [probs, value,state,action ,next_state,log_p] )
            if i_episode % args.log_interval == 0:
                last_episode_log.append(np.where( state.detach().numpy() == 1 ))
            if args.render:
                env.render()
            state = next_state
            model.rewards.append(reward)
            ep_reward += reward
            if done:
                break
            
        print("ep_reward:",ep_reward)
        reward_record.append(ep_reward)
        # update cumulative reward
        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        running_stepcount = 0.05 * abs(distance - step_count) + (1 - 0.05) * running_stepcount
        hinge_reward_EWMA.append(running_reward)

        # perform backprop
        # finish_episode(past_prob_table)
        finish_episode(i_episode,past_prob_table)
        # log results
        eval2ret , l1_norm = evaluation2(model)
        if i_episode % args.log_interval == 0:
            print('hinge Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f} \trunning_(stepcount - beststepcount): {}'.format(
                  i_episode, ep_reward, running_reward,running_stepcount))
            print("last_episode_log",last_episode_log)
            last_episode_log.clear()
            writer.add_scalar("HPG/Reward/EWMA(ep)", running_reward, i_episode)
            writer.add_scalar("HPG/Reward/Ep_reward", ep_reward, i_episode)
            writer.add_scalar("HPG/Reward/EWMA(step)", running_reward, step_count)
            if args.wandb:
                wandb.log({"hinge ewma reward": running_reward})
                wandb.log({"l1_norm": l1_norm})
        # check if we have "solved" the cart pole problem
        # if  running_stepcount < 0.2 :
        #     print("hinge episode :",i_episode)
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break
        # if i_episode % args.log_interval == 0 and evaluation(model) == True:
        #     print("hinge episode :",i_episode)
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break


    # print("start_point_log",start_point_log)
    # print(" reward record",reward_record)
    # print("loss log",loss_log)
    # print("ploss log",ploss_log)
    # print("vloss log",vloss_log)
    # mpl.rcParams['agg.path.chunksize'] = 10000
    plt.title('loss a2c vs hingeloss')
    plt.xlabel('episode')
    plt.ylabel('loss')
    line= [plt.Line2D([],[])]*2
    line[0],=plt.plot(loss_log,linestyle="none",marker=".",markersize=3,alpha =0.6)
    
    line[1],=plt.plot(a2c_loss_log,linestyle="none",marker=".",markersize=3,alpha =0.6)
    
    from time import gmtime, strftime
    num=strftime("%Y-%m-%d %H:%M:%S", gmtime())
    import re
    num = re.sub('[^A-Za-z0-9]+', '', num)
    plt.legend([line[0], line[1]], ['hinge_loss', 'a2c_loss'])
    # plt.legend([line[0], line[1], line[2],line[3]], ['hinge_loss', 'hinge_reward', 'a2c_loss','a2c_reward'])
    plt.savefig('./fig/{classifier}_A2C_randReward_loss_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    plt.clf()
    plt.title('reward a2c vs hinge')
    plt.xlabel('episode')
    plt.ylabel('reward')
    line[0],=plt.plot(reward_record,alpha =0.8)
    line[1],=plt.plot(a2c_reward_record,alpha =0.8)
    plt.legend([line[0], line[1]], ['hinge_reward', 'a2c_reward'])
    plt.savefig('./fig/{classifier}_A2C_randReward_rewardLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    plt.clf()
    plt.title('ewma reward a2c vs hinge')
    plt.xlabel('episode')
    plt.ylabel('reward')
    line[0],=plt.plot(hinge_reward_EWMA,alpha =0.8)
    line[1],=plt.plot(a2c_reward_EWMA,alpha =0.8)
    plt.legend([line[0], line[1]], ['hinge_reward', 'a2c_reward'])
    plt.savefig('./fig/{classifier}_A2C_randReward_EWMArewardLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    plt.clf()
    plt.title('loss a2c ')
    plt.xlabel('episode')
    plt.ylabel('loss')
    line[0],=plt.plot(a2c_ploss_log,alpha =0.8)
    line[1],=plt.plot(a2c_vloss_log,alpha =0.8)
    plt.legend([line[0], line[1]], ['a2c policy_loss', 'a2c value_loss'])
    plt.savefig('./fig/{classifier}_A2C_p_V_lossLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    plt.clf()
    plt.title('loss hinge ')
    plt.xlabel('episode')
    plt.ylabel('loss')
    line[0],=plt.plot(ploss_log,alpha =0.8)
    line[1],=plt.plot(vloss_log,alpha =0.8)
    plt.legend([line[0], line[1]], ['hinge policy_loss', 'hinge value_loss'])
    plt.savefig('./fig/{classifier}_hinge_p_V_lossLog_{time}.png'.format(classifier = args.classifier ,time = num), format='png',dpi=200)
    plt.clf()
    # line[1],=plt.plot(ploss_log,linestyle="none",marker=".",markersize=5)
    # line[2],=plt.plot(vloss_log,linestyle="none",marker=".",markersize=5)
    # line[3],=plt.plot(reward_record)
    fo = open("{classifier}_prob_dir_{time}.txt".format(classifier = args.classifier ,time = num), "w")
    print("### max prob dir from hinge Policy network",file=fo)
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        # values_ = q_value[state, :]
        # action = np.argmax(values_)
        # state = torch.from_numpy( np.array([i]) ).float()
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        probs ,v_exp = model(state)
        probs = probs .detach().numpy()
        # m = Categorical(probs)
        # probs = policy(state).detach().numpy()
        dir = np.argmax(probs)
        if i == 0:
            print('G ', end = '',file=fo)
        elif dir == 0 :
            print('U ', end = '',file=fo)
        elif dir == 1:
            print('R ', end = '',file=fo)
        elif dir == 2:
            print('D ', end = '',file=fo)
        elif dir == 3:
            print('L ', end = '',file=fo)
        # else:
        #     print('G ', end = '')
        if i % GRIDWIDTH == GRIDWIDTH-1:
            print(' ',file=fo)
    print('### prob from hinge Policy network',file=fo)
    for i in range(env.observation_space.n):
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        # state = torch.from_numpy( np.array([i]) ).float()
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        # probs = policy(state).detach().numpy()
        # probs = policy(state).detach().numpy()
        probs ,v_exp = model(state)
        probs = probs .detach().numpy()
        print(i,probs," ",file=fo)
        # for j in range(env.action_space.n):
        #     print(probs[j]," ", end = '')
        # print(" ")
        # log results
        # if i_episode % args.log_interval == 0:
        #     print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
        #           i_episode, ep_reward, running_reward))

        # # check if we have "solved" the cart pole problem
        # if running_reward > env.spec.reward_threshold:
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break
    print("0 :U, 1:R, 2:D, 3:L",file=fo)
    print('### hinge q table value',file=fo)
    for i in range(env.observation_space.n):
        # state = torch.from_numpy( np.array([i]) ).float()
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        probs ,q_exp = model(state)
        q_exp=q_exp.detach().numpy()
        print(q_exp,file=fo)
        # for j in range(env.action_space.n):
        #     print(q_exp[j]," ", end = '')
        # print(" ")
    # print("### hinge q table dir")
    # for i in range(env.observation_space.n):
    #     # [(-1, 0), (0, 1), (1, 0), (0, -1)]
    #     # values_ = q_value[state, :]
    #     # action = np.argmax(values_)
    #     state = torch.from_numpy( np.array([i]) ).float()
    #     probs ,q_exp = model(state)
    #     q_exp=q_exp.detach().numpy()
    #     dir = np.argmax(q_exp[ :])
    #     if i == 0:
    #         print('G ', end = '')
    #     elif dir == 0 :
    #         print('U ', end = '')
    #     elif dir == 1:
    #         print('R ', end = '')
    #     elif dir == 2:
    #         print('D ', end = '')
    #     elif dir == 3:
    #         print('L ', end = '')
    #     # else:
    #     #     print('G ', end = '')
    #     if i % ACTION_DIM == 3:
    #         print(' ')
    # print(' ')

    print("### max prob dir from a2c Policy network",file=fo)
    for i in range(env.observation_space.n):
        # [(-1, 0), (0, 1), (1, 0), (0, -1)]
        # values_ = q_value[state, :]
        # action = np.argmax(values_)
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        # state = torch.from_numpy( np.array([i]) ).float()
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        probs ,v_exp = a2c_model(state)
        probs = probs .detach().numpy()
        # m = Categorical(probs)
        # probs = policy(state).detach().numpy()
        dir = np.argmax(probs)
        if i == 0:
            print('G ', end = '',file=fo)
        elif dir == 0 :
            print('U ', end = '',file=fo)
        elif dir == 1:
            print('R ', end = '',file=fo)
        elif dir == 2:
            print('D ', end = '',file=fo)
        elif dir == 3:
            print('L ', end = '',file=fo)
        # else:
        #     print('G ', end = '')
        if i % GRIDWIDTH == GRIDWIDTH -1:
            print(' ',file=fo)
    print('### prob from a2c Policy network',file=fo)
    for i in range(env.observation_space.n):
        # state = torch.from_numpy( np.array([i]) ).float().unsqueeze(0) 
        # state = torch.from_numpy( np.array([i]) ).float()
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        # probs = policy(state).detach().numpy()
        # probs = policy(state).detach().numpy()
        probs ,v_exp = a2c_model(state)
        probs = probs .detach().numpy()
        print(i,probs," ",file=fo)
        # for j in range(env.action_space.n):
        #     print(probs[j]," ", end = '')
        # print(" ")
        # log results
        # if i_episode % args.log_interval == 0:
        #     print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
        #           i_episode, ep_reward, running_reward))

        # # check if we have "solved" the cart pole problem
        # if running_reward > env.spec.reward_threshold:
        #     print("Solved! Running reward is now {} and "
        #           "the last episode runs to {} time steps!".format(running_reward, t))
        #     break
    print("0 :U, 1:R, 2:D, 3:L",file=fo)
    print('### a2c q table value',file=fo)
    for i in range(env.observation_space.n):
        # state = torch.from_numpy( np.array([i]) ).float()
        emptyarr = np.zeros(env.observation_space.n)
        emptyarr[i]=1
        state = torch.from_numpy( emptyarr ).float() 
        probs ,q_exp = a2c_model(state)
        q_exp=q_exp.detach().numpy()
        print(q_exp,file=fo)
        # for j in range(env.action_space.n):
        #     print(q_exp[j]," ", end = '')
        # print(" ")
    # print("### a2c q table dir")
    # for i in range(env.observation_space.n):
    #     # [(-1, 0), (0, 1), (1, 0), (0, -1)]
    #     # values_ = q_value[state, :]
    #     # action = np.argmax(values_)
    #     state = torch.from_numpy( np.array([i]) ).float()
    #     probs ,q_exp = a2c_model(state)
    #     q_exp=q_exp.detach().numpy()
    #     dir = np.argmax(q_exp[ :])
    #     if i == 0:
    #         print('G ', end = '')
    #     elif dir == 0 :
    #         print('U ', end = '')
    #     elif dir == 1:
    #         print('R ', end = '')
    #     elif dir == 2:
    #         print('D ', end = '')
    #     elif dir == 3:
    #         print('L ', end = '')
    #     # else:
    #     #     print('G ', end = '')
    #     if i % ACTION_DIM == 3:
    #         print(' ')
    # print(' ')
    


if __name__ == '__main__':
    main()
