import argparse
from itertools import count
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.optimize
import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from reinforcement import trpo_step
from utils import *
from motion_dynamic import dynamic1, dynamic2
import time


torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

torch.set_default_tensor_type('torch.DoubleTensor')

class Parser():
    def __init__(self,gamma=0.995,tau=0.97,l2_reg=1e-3,max_kl=1e-2,damping=1e-1,batch_size=10000,terminal_time_onetrack=500,log_interval=1,seed=543,test_reward_iteration=5):
        self.gamma=gamma
        self.tau=tau
        self.l2_reg=l2_reg
        self.max_kl=max_kl
        self.damping=damping
        self.batch_size=batch_size
        self.log_interval=log_interval
        self.test_reward_iteration=test_reward_iteration
        self.terminal_time_onetrack=terminal_time_onetrack

args = Parser()

dynamics1= dynamic1()
dynamics2= dynamic2()
num_inputs = dynamics1.num_states
num_actions = dynamics1.num_actions

def feature1(state):
  if state[0]>=7 and state[0]<=8 and state[1]>=12 and state[1]<=13:
     return np.mat([50,50]).T
  elif state[0]<0 and state[0]>9 and state[1]<0 and state[1]>13:
     return np.mat([-0.5,-0.5]).T
  else:
     return np.mat([0.0001*state[0],0.0001*state[0]]).T

def feature2(state):
  if state[0]>=1 and state[0]<=2 and state[1]>=12 and state[1]<=13:
     return np.mat([50,50]).T
  elif state[0]<0 and state[0]>9 and state[1]<0 and state[1]>13:
     return np.mat([-0.5,-0.5]).T
  else:
     return np.mat([0.0001*state[0],0.0001*state[0]]).T

