import numpy as np
from math import *
from env import dynamics
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(8)
np.random.seed(8)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.parameters = [
                    torch.Tensor(512, 2).uniform_(-1./sqrt(2), 1./sqrt(2)).requires_grad_(),
                    torch.Tensor(512).zero_().requires_grad_(),

                    torch.Tensor(256, 512).uniform_(-1./sqrt(512), 1./sqrt(512)).requires_grad_(),
                    torch.Tensor(256).zero_().requires_grad_(),

                    torch.Tensor(256, 256).uniform_(-1./sqrt(256), 1./sqrt(256)).requires_grad_(),
                    torch.Tensor(256).zero_().requires_grad_(),

                    torch.Tensor(128, 256).uniform_(-1./sqrt(256), 1./sqrt(256)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(1, 128).uniform_(-1./sqrt(128), 1./sqrt(128)).requires_grad_(),
                    torch.Tensor(1).zero_().requires_grad_(),
                ]

    def dense(self, x, parameters):
        x = F.linear(x, parameters[0], parameters[1])
        x = F.relu(x)

        x = F.linear(x, parameters[2], parameters[3])
        x = F.relu(x)

        x = F.linear(x, parameters[4], parameters[5])
        x = F.relu(x)

        x = F.linear(x, parameters[6], parameters[7])
        x = F.relu(x)

        x = F.linear(x, parameters[8], parameters[9])
        return x

def policy(Q_matrix,num_action):
  distribution=np.zeros((10,13,num_action))
  distribution=distribution.astype(np.object)
  for x in range(10):
    for y in range(13):
      counter=0
      value_list=[]
      for a in range(num_action):
        value_list.append(Q_matrix[x][y][a])
      max_value=max(value_list)
      for a in range(num_action):
        if Q_matrix[x][y][a]==max_value:
          counter=counter+1
      for a in range(num_action):
        if Q_matrix[x][y][a]==max_value:
          distribution[x][y][a]=1.0/counter
  return distribution

def Q_matrix_function(gamma,V_matrix,num_action,reward_matrix,cost_matrix):
  Q_matrix=np.zeros((10,13,num_action))
  Q_matrix=Q_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      for a in range(num_action):
        next_state=dynamics(np.mat([x,y]).T,np.mat([a]).T)
        value=V_matrix[next_state.item(0)][next_state.item(1)]
        Q_matrix[x][y][a]=reward_matrix[x,y]-cost_matrix[x,y]+gamma*value
  return Q_matrix

def V_matrix_funciton(Q_matrix,num_action,policy):
  V_matrix=np.zeros((10,13))
  V_matrix=V_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      value=0.0
      for a in range(num_action):
        value=value+policy[x][y][a]*Q_matrix[x][y][a]
      V_matrix[x][y]=value
  return V_matrix

def calculate_policy(reward_function1,reward_function2,omega,gamma,num_action,iteration):
  reward1_matrix=np.zeros((10,13))
  reward1_matrix=reward1_matrix.astype(np.object)
  reward2_matrix=np.zeros((10,13))
  reward2_matrix=reward2_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      reward2_matrix[x,y]=reward_function2.dense(torch.tensor([(9.0-1.0*x)/10,1.0*y/13]),reward_function2.parameters).item()
      if iteration<10:
        reward1_matrix[x,y]=reward_function1.dense(torch.tensor([1.0*x/10,1.0*y/13]),reward_function1.parameters).item()
      else:
        reward1_matrix[9,12]=reward_function1.dense(torch.tensor([1.0*9/10,1.0*12/13]),reward_function1.parameters).item()
  cost_matrix=100*omega

  V1_matrix=np.zeros((10,13))
  V1_matrix=V1_matrix.astype(np.object)
  V2_matrix=np.zeros((10,13))
  V2_matrix=V2_matrix.astype(np.object)
 
  Q1_matrix=np.copy(Q_matrix_function(gamma,V1_matrix,num_action,reward1_matrix,cost_matrix))
  policy1=np.copy(policy(Q1_matrix,num_action))
  new_V1_matrix=np.copy(V_matrix_funciton(Q1_matrix,num_action,policy1))
  Q2_matrix=np.copy(Q_matrix_function(gamma,V2_matrix,num_action,reward2_matrix,cost_matrix))
  policy2=np.copy(policy(Q2_matrix,num_action))
  new_V2_matrix=np.copy(V_matrix_funciton(Q2_matrix,num_action,policy2))

  for m in range(50):
    V1_matrix=np.copy(new_V1_matrix)
    Q1_matrix=np.copy(Q_matrix_function(gamma,V1_matrix,num_action,reward1_matrix,cost_matrix))
    policy1=np.copy(policy(Q1_matrix,num_action))
    new_V1_matrix=np.copy(V_matrix_funciton(Q1_matrix,num_action,policy1))
    V2_matrix=np.copy(new_V2_matrix)
    Q2_matrix=np.copy(Q_matrix_function(gamma,V2_matrix,num_action,reward2_matrix,cost_matrix))
    policy2=np.copy(policy(Q2_matrix,num_action))
    new_V2_matrix=np.copy(V_matrix_funciton(Q2_matrix,num_action,policy2))

  return policy1,policy2

def trial(initial_state,policy1,policy2,num_action,iteration):
  trajectory=[]
  state=initial_state
  for i in range(35):
    policy1_distribution=policy1[state.item(0)][state.item(1)][:]
    choice1=[]
    for a in range(num_action):
      if policy1_distribution[a]>0.0:
        choice1.append(a)
    if len(choice1)==0:
      choice1.append(0)
    if iteration<4:
      index1=np.random.randint(len(choice1))
      action1=choice1[index1]
    else:
      sign1=np.random.uniform()
      if sign1>(0.1/iteration):
        index1=np.random.randint(len(choice1))
        action1=choice1[index1]
      else:
        action1=np.random.randint(0,4)
    next_state1=dynamics(state[0:2],np.mat([action1]).T)

    policy2_distribution=policy2[state.item(2)][state.item(3)][:]
    choice2=[]
    for a in range(num_action):
      if policy2_distribution[a]>0.0:
        choice2.append(a)
    if len(choice2)==0:
      choice2.append(0)
    if iteration<4:
      index2=np.random.randint(len(choice2))
      action2=choice2[index2]
    else:
      sign2=np.random.uniform()
      if sign2>(0.1/iteration):
        index2=np.random.randint(len(choice2))
        action2=choice2[index2]
      else:
        action2=np.random.randint(0,4)
    next_state2=dynamics(state[2:4],np.mat([action2]).T)
    
    trajectory.append([state.item(0),state.item(1),state.item(2),state.item(3),action1,action2])
    state=np.copy(np.vstack((next_state1,next_state2)))
  return trajectory

def constraint_map(number_trials,trajectories):
  constraint_map=np.zeros((10,13))
  constraint_map=constraint_map.astype(np.object)
  for i in range(number_trials):
    for j in range(35):
      constraint_map[int(trajectories[35*i+j,0]),int(trajectories[35*i+j,1])]=constraint_map[int(trajectories[35*i+j,0]),int(trajectories[35*i+j,1])]+1.0
      constraint_map[int(trajectories[35*i+j,2]),int(trajectories[35*i+j,3])]=constraint_map[int(trajectories[35*i+j,2]),int(trajectories[35*i+j,3])]+1.0
  return constraint_map/number_trials

def reward_gradient_map(reward_function1,reward_function2,number_trials,trajectories):
  reward_parameters_count1=[0.0]*10
  reward_parameters_count2=[0.0]*10
  for i in range(number_trials):
    for j in range(35):
      reward_value1=reward_function1.dense(torch.tensor([trajectories[35*i+j,0],trajectories[35*i+j,1]],dtype=torch.float),reward_function1.parameters)
      reward_value1.backward()
      for number in range(10):
        reward_parameters_count1[number]=reward_parameters_count1[number]+reward_function1.parameters[number].grad
      reward_value2=reward_function2.dense(torch.tensor([trajectories[35*i+j,2],trajectories[35*i+j,3]],dtype=torch.float),reward_function2.parameters)
      reward_value2.backward()
      for number in range(10):
        reward_parameters_count2[number]=reward_parameters_count2[number]+reward_function2.parameters[number].grad
  for number in range(10):
    reward_parameters_count1[number]=reward_parameters_count1[number]/number_trials
    reward_parameters_count2[number]=reward_parameters_count2[number]/number_trials
  return reward_parameters_count1,reward_parameters_count2

def false_positive_negative_rate(omega):
  positive=0.0
  false_positive=0.0
  for x in range(10):
    for y in range(13):
      if omega[x,y]>0.0:
        if x>=3 and x<=9 and y>=2 and y<=10:
          positive=positive+1.0
        elif x==0 and y>=1 and y<=5:
          positive=positive+1.0
        elif x==0 and y>=7 and y<=11:
          positive=positive+1.0
        elif x==8 and y==0:
          positive=positive+1.0
        elif x==2 and y==1:
          positive=positive+1.0
        elif x==3 and y==11:
          positive=positive+1.0
        elif x==7 and y==12:
          positive=positive+1.0
        else:
          false_positive=false_positive+1.0
  return false_positive/53.0, (77.0-positive)/77.0

def obstacle_collision(x,y):
  if x>=3 and x<=9 and y>=2 and y<=10:
    return True
  elif x==0 and y>=1 and y<=5:
    return True
  elif x==0 and y>=7 and y<=11:
    return True
  elif x==8 and y==0:
    return True
  elif x==2 and y==1:
    return True
  elif x==3 and y==11:
    return True
  elif x==7 and y==12:
    return True
  else:
    return False

def constraint_violation_rate(number_trials,trajectories):
  violation_list=[]
  for i in range(number_trials):
    violation=0.0
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,0],trajectories[35*i+j,1]):
        violation=violation+1.0
        break
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,2],trajectories[35*i+j,3]):
        violation=violation+1.0
        break
    violation_list.append(violation/2)
  return sum(violation_list)/len(violation_list)

