import argparse
from itertools import count
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.optimize
import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from reinforcement import trpo_step
from utils import *
from motion_dynamic import dynamic1, dynamic2
import time
from math import *

torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

torch.set_default_tensor_type('torch.DoubleTensor')

class Parser():
    def __init__(self,gamma=0.995,tau=0.97,l2_reg=1e-3,max_kl=1e-2,damping=1e-1,batch_size=10000,terminal_time_onetrack=500,log_interval=1,seed=543,test_reward_iteration=5):
        self.gamma=gamma
        self.tau=tau
        self.l2_reg=l2_reg
        self.max_kl=max_kl
        self.damping=damping
        self.batch_size=batch_size
        self.log_interval=log_interval
        self.test_reward_iteration=test_reward_iteration
        self.terminal_time_onetrack=terminal_time_onetrack

args = Parser()

dynamics1= dynamic1()
dynamics2= dynamic2()
num_inputs = dynamics1.num_states
num_actions = dynamics1.num_actions

def feature1(state):
  if state[0]>=7 and state[0]<=8 and state[1]>=12 and state[1]<=13:
     return np.mat([50,50]).T
  elif state[0]<0 and state[0]>9 and state[1]<0 and state[1]>13:
     return np.mat([-0.5,-0.5]).T
  else:
     return np.mat([0.0001*state[0],0.0001*state[1]]).T

def feature2(state):
  if state[0]>=1 and state[0]<=2 and state[1]>=12 and state[1]<=13:
     return np.mat([50,50]).T
  elif state[0]<0 and state[0]>9 and state[1]<0 and state[1]>13:
     return np.mat([-0.5,-0.5]).T
  else:
     return np.mat([0.0001*state[0],0.0001*state[1]]).T