def draw_trajectory1(trajectory1):
  x1=[]
  y1=[]
  steps1=len(trajectory1)
  i=0
  while i<steps1:
    x1.append(trajectory1[i][0])
    y1.append(trajectory1[i][1])
    #x2.append(trajectory[i][2])
    #y2.append(trajectory[i][3])
    i=i+1

  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,8,9))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,9,0,13])
  ax.grid(linestyle='-',color='black')
  obstacle1=plt.Rectangle((3.05,5.05),0.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle2=plt.Rectangle((3.05,6.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle3=plt.Rectangle((3.05,7.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle4=plt.Rectangle((4.05,8.05),1.9,0.9,linewidth=2,color='tab:orange',facecolor='none')


  ax.add_patch(obstacle1)
  ax.add_patch(obstacle2)
  ax.add_patch(obstacle3)
  ax.add_patch(obstacle4)

  #ax.scatter(4.5,0.5,s=160,c="r",marker="x")

  ax.text(1.35,1.35,'$s_0$',fontsize=10)
  ax.text(7.35,12.35,'$s_G$',fontsize=10)
  ax.text(7.35,1.35,'$s_0^{\prime}$',fontsize=10)
  ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  #ax.set_title('Robots motion planning')
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  

  plt.plot(x1,y1,'r',linewidth=2)
  plt.show()

def draw_trajectory(trajectory1,trajectory2):
  x1=[]
  y1=[]
  x2=[]
  y2=[]
  steps1=len(trajectory1)
  i=0
  while i<steps1:
    x1.append(trajectory1[i][0])
    y1.append(trajectory1[i][1])
    #x2.append(trajectory[i][2])
    #y2.append(trajectory[i][3])
    i=i+1

  steps2=len(trajectory2)
  i=0
  while i<steps2:
    x2.append(trajectory2[i][0])
    y2.append(trajectory2[i][1])
    #x2.append(trajectory[i][2])
    #y2.append(trajectory[i][3])
    i=i+1

  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,8,9))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,9,0,13])
  ax.grid(linestyle='-',color='black')
  obstacle1=plt.Rectangle((3.05,5.05),0.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle2=plt.Rectangle((3.05,6.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle3=plt.Rectangle((3.05,7.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle4=plt.Rectangle((4.05,8.05),1.9,0.9,linewidth=2,color='tab:orange',facecolor='none')

  ax.add_patch(obstacle1)
  ax.add_patch(obstacle2)
  ax.add_patch(obstacle3)
  ax.add_patch(obstacle4)

  #ax.scatter(4.5,0.5,s=160,c="r",marker="x")

  ax.text(1.35,1.35,'$s_0$',fontsize=10)
  ax.text(7.35,12.35,'$s_G$',fontsize=10)
  ax.text(7.35,1.35,'$s_0^{\prime}$',fontsize=10)
  ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  #ax.set_title('Robots motion planning')
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  

  plt.plot(x1,y1,'r',linewidth=2)
  plt.plot(x2,y2,'blue',linewidth=2)
  plt.savefig('trajectories.pdf') 
  plt.show()

def select_action1(state,policy_net1):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net1(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def select_action2(state,policy_net2):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net2(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def update_params1(batch,policy_net1,value_net1):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net1(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net1, torch.Tensor(flat_params))
        for param in value_net1.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net1(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net1.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net1).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net1).double().numpy(), maxiter=25)
    set_flat_params_to(value_net1, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net1(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net1(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net1(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net1(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    trpo_step(policy_net1, get_loss, get_kl, args.max_kl, args.damping)

def update_params2(batch,policy_net2,value_net2):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net2(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net2, torch.Tensor(flat_params))
        for param in value_net2.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net2(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net2.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net2).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net2).double().numpy(), maxiter=25)
    set_flat_params_to(value_net2, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net2(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net2(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net2(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net2(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    trpo_step(policy_net2, get_loss, get_kl, args.max_kl, args.damping)

running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10)  

def trial1(memory,theta,policy_net1):
  trajectory1=[]
  state1 = dynamics1.reset(np.array([1.5,1.5]))
  trajectory1.append([state1[0],state1[1]])

  state = running_state(state1)
  for t in range(500): # Don't infinite loop while learning
      action = select_action1(state,policy_net1)
      action = action.data[0].numpy()
      next_state, reward, done, out = dynamics1.step(action, omega1, theta)
      trajectory1.append([next_state[0],next_state[1]])
      next_state = running_state(next_state)
      mask = 1

      if done:
          mask = 0
      if out:
          mask = 0
      memory.push(state, np.array([action]), mask, next_state, reward)
      if done:
          break
      if out:
          break
      state = next_state

  return trajectory1

def trial2(memory,theta,policy_net2):
  trajectory2=[]
  state2 = dynamics2.reset(np.array([7.5,1.5]))
  trajectory2.append([state2[0],state2[1]])

  state = running_state(state2)
  for t in range(500): # Don't infinite loop while learning
      action = select_action2(state,policy_net2)
      action = action.data[0].numpy()
      next_state, reward, done, out = dynamics2.step(action, omega2, theta)
      trajectory2.append([next_state[0],next_state[1]])
      next_state = running_state(next_state)
      mask = 1

      if done:
          mask = 0
      if out:
          mask = 0
      memory.push(state, np.array([action]), mask, next_state, reward)
      if done:
          break
      if out:
          break
      state = next_state

  return trajectory2

def RL(omega,theta,i,test):
  policy_net1 = Policy(num_inputs, num_actions)
  value_net1 = Value(num_inputs)
  policy_net2 = Policy(num_inputs, num_actions)
  value_net2 = Value(num_inputs)
  outer=i
  for i_episode in range(100):
      memory1 = Memory()

      num_steps = 0
      reward_batch = 0
      num_episodes = 0
      while num_steps < args.batch_size:
          state = dynamics1.reset()
          state = running_state(state)
          reward_sum = 0
          for t in range(args.terminal_time_onetrack): # Don't infinite loop while learning
              action = select_action1(state,policy_net1)
              action = action.data[0].numpy()
              next_state, reward, done, out= dynamics1.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1
              if done:
                  mask = 0
              if out:
                  mask=0

              memory1.push(state, np.array([action]), mask, next_state, reward)

              if done:
                  break
              if out:
                  break
  
              state = next_state
          num_steps += (t-1)
          num_episodes += 1
          reward_batch += reward_sum

      reward_batch /= num_episodes
      batch = memory1.sample()
      update_params1(batch,policy_net1,value_net1)

      if i_episode % args.log_interval == 0:
          print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
              i_episode, reward_sum, reward_batch))


      if i_episode % (args.test_reward_iteration) == 0:
          state = dynamics1.reset(np.array([1.5,1.5]))
          state = running_state(state)
          reward_sum = 0
          for t in range(500): # Don't infinite loop while learning
              action = select_action1(state,policy_net1)
              action = action.data[0].numpy()
              next_state, reward, done, out = dynamics1.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1

              if done:
                mask = 0
              if out:
                  mask = 0
              memory1.push(state, np.array([action]), mask, next_state, reward)
              if done:
                  break
              if out:
                  break
              state = next_state
          print("from: "+str(np.array([1.5,1.5])))
          print('reward:  '+str(reward_sum))


      print("--------------------------------------------------------------")
    #time_end=time.time()
    #print('time cost:   ',str(time_end-time_start),'s')
  if test:
    num_trials=5
    for i in range(num_trials):
      trajectory1_file=open("learned_agent1_trajectory_file"+str(num_trials*outer+i)+".txt","w")
      trajectory1=trial1(memory1,theta,policy_net1)
      for entry in trajectory1:
        np.savetxt(trajectory1_file,entry)
      trajectory1_file.close()
  else:
    constraint_map1=np.zeros((9,13))  
    feature_count1=np.zeros((2,1))
    num_trials=10
    for i in range(num_trials):
      trajectory1=trial1(memory1,theta,policy_net1)
      length1=(1.0*len(trajectory1))
      for j in range(int(length1)):
        if int(trajectory1[j][0])==9:
           trajectory1[j][0]=8
        if int(trajectory1[j][1])==13:
           trajectory1[j][1]=12
        constraint_map1[int(trajectory1[j][0]),int(trajectory1[j][1])]=constraint_map1[int(trajectory1[j][0]),int(trajectory1[j][1])]+1.0
        feature_count1=feature_count1+feature1([int(trajectory1[j][0]),int(trajectory1[j][1])])
    constraint_map1=constraint_map1/num_trials
    feature_count1=feature_count1/num_trials

  for i_episode in range(100):
      memory2 = Memory()

      num_steps = 0
      reward_batch = 0
      num_episodes = 0
      while num_steps < args.batch_size:
          state = dynamics2.reset()
          state = running_state(state)
          reward_sum = 0
          for t in range(args.terminal_time_onetrack): # Don't infinite loop while learning
              action = select_action2(state,policy_net2)
              action = action.data[0].numpy()
              next_state, reward, done, out= dynamics2.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1
              if done:
                  mask = 0
              if out:
                  mask=0

              memory2.push(state, np.array([action]), mask, next_state, reward)

              if done:
                  break
              if out:
                  break

              state = next_state
          num_steps += (t-1)
          num_episodes += 1
          reward_batch += reward_sum

      reward_batch /= num_episodes
      batch = memory2.sample()
      update_params2(batch,policy_net2,value_net2)

      if i_episode % args.log_interval == 0:
          print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
              i_episode, reward_sum, reward_batch))


      if i_episode % (args.test_reward_iteration) == 0:
          state = dynamics2.reset(np.array([7.5,1.5]))
          state = running_state(state)
          reward_sum = 0
          for t in range(500): # Don't infinite loop while learning
              action = select_action2(state,policy_net2)
              action = action.data[0].numpy()
              next_state, reward, done, out = dynamics2.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1

              if done:
                  mask = 0
              if out:
                  mask = 0
              memory2.push(state, np.array([action]), mask, next_state, reward)
              if done:
                  break
              if out:
                  break
              state = next_state
          print("from: "+str(np.array([7.5,1.5])))
          print('reward:  '+str(reward_sum))


      print("--------------------------------------------------------------")

  if test:
    num_trials=5
    for i in range(num_trials):
      trajectory2_file=open("learned_agent2_trajectory_file"+str(num_trials*outer+i)+".txt","w")
      trajectory2=trial2(memory2,theta,policy_net2)
      for entry in trajectory2:
        np.savetxt(trajectory2_file,entry)
      trajectory2_file.close()
  else:
    constraint_map2=np.zeros((9,13))  
    feature_count2=np.zeros((2,1))
    num_trials=10
    for i in range(num_trials):
      trajectory2=trial2(memory2,theta,policy_net2)
      length2=(1.0*len(trajectory2))
      for j in range(int(length2)):
        if int(trajectory2[j][0])==9:
           trajectory2[j][0]=8
        if int(trajectory2[j][1])==13:
           trajectory2[j][1]=12
        constraint_map2[int(trajectory2[j][0]),int(trajectory2[j][1])]=constraint_map2[int(trajectory2[j][0]),int(trajectory2[j][1])]+1.0
        feature_count2=feature_count2+feature2([int(trajectory2[j][0]),int(trajectory2[j][1])])
    constraint_map2=constraint_map2/num_trials
    feature_count2=feature_count2/num_trials
  if not test:
    return constraint_map1,constraint_map2,feature_count1,feature_count2

def expert_constraint_feature_map(num_iteration):
  constraint_map1=np.zeros((9,13))
  constraint_map2=np.zeros((9,13))
  feature_count1=np.zeros((2,1))
  feature_count2=np.zeros((2,1))
  for i in range(num_iteration):
    trajectory1=np.loadtxt("expert1_trajectory_file"+str(i)+".txt",dtype=float)
    trajectory2=np.loadtxt("expert2_trajectory_file"+str(i)+".txt",dtype=float)
    length1=(1.0*len(trajectory1))/2.0
    length2=(1.0*len(trajectory2))/2.0
    for j in range(int(length1)):
      constraint_map1[int(trajectory1[2*j]),int(trajectory1[2*j+1])]=constraint_map1[int(trajectory1[2*j]),int(trajectory1[2*j+1])]+1.0
      feature_count1=feature_count1+feature1([int(trajectory1[2*j]),int(trajectory1[2*j+1])])
    for j in range(int(length2)):
      constraint_map2[int(trajectory2[2*j]),int(trajectory2[2*j+1])]=constraint_map2[int(trajectory2[2*j]),int(trajectory2[2*j+1])]+1.0
      feature_count2=feature_count2+feature2([int(trajectory2[2*j]),int(trajectory2[2*j+1])])
  feature_count1=feature_count1/num_iteration
  feature_count2=feature_count2/num_iteration
  constraint_map1=10*constraint_map1
  constraint_map2=10*constraint_map2
  return constraint_map1,constraint_map2,feature_count1,feature_count2


omega1=np.mat([1,1,1,1]).T
omega2=np.mat([1,1,1,1]).T
theta1=np.zeros((9,13))
theta2=np.zeros((9,13))
theta2[2:4,3]=1.0
theta2[5:7,3]=1.0
theta2[3:6,7]=1.0
theta2[4,8]=1.0
online_iterations=9

for j in range(10):
  RL(omega2,theta2,j,True)