def success_rate(number_trials,trajectories):
  success_list=[]
  for i in range(number_trials):
    success=0.0
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,0],trajectories[35*i+j,1]):
        break
      elif trajectories[35*i+j,0]==9 and trajectories[35*i+j,1]==12:
        success=success+1.0
        break
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,2],trajectories[35*i+j,3]):
        break 
      elif trajectories[35*i+j,2]==9 and trajectories[35*i+j,3]==0:
        success=success+1.0 
        break
    success_list.append(success/2)
  return sum(success_list)/len(success_list)

num_action=4
gamma=1.0
num_trials=50
a=np.loadtxt("expert_trajectory_file.txt",dtype=float)
expert_trajectories=a.reshape(35*num_trials,6)
initial_state=np.mat([9,0,9,12]).T

def experiment():  
  false_positive_list=[]
  false_negative_list=[]
  constraint_violation_list=[]
  success_list=[]
  omega=np.zeros((10,13))
  omega=omega.astype(np.object)
  reward_function1=Model()
  reward_function2=Model()
  for i in range(20):
    print('online iteration', i+1)
    expert_constraint_map=1000*constraint_map(2*(i+1),expert_trajectories)
    for iteration in range(5):
      print('omega', omega)
      policy1,policy2=calculate_policy(reward_function1,reward_function2,omega,gamma,num_action,i+1)
      trajectory_file=open("learner_trajectory_file.txt","w")
      for j in range(num_trials):
        trajectory=np.copy(trial(initial_state,policy1,policy2,num_action,i+1))
        for entry in trajectory:
          np.savetxt(trajectory_file,entry)
      trajectory_file.close()
      b=np.loadtxt("learner_trajectory_file.txt",dtype=float)
      learner_trajectories=b.reshape(35*num_trials,6)
      learner_constraint_map=constraint_map(num_trials,learner_trajectories)
      omega=omega+0.1*(learner_constraint_map-expert_constraint_map)
      for x in range(10):
        for y in range(13):
          if omega[x,y]>1.0:
            omega[x,y]=1.0
          if omega[x,y]<0.0:
            omega[x,y]=0.0
        
      expert_reward_gradient_map1,expert_reward_gradient_map2=reward_gradient_map(reward_function1,reward_function2,2*(i+1),expert_trajectories)
      learner_reward_gradient_map1,learner_reward_gradient_map2=reward_gradient_map(reward_function1,reward_function2,num_trials,learner_trajectories)

      with torch.no_grad():
        for number in range(10):
          reward_function1.parameters[number]+=0.001*(0.00001*(learner_reward_gradient_map1[number]-expert_reward_gradient_map1[number]))
          reward_function2.parameters[number]-=0.001*(0.00001*(learner_reward_gradient_map2[number]-expert_reward_gradient_map2[number]))

    false_positive_rate,false_negative_rate=false_positive_negative_rate(omega)
    constraint_violation_mean=constraint_violation_rate(num_trials,learner_trajectories)
    success_mean=success_rate(num_trials,learner_trajectories) 

    false_positive_list.append(false_positive_rate)
    false_negative_list.append(false_negative_rate)
    constraint_violation_list.append(constraint_violation_mean)
    success_list.append(success_mean)

    print('false positive rate', false_positive_rate)
    print('false negative rate', false_negative_rate)
    print('constraint violation', constraint_violation_mean)
    print('success_rate', success_mean)

    print(trajectory)
  return false_positive_list,false_negative_list,constraint_violation_list,success_list