def draw_trajectory1(trajectory1):
  x1=[]
  y1=[]
  steps1=len(trajectory1)
  i=0
  while i<steps1:
    x1.append(trajectory1[i][0])
    y1.append(trajectory1[i][1])
    #x2.append(trajectory[i][2])
    #y2.append(trajectory[i][3])
    i=i+1

  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,8,9))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,9,0,13])
  ax.grid(linestyle='-',color='black')
  obstacle1=plt.Rectangle((3.05,5.05),0.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle2=plt.Rectangle((3.05,6.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle3=plt.Rectangle((3.05,7.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle4=plt.Rectangle((4.05,8.05),1.9,0.9,linewidth=2,color='tab:orange',facecolor='none')


  ax.add_patch(obstacle1)
  ax.add_patch(obstacle2)
  ax.add_patch(obstacle3)
  ax.add_patch(obstacle4)

  #ax.scatter(4.5,0.5,s=160,c="r",marker="x")

  ax.text(1.35,1.35,'$s_0$',fontsize=10)
  ax.text(7.35,12.35,'$s_G$',fontsize=10)
  ax.text(7.35,1.35,'$s_0^{\prime}$',fontsize=10)
  ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  #ax.set_title('Robots motion planning')
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  

  plt.plot(x1,y1,'r',linewidth=2)
  plt.show()

def draw_trajectory(trajectory1,trajectory2):
  x1=[]
  y1=[]
  x2=[]
  y2=[]
  steps1=len(trajectory1)
  i=0
  while i<steps1:
    x1.append(trajectory1[i][0])
    y1.append(trajectory1[i][1])
    #x2.append(trajectory[i][2])
    #y2.append(trajectory[i][3])
    i=i+1

  steps2=len(trajectory2)
  i=0
  while i<steps2:
    x2.append(trajectory2[i][0])
    y2.append(trajectory2[i][1])
    #x2.append(trajectory[i][2])
    #y2.append(trajectory[i][3])
    i=i+1

  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,8,9))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,9,0,13])
  ax.grid(linestyle='-',color='black')
  obstacle1=plt.Rectangle((3.05,5.05),0.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle2=plt.Rectangle((3.05,6.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle3=plt.Rectangle((3.05,7.05),2.9,0.9,linewidth=2,color='tab:orange',facecolor='none')
  obstacle4=plt.Rectangle((4.05,8.05),1.9,0.9,linewidth=2,color='tab:orange',facecolor='none')

  ax.add_patch(obstacle1)
  ax.add_patch(obstacle2)
  ax.add_patch(obstacle3)
  ax.add_patch(obstacle4)

  #ax.scatter(4.5,0.5,s=160,c="r",marker="x")

  ax.text(1.35,1.35,'$s_0$',fontsize=10)
  ax.text(7.35,12.35,'$s_G$',fontsize=10)
  ax.text(7.35,1.35,'$s_0^{\prime}$',fontsize=10)
  ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  #ax.set_title('Robots motion planning')
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  

  plt.plot(x1,y1,'r',linewidth=2)
  plt.plot(x2,y2,'blue',linewidth=2)
  plt.savefig('trajectories.pdf') 
  plt.show()

def select_action1(state,policy_net1):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net1(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def select_action2(state,policy_net2):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net2(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def update_params1(batch,policy_net1,value_net1):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net1(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net1, torch.Tensor(flat_params))
        for param in value_net1.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net1(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net1.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net1).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net1).double().numpy(), maxiter=25)
    set_flat_params_to(value_net1, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net1(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net1(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net1(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net1(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    trpo_step(policy_net1, get_loss, get_kl, args.max_kl, args.damping)

def update_params2(batch,policy_net2,value_net2):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net2(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net2, torch.Tensor(flat_params))
        for param in value_net2.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net2(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net2.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net2).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net2).double().numpy(), maxiter=25)
    set_flat_params_to(value_net2, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net2(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net2(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net2(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net2(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    trpo_step(policy_net2, get_loss, get_kl, args.max_kl, args.damping)

running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10)  

def trial1(memory,theta,policy_net1):
  trajectory1=[]
  state1 = dynamics1.reset(np.array([1.5,1.5]))
  trajectory1.append([state1[0],state1[1]])

  state = running_state(state1)
  for t in range(500): # Don't infinite loop while learning
      action = select_action1(state,policy_net1)
      action = action.data[0].numpy()
      next_state, reward, done, out = dynamics1.step(action, omega1, theta)
      trajectory1.append([next_state[0],next_state[1]])
      next_state = running_state(next_state)
      mask = 1

      if done:
          mask = 0
      if out:
          mask = 0
      memory.push(state, np.array([action]), mask, next_state, reward)
      if done:
          break
      if out:
          break
      state = next_state

  return trajectory1

def trial2(memory,theta,policy_net2):
  trajectory2=[]
  state2 = dynamics2.reset(np.array([7.5,1.5]))
  trajectory2.append([state2[0],state2[1]])

  state = running_state(state2)
  for t in range(500): # Don't infinite loop while learning
      action = select_action2(state,policy_net2)
      action = action.data[0].numpy()
      next_state, reward, done, out = dynamics2.step(action, omega2, theta)
      trajectory2.append([next_state[0],next_state[1]])
      next_state = running_state(next_state)
      mask = 1

      if done:
          mask = 0
      if out:
          mask = 0
      memory.push(state, np.array([action]), mask, next_state, reward)
      if done:
          break
      if out:
          break
      state = next_state

  return trajectory2

def collision(x,y):
  if x>=2 and x<=7 and y>=3 and y<=4:
    return True
  elif x>=3 and x<=6 and y>=7 and y<=8:
    return True
  else:
    return False

def success_collision_rate():
  success_times=0.0
  collision_times=0.0
  for i in range(10):
    trajectory11=np.loadtxt("learned_agent1_trajectory_file"+str(i)+".txt",dtype=float)
    length1=(1.0*len(trajectory11))/2.0
    trajectory1=[]
    for j in range(int(length1)):
      trajectory1.append([trajectory11[2*j],trajectory11[2*j+1]])

    trajectory22=np.loadtxt("learned_agent2_trajectory_file"+str(i)+".txt",dtype=float)
    length2=(1.0*len(trajectory22))/2.0
    trajectory2=[]
    for j in range(int(length2)):
      trajectory2.append([trajectory22[2*j],trajectory22[2*j+1]])

    steps1=len(trajectory1)
    for j in range(steps1):
      if collision(trajectory1[j][0],trajectory1[j][1]):
        collision_times=collision_times+1.0
        break
      if trajectory1[j][0]>=7 and trajectory1[j][0]<=8 and trajectory1[j][1]>=12 and trajectory1[j][1]<=13:
        success_times=success_times+1.0
        break

    steps2=len(trajectory2)
    for j in range(steps2):
      if collision(trajectory2[j][0],trajectory2[j][1]):
        collision_times=collision_times+1.0
        break
      if trajectory2[j][0]>=1 and trajectory2[j][0]<=2 and trajectory2[j][1]>=12 and trajectory2[j][1]<=13:
        success_times=success_times+1.0
        break

  return success_times/20.0,collision_times/20.0

def RL(omega,theta,i,test):
  policy_net1 = Policy(num_inputs, num_actions)
  value_net1 = Value(num_inputs)
  policy_net2 = Policy(num_inputs, num_actions)
  value_net2 = Value(num_inputs)
  outer=i
  for i_episode in range(100):
      memory1 = Memory()

      num_steps = 0
      reward_batch = 0
      num_episodes = 0
      while num_steps < args.batch_size:
          state = dynamics1.reset()
          state = running_state(state)
          reward_sum = 0
          for t in range(args.terminal_time_onetrack): # Don't infinite loop while learning
              action = select_action1(state,policy_net1)
              action = action.data[0].numpy()
              next_state, reward, done, out= dynamics1.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1
              if done:
                  mask = 0
              if out:
                  mask=0

              memory1.push(state, np.array([action]), mask, next_state, reward)

              if done:
                  break
              if out:
                  break
  
              state = next_state
          num_steps += (t-1)
          num_episodes += 1
          reward_batch += reward_sum

      reward_batch /= num_episodes
      batch = memory1.sample()
      update_params1(batch,policy_net1,value_net1)

      if i_episode % args.log_interval == 0:
          print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
              i_episode, reward_sum, reward_batch))


      if i_episode % (args.test_reward_iteration) == 0:
          state = dynamics1.reset(np.array([1.5,1.5]))
          state = running_state(state)
          reward_sum = 0
          for t in range(500): # Don't infinite loop while learning
              action = select_action1(state,policy_net1)
              action = action.data[0].numpy()
              next_state, reward, done, out = dynamics1.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1

              if done:
                mask = 0
              if out:
                  mask = 0
              memory1.push(state, np.array([action]), mask, next_state, reward)
              if done:
                  break
              if out:
                  break
              state = next_state
          print("from: "+str(np.array([1.5,1.5])))
          print('reward:  '+str(reward_sum))


      print("--------------------------------------------------------------")
    #time_end=time.time()
    #print('time cost:   ',str(time_end-time_start),'s')
  if test:
    num_trials=5
    for i in range(num_trials):
      trajectory1_file=open("learned_agent1_trajectory_file"+str(num_trials*outer+i)+".txt","w")
      trajectory1=trial1(memory1,theta,policy_net1)
      for entry in trajectory1:
        np.savetxt(trajectory1_file,entry)
      trajectory1_file.close()
  else:
    constraint_map1=np.zeros((9,13))  
    feature_count1=np.zeros((2,1))
    num_trials=10
    for i in range(num_trials):
      trajectory1_file=open("learned_agent1_trajectory_file"+str(i)+".txt","w")
      trajectory1=trial1(memory1,theta,policy_net1)
      for entry in trajectory1:
        np.savetxt(trajectory1_file,entry)
      trajectory1_file.close()
      length1=(1.0*len(trajectory1))
      for j in range(int(length1)):
        if int(trajectory1[j][0])==9:
           trajectory1[j][0]=8
        if int(trajectory1[j][1])==13:
           trajectory1[j][1]=12
        constraint_map1[int(trajectory1[j][0]),int(trajectory1[j][1])]=constraint_map1[int(trajectory1[j][0]),int(trajectory1[j][1])]+1.0
        feature_count1=feature_count1+feature1([int(trajectory1[j][0]),int(trajectory1[j][1])])
    constraint_map1=constraint_map1/num_trials
    feature_count1=feature_count1/num_trials

  for i_episode in range(100):
      memory2 = Memory()

      num_steps = 0
      reward_batch = 0
      num_episodes = 0
      while num_steps < args.batch_size:
          state = dynamics2.reset()
          state = running_state(state)
          reward_sum = 0
          for t in range(args.terminal_time_onetrack): # Don't infinite loop while learning
              action = select_action2(state,policy_net2)
              action = action.data[0].numpy()
              next_state, reward, done, out= dynamics2.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1
              if done:
                  mask = 0
              if out:
                  mask=0

              memory2.push(state, np.array([action]), mask, next_state, reward)

              if done:
                  break
              if out:
                  break

              state = next_state
          num_steps += (t-1)
          num_episodes += 1
          reward_batch += reward_sum

      reward_batch /= num_episodes
      batch = memory2.sample()
      update_params2(batch,policy_net2,value_net2)

      if i_episode % args.log_interval == 0:
          print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
              i_episode, reward_sum, reward_batch))


      if i_episode % (args.test_reward_iteration) == 0:
          state = dynamics2.reset(np.array([7.5,1.5]))
          state = running_state(state)
          reward_sum = 0
          for t in range(500): # Don't infinite loop while learning
              action = select_action2(state,policy_net2)
              action = action.data[0].numpy()
              next_state, reward, done, out = dynamics2.step(action, omega, theta)
              reward_sum += reward
              next_state = running_state(next_state)
              mask = 1

              if done:
                  mask = 0
              if out:
                  mask = 0
              memory2.push(state, np.array([action]), mask, next_state, reward)
              if done:
                  break
              if out:
                  break
              state = next_state
          print("from: "+str(np.array([7.5,1.5])))
          print('reward:  '+str(reward_sum))


      print("--------------------------------------------------------------")

  if test:
    num_trials=5
    for i in range(num_trials):
      trajectory2_file=open("learned_agent2_trajectory_file"+str(num_trials*outer+i)+".txt","w")
      trajectory2=trial2(memory2,theta,policy_net2)
      for entry in trajectory2:
        np.savetxt(trajectory2_file,entry)
      trajectory2_file.close()
  else:
    constraint_map2=np.zeros((9,13))  
    feature_count2=np.zeros((2,1))
    num_trials=10
    for i in range(num_trials):
      trajectory2_file=open("learned_agent2_trajectory_file"+str(i)+".txt","w")
      trajectory2=trial2(memory2,theta,policy_net2)
      for entry in trajectory2:
        np.savetxt(trajectory2_file,entry)
      trajectory2_file.close()
      length2=(1.0*len(trajectory2))
      for j in range(int(length2)):
        if int(trajectory2[j][0])==9:
           trajectory2[j][0]=8
        if int(trajectory2[j][1])==13:
           trajectory2[j][1]=12
        constraint_map2[int(trajectory2[j][0]),int(trajectory2[j][1])]=constraint_map2[int(trajectory2[j][0]),int(trajectory2[j][1])]+1.0
        feature_count2=feature_count2+feature2([int(trajectory2[j][0]),int(trajectory2[j][1])])
    constraint_map2=constraint_map2/num_trials
    feature_count2=feature_count2/num_trials
  if not test:
    return constraint_map1,constraint_map2,feature_count1,feature_count2

def expert_constraint_feature_map(num_iteration):
  constraint_map1=np.zeros((9,13))
  constraint_map2=np.zeros((9,13))
  feature_count1=np.zeros((2,1))
  feature_count2=np.zeros((2,1))
  for i in range(num_iteration):
    trajectory1=np.loadtxt("expert1_trajectory_file"+str(i)+".txt",dtype=float)
    trajectory2=np.loadtxt("expert2_trajectory_file"+str(i)+".txt",dtype=float)
    length1=(1.0*len(trajectory1))/2.0
    length2=(1.0*len(trajectory2))/2.0
    for j in range(int(length1)):
      constraint_map1[int(trajectory1[2*j]),int(trajectory1[2*j+1])]=constraint_map1[int(trajectory1[2*j]),int(trajectory1[2*j+1])]+1.0
      feature_count1=feature_count1+feature1([int(trajectory1[2*j]),int(trajectory1[2*j+1])])
    for j in range(int(length2)):
      constraint_map2[int(trajectory2[2*j]),int(trajectory2[2*j+1])]=constraint_map2[int(trajectory2[2*j]),int(trajectory2[2*j+1])]+1.0
      feature_count2=feature_count2+feature2([int(trajectory2[2*j]),int(trajectory2[2*j+1])])
  feature_count1=feature_count1/num_iteration
  feature_count2=feature_count2/num_iteration
  constraint_map1=10*constraint_map1
  constraint_map2=10*constraint_map2
  return constraint_map1,constraint_map2,feature_count1,feature_count2

#expert_constraint_map1, expert_constraint_map2=expert_constraint_map()
#print('experts visitation frequency: \n', expert_constraint_map1+expert_constraint_map2)
#print(expert_constraint_map2)

def false_positive_negative(theta):
  positive=0
  negative=0
  for x in range(9):
    for y in range(13):
      if theta[x,y]>0.05:
        if x>=2 and x<=6 and y==3:
          negative=negative+1.0
        elif x>=3 and x<=5 and y==7:
          negative=negative+1.0
        else:
          positive=positive+1.0
  return positive/109, (8-negative)/8


omega1=np.mat([1,1,1,1]).T
theta1=np.zeros((9,13))

omega2=np.mat([1,1,1,1]).T
theta2=np.zeros((9,13))

omega3=np.mat([1,1,1,1]).T
theta3=np.zeros((9,13))

omega4=np.mat([1,1,1,1]).T
theta4=np.zeros((9,13))

online_iterations=18

success_mean_list1=[[0.0]]
success_sd_list1=[[0.0]]
collision_mean_list1=[[1.0]]
collision_sd_list1=[[0.0]]

success_mean_list2=[[0.0]]
success_sd_list2=[[0.0]]
collision_mean_list2=[[1.0]]
collision_sd_list2=[[0.0]]

success_mean_list3=[[0.0]]
success_sd_list3=[[0.0]]
collision_mean_list3=[[1.0]]
collision_sd_list3=[[0.0]]

success_mean_list4=[[0.0]]
success_sd_list4=[[0.0]]
collision_mean_list4=[[1.0]]
collision_sd_list4=[[0.0]]

all_success_list1=[]
all_collision_list1=[]

all_success_list2=[]
all_collision_list2=[]

all_success_list3=[]
all_collision_list3=[]

all_success_list4=[]
all_collision_list4=[]

num_experiment=5

def experiment():
  omega1=np.mat([1,1,1,1]).T
  theta1=np.zeros((9,13))
  omega2=np.mat([1,1,1,1]).T
  theta2=np.zeros((9,13))
  omega3=np.mat([1,1,1,1]).T
  theta3=np.zeros((9,13))
  omega4=np.mat([1,1,1,1]).T
  theta4=np.zeros((9,13))
  local_success_list1=[]
  local_collision_list1=[]
  local_success_list2=[]
  local_collision_list2=[]
  local_success_list3=[]
  local_collision_list3=[]
  local_success_list4=[]
  local_collision_list4=[]
  for i in range(online_iterations):
    expert_constraint_map11, expert_constraint_map12,expert_feature_count11,expert_feature_count12=expert_constraint_feature_map(i+1)
    constraint_map11, constraint_map12,feature_count11,feature_count12=RL(omega1,theta1,i,False)
    success_rate1,collision_rate1=success_collision_rate()
    local_success_list1.append(success_rate1)
    local_collision_list1.append(collision_rate1)

    expert_constraint_map21, expert_constraint_map22,expert_feature_count21,expert_feature_count22=expert_constraint_feature_map(i+1)
    constraint_map21, constraint_map22,feature_count21,feature_count22=RL(omega2,theta2,i,False)
    success_rate2,collision_rate2=success_collision_rate()
    local_success_list2.append(success_rate2)
    local_collision_list2.append(collision_rate2)

    expert_constraint_map31, expert_constraint_map32,expert_feature_count31,expert_feature_count32=expert_constraint_feature_map(i+1)
    constraint_map31, constraint_map32,feature_count31,feature_count32=RL(omega3,theta3,i,False)
    success_rate3,collision_rate3=success_collision_rate()
    local_success_list3.append(success_rate3)
    local_collision_list3.append(collision_rate3)

    expert_constraint_map41, expert_constraint_map42,expert_feature_count41,expert_feature_count42=expert_constraint_feature_map(i+1)
    constraint_map41, constraint_map42,feature_count41,feature_count42=RL(omega4,theta4,i,False)
    success_rate4,collision_rate4=success_collision_rate()
    local_success_list4.append(success_rate4)
    local_collision_list4.append(collision_rate4)

    print(constraint_map11+constraint_map12)

    if i%2==0:
      theta1=0.24*(theta1+theta2)+0.26*(theta3+theta4)+0.1*(constraint_map11-expert_constraint_map11+constraint_map12-expert_constraint_map12)
      theta2=0.24*(theta1+theta2)+0.26*(theta3+theta4)+0.1*(constraint_map21-expert_constraint_map21+constraint_map22-expert_constraint_map22)
      theta3=0.24*(theta1+theta2)+0.26*(theta3+theta4)+0.1*(constraint_map31-expert_constraint_map31+constraint_map32-expert_constraint_map32)
      theta4=0.24*(theta1+theta2)+0.26*(theta3+theta4)+0.1*(constraint_map41-expert_constraint_map41+constraint_map42-expert_constraint_map42)
      omega1=0.24*(omega1+omega2)+0.26*(omega3+omega4)-0.00000001*np.vstack((feature_count11-expert_feature_count11,feature_count12-expert_feature_count12))
      omega2=0.24*(omega1+omega2)+0.26*(omega3+omega4)-0.00000001*np.vstack((feature_count21-expert_feature_count21,feature_count22-expert_feature_count22))
      omega3=0.24*(omega1+omega2)+0.26*(omega3+omega4)-0.00000001*np.vstack((feature_count31-expert_feature_count31,feature_count32-expert_feature_count32))
      omega4=0.24*(omega1+omega2)+0.26*(omega3+omega4)-0.00000001*np.vstack((feature_count41-expert_feature_count41,feature_count42-expert_feature_count42))

    else:
      theta1=0.26*(theta1+theta2)+0.24*(theta3+theta4)+0.1*(constraint_map11-expert_constraint_map11+constraint_map12-expert_constraint_map12)
      theta2=0.26*(theta1+theta2)+0.24*(theta3+theta4)+0.1*(constraint_map21-expert_constraint_map21+constraint_map22-expert_constraint_map22)
      theta3=0.26*(theta1+theta2)+0.24*(theta3+theta4)+0.1*(constraint_map31-expert_constraint_map31+constraint_map32-expert_constraint_map32)
      theta4=0.26*(theta1+theta2)+0.24*(theta3+theta4)+0.1*(constraint_map41-expert_constraint_map41+constraint_map42-expert_constraint_map42)
      omega1=0.26*(omega1+omega2)+0.24*(omega3+omega4)-0.00000001*np.vstack((feature_count11-expert_feature_count11,feature_count12-expert_feature_count12))
      omega2=0.26*(omega1+omega2)+0.24*(omega3+omega4)-0.00000001*np.vstack((feature_count21-expert_feature_count21,feature_count22-expert_feature_count22))
      omega3=0.26*(omega1+omega2)+0.24*(omega3+omega4)-0.00000001*np.vstack((feature_count31-expert_feature_count31,feature_count32-expert_feature_count32))
      omega4=0.26*(omega1+omega2)+0.24*(omega3+omega4)-0.00000001*np.vstack((feature_count41-expert_feature_count41,feature_count42-expert_feature_count42))

    print('omega is',omega1)

    for x in range(9):
      for y in range(13):
        if theta1[x,y]>1.0:
          theta1[x,y]=1.0
        if theta1[x,y]<0.0:
          theta1[x,y]=0.0
        if theta2[x,y]>1.0:
          theta2[x,y]=1.0
        if theta2[x,y]<0.0:
          theta2[x,y]=0.0
        if theta3[x,y]>1.0:
          theta3[x,y]=1.0
        if theta3[x,y]<0.0:
          theta3[x,y]=0.0
        if theta4[x,y]>1.0:
          theta4[x,y]=1.0
        if theta4[x,y]<0.0:
          theta4[x,y]=0.0
    print(theta1)
  return local_success_list1,local_collision_list1,local_success_list2,local_collision_list2,local_success_list3,local_collision_list3,local_success_list4,local_collision_list4

start_time=time.time()
for m in range(num_experiment):
  list11,list12,list21,list22,list31,list32,list41,list42=experiment()
  all_success_list1.append(list11)
  all_collision_list1.append(list12)
  all_success_list2.append(list21)
  all_collision_list2.append(list22)
  all_success_list3.append(list31)
  all_collision_list3.append(list32)
  all_success_list4.append(list41)
  all_collision_list4.append(list42)

end_time=time.time()

for t in range(online_iterations):
  one_list1=[]
  two_list1=[]
  one_list2=[]
  two_list2=[]
  one_list3=[]
  two_list3=[]
  one_list4=[]
  two_list4=[]
  for i in range(num_experiment):
    one_list1.append(all_success_list1[i][t])
    two_list1.append(all_collision_list1[i][t])
    one_list2.append(all_success_list2[i][t])
    two_list2.append(all_collision_list2[i][t])
    one_list3.append(all_success_list3[i][t])
    two_list3.append(all_collision_list3[i][t])
    one_list4.append(all_success_list4[i][t])
    two_list4.append(all_collision_list4[i][t])

  success_mean_list1.append([1.0*sum(one_list1)/len(one_list1)])
  success_sd_list1.append([sqrt(np.var(one_list1))])
  collision_mean_list1.append([1.0*sum(two_list1)/len(two_list1)])
  collision_sd_list1.append([sqrt(np.var(two_list1))])
  success_mean_list2.append([1.0*sum(one_list2)/len(one_list2)])
  success_sd_list2.append([sqrt(np.var(one_list2))])
  collision_mean_list2.append([1.0*sum(two_list2)/len(two_list2)])
  collision_sd_list2.append([sqrt(np.var(two_list2))])
  success_mean_list3.append([1.0*sum(one_list3)/len(one_list3)])
  success_sd_list3.append([sqrt(np.var(one_list3))])
  collision_mean_list3.append([1.0*sum(two_list3)/len(two_list3)])
  collision_sd_list3.append([sqrt(np.var(two_list3))])
  success_mean_list4.append([1.0*sum(one_list4)/len(one_list4)])
  success_sd_list4.append([sqrt(np.var(one_list4))])
  collision_mean_list4.append([1.0*sum(two_list4)/len(two_list4)])
  collision_sd_list4.append([sqrt(np.var(two_list4))])


success_mean_file1=open("learner1_success_mean_file.txt","w")
for entry in success_mean_list1:
  np.savetxt(success_mean_file1,entry)
success_mean_file1.close()

success_sd_file1=open("learner1_success_sd_file.txt","w")
for entry in success_sd_list1:
  np.savetxt(success_sd_file1,entry)
success_sd_file1.close()

collision_mean_file1=open("learner1_constraint_violation_mean_file.txt","w")
for entry in collision_mean_list1:
  np.savetxt(collision_mean_file1,entry)
collision_mean_file1.close()

collision_sd_file1=open("learner1_constraint_violation_sd_file.txt","w")
for entry in collision_sd_list1:
  np.savetxt(collision_sd_file1,entry)
collision_sd_file1.close()

success_mean_file2=open("learner2_success_mean_file.txt","w")
for entry in success_mean_list2:
  np.savetxt(success_mean_file2,entry)
success_mean_file2.close()

success_sd_file2=open("learner2_success_sd_file.txt","w")
for entry in success_sd_list2:
  np.savetxt(success_sd_file2,entry)
success_sd_file2.close()

collision_mean_file2=open("learner2_constraint_violation_mean_file.txt","w")
for entry in collision_mean_list2:
  np.savetxt(collision_mean_file2,entry)
collision_mean_file2.close()

collision_sd_file2=open("learner2_constraint_violation_sd_file.txt","w")
for entry in collision_sd_list2:
  np.savetxt(collision_sd_file2,entry)
collision_sd_file2.close()

success_mean_file3=open("learner3_success_mean_file.txt","w")
for entry in success_mean_list3:
  np.savetxt(success_mean_file3,entry)
success_mean_file3.close()

success_sd_file3=open("learner3_success_sd_file.txt","w")
for entry in success_sd_list3:
  np.savetxt(success_sd_file3,entry)
success_sd_file3.close()

collision_mean_file3=open("learner3_constraint_violation_mean_file.txt","w")
for entry in collision_mean_list3:
  np.savetxt(collision_mean_file3,entry)
collision_mean_file3.close()

collision_sd_file3=open("learner3_constraint_violation_sd_file.txt","w")
for entry in collision_sd_list3:
  np.savetxt(collision_sd_file3,entry)
collision_sd_file3.close()

success_mean_file4=open("learner4_success_mean_file.txt","w")
for entry in success_mean_list4:
  np.savetxt(success_mean_file4,entry)
success_mean_file4.close()

success_sd_file4=open("learner4_success_sd_file.txt","w")
for entry in success_sd_list4:
  np.savetxt(success_sd_file4,entry)
success_sd_file4.close()

collision_mean_file4=open("learner4_constraint_violation_mean_file.txt","w")
for entry in collision_mean_list4:
  np.savetxt(collision_mean_file4,entry)
collision_mean_file4.close()

collision_sd_file4=open("learner4_constraint_violation_sd_file.txt","w")
for entry in collision_sd_list4:
  np.savetxt(collision_sd_file4,entry)
collision_sd_file4.close()

tpoi=np.zeros((1,1))
tpoi[0,0]=(end_time-start_time)/(4*online_iterations)
time_file=open("time.txt","w")
for entry in tpoi:
  np.savetxt(time_file,entry)
time_file.close()

theta1_file=open("theta1.txt","w")
for entry in theta1:
  np.savetxt(theta1_file,entry)
theta1_file.close()

theta2_file=open("theta2.txt","w")
for entry in theta2:
  np.savetxt(theta2_file,entry)
theta2_file.close()

#for j in range(10):
#  RL(omega2,theta2,j,True)