all_false_positive_list=[]
all_false_negative_list=[]
all_constraint_violation_list=[]
all_success_list=[]
num_experiment=2

start_time=time.time()
for number in range(num_experiment):
  false_positive_list,false_negative_list,constraint_violation_list,success_list=experiment()
  all_false_positive_list.append(false_positive_list)
  all_false_negative_list.append(false_negative_list)
  all_constraint_violation_list.append(constraint_violation_list)
  all_success_list.append(success_list)
end_time=time.time()
print('time cost for one experiment',(end_time-start_time)/num_experiment)

false_positive_mean_list=[[0.0]]
false_positive_sd_list=[[0.0]]
false_negative_mean_list=[[1.0]]
false_negative_sd_list=[[0.0]]
constraint_violation_mean_list=[[1.0]]
constraint_violation_sd_list=[[0.0]]
success_mean_list=[[0.0]]
success_sd_list=[[0.0]]

for i in range(20):
  positive_list=[]
  negative_list=[]
  violation_list=[]
  succ_list=[]
  for j in range(num_experiment):
    positive_list.append(all_false_positive_list[j][i])
    negative_list.append(all_false_negative_list[j][i])
    violation_list.append(all_constraint_violation_list[j][i])
    succ_list.append(all_success_list[j][i])
  false_positive_mean_list.append([sum(positive_list)/len(positive_list)])
  false_positive_sd_list.append([sqrt(np.var(positive_list))])
  false_negative_mean_list.append([sum(negative_list)/len(negative_list)])
  false_negative_sd_list.append([sqrt(np.var(negative_list))])
  constraint_violation_mean_list.append([sum(violation_list)/len(violation_list)])
  constraint_violation_sd_list.append([sqrt(np.var(violation_list))])
  success_mean_list.append([sum(succ_list)/len(succ_list)])
  success_sd_list.append([sqrt(np.var(succ_list))])

#print(false_positive_mean_list)

false_positive_mean_file=open("follow_false_positive_mean_file.txt","w")
for entry in false_positive_mean_list:
  np.savetxt(false_positive_mean_file,entry)
false_positive_mean_file.close()

false_negative_mean_file=open("follow_false_negative_mean_file.txt","w")
for entry in false_negative_mean_list:
  np.savetxt(false_negative_mean_file,entry)
false_negative_mean_file.close()

constraint_violation_mean_file=open("follow_constraint_violation_mean_file.txt","w")
for entry in constraint_violation_mean_list:
  np.savetxt(constraint_violation_mean_file,entry)
constraint_violation_mean_file.close()

success_mean_file=open("follow_success_mean_file.txt","w")
for entry in success_mean_list:
  np.savetxt(success_mean_file,entry)
success_mean_file.close()

false_positive_sd_file=open("follow_false_positive_sd_file.txt","w")
for entry in false_positive_sd_list:
  np.savetxt(false_positive_sd_file,entry)
false_positive_sd_file.close()

false_negative_sd_file=open("follow_false_negative_sd_file.txt","w")
for entry in false_negative_sd_list:
  np.savetxt(false_negative_sd_file,entry)
false_negative_sd_file.close()

constraint_violation_sd_file=open("follow_constraint_violation_sd_file.txt","w")
for entry in constraint_violation_sd_list:
  np.savetxt(constraint_violation_sd_file,entry)
constraint_violation_sd_file.close()

success_sd_file=open("follow_success_sd_file.txt","w")
for entry in success_sd_list:
  np.savetxt(success_sd_file,entry)
success_sd_file.close()

























