import numpy as np
from math import *
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.gridspec as gridspec
import sys, random, time
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

def draw_env(num_trials,gamma): 
  a=np.loadtxt("optimal_trajectory_file.txt",dtype=float)
  trajectories=a.reshape(30*num_trials,9)
  state_heat=np.zeros((9,9))
  aaction1_heat=np.zeros((9,1))
  aaction2_heat=np.zeros((9,1))
  aaction3_heat=np.zeros((9,1))
  for i in range(num_trials):
    trajectory=trajectories[30*i:30*(i+1),:]
    for j in range(30):
      state_heat[int(trajectory[j,1]),int(trajectory[j,0])]=state_heat[int(trajectory[j,1]),int(trajectory[j,0])]+1
      state_heat[int(trajectory[j,3]),int(trajectory[j,2])]=state_heat[int(trajectory[j,3]),int(trajectory[j,2])]+1
      state_heat[int(trajectory[j,5]),int(trajectory[j,4])]=state_heat[int(trajectory[j,5]),int(trajectory[j,4])]+1
      aaction1_heat[int(trajectory[j,6])]=aaction1_heat[int(trajectory[j,6])]+1
      aaction2_heat[int(trajectory[j,7])]=aaction2_heat[int(trajectory[j,7])]+1
      aaction3_heat[int(trajectory[j,8])]=aaction3_heat[int(trajectory[j,8])]+1
  state_heat=state_heat/100.0
  state_heat[0,8]=1.0
  state_heat[4,8]=1.0
  state_heat[8,8]=1.0

  action1_heat=np.zeros((3,3))
  action2_heat=np.zeros((3,3))
  action3_heat=np.zeros((3,3))
  action1_heat[0,0]=aaction1_heat[7]
  action1_heat[0,1]=aaction1_heat[2]
  action1_heat[0,2]=aaction1_heat[8]
  action1_heat[1,0]=aaction1_heat[3]
  action1_heat[1,1]=aaction1_heat[0]
  action1_heat[1,2]=aaction1_heat[4]
  action1_heat[2,0]=aaction1_heat[5]
  action1_heat[2,1]=aaction1_heat[1]
  action1_heat[2,2]=aaction1_heat[6]
  action1_heat=action1_heat/100
  action1_heat[1,1]=1

  action2_heat[0,0]=aaction2_heat[7]
  action2_heat[0,1]=aaction2_heat[2]
  action2_heat[0,2]=aaction2_heat[8]
  action2_heat[1,0]=aaction2_heat[3]
  action2_heat[1,1]=aaction2_heat[0]
  action2_heat[1,2]=aaction2_heat[4]
  action2_heat[2,0]=aaction2_heat[5]
  action2_heat[2,1]=aaction2_heat[1]
  action2_heat[2,2]=aaction2_heat[6]
  action2_heat=action2_heat/100
  action2_heat[1,1]=1


  action3_heat[0,0]=aaction3_heat[7]
  action3_heat[0,1]=aaction3_heat[2]
  action3_heat[0,2]=aaction3_heat[8]
  action3_heat[1,0]=aaction3_heat[3]
  action3_heat[1,1]=aaction3_heat[0]
  action3_heat[1,2]=aaction3_heat[4]
  action3_heat[2,0]=aaction3_heat[5]
  action3_heat[2,1]=aaction3_heat[1]
  action3_heat[2,2]=aaction3_heat[6]
  action3_heat=action3_heat/100
  action3_heat[1,1]=1

  state_heat=state_heat/np.max(state_heat)
  action1_heat=action1_heat/np.max(action1_heat)
  action2_heat=action2_heat/np.max(action2_heat)
  action3_heat=action3_heat/np.max(action3_heat)
  
  fig=plt.figure()
  gs=gridspec.GridSpec(3,6)
  ax1=fig.add_subplot(gs[:,0:5])
  ax2=fig.add_subplot(gs[0,5])
  ax3=fig.add_subplot(gs[1,5])
  ax4=fig.add_subplot(gs[2,5])
  ax1.axis('scaled')
  #ax1.set_xticks(np.linspace(0,9,10))
  #ax1.set_yticks(np.linspace(0,9,10))
  ax1.axis([0,9,0,9])
  ax1.grid(linestyle='-',color='black',which='both',linewidth=0.8)
  ax1.plot([1,1],[0,9],'black',linewidth=0.8)
  ax1.plot([3,3],[0,9],'black',linewidth=0.8)
  ax1.plot([5,5],[0,9],'black',linewidth=0.8)
  ax1.plot([7,7],[0,9],'black',linewidth=0.8)
  im=ax1.imshow(state_heat,cmap='viridis',extent=(0,9,9,0))
  #divider = make_axes_locatable(ax1)
  #cax = divider.append_axes("left", size="5%", pad=0.1)
  #cax.yaxis.set_ticks_position('left')
  cb=plt.colorbar(im, ax=[ax1],fraction=0.046, pad=0.04,location='left')
  cb.set_label('Visitation Frequency (Scaled)')

  obstacle1=plt.Rectangle((4,0),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle2=plt.Rectangle((4,1),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle3=plt.Rectangle((4,2),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle4=plt.Rectangle((4,6),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle5=plt.Rectangle((4,7),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle6=plt.Rectangle((4,8),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle7=plt.Rectangle((4,3),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle8=plt.Rectangle((3,4),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle9=plt.Rectangle((2,5),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle10=plt.Rectangle((5,4),1,1,linewidth=2,edgecolor='black',facecolor='none')
  obstacle11=plt.Rectangle((6,5),1,1,linewidth=2,edgecolor='black',facecolor='none')

  ax1.add_patch(obstacle1)
  ax1.add_patch(obstacle2)
  ax1.add_patch(obstacle3)
  ax1.add_patch(obstacle4)
  ax1.add_patch(obstacle5)
  ax1.add_patch(obstacle6)
  ax1.add_patch(obstacle7)
  ax1.add_patch(obstacle8)
  ax1.add_patch(obstacle9)
  ax1.add_patch(obstacle10)
  ax1.add_patch(obstacle11)
  #ax1.add_patch(plt.Rectangle((1,6),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax1.scatter(4.5,0.5,s=160,c="r",marker="x")
  ax1.scatter(4.5,1.5,s=160,c="r",marker="x")
  ax1.scatter(4.5,2.5,s=160,c="r",marker="x")
  ax1.scatter(4.5,6.5,s=160,c="r",marker="x")
  ax1.scatter(4.5,7.5,s=160,c="r",marker="x")
  ax1.scatter(4.5,8.5,s=160,c="r",marker="x")
  ax1.scatter(4.5,3.5,s=160,c="r",marker="x")
  ax1.scatter(3.5,4.5,s=160,c="r",marker="x")
  ax1.scatter(2.5,5.5,s=160,c="r",marker="x")
  ax1.scatter(5.5,4.5,s=160,c="r",marker="x")
  ax1.scatter(6.5,5.5,s=160,c="r",marker="x")
  ax1.text(0.35,0.35,'$s_0^{\prime\prime}$',fontsize=10)
  ax1.text(8.35,0.35,'$s_G^{\prime\prime}$',fontsize=10)
  ax1.text(0.35,4.35,'$s_0^{\prime}$',fontsize=10)
  ax1.text(8.35,4.35,'$s_G^{\prime}$',fontsize=10)
  ax1.text(0.35,8.35,'$s_0$',fontsize=10)
  ax1.text(8.35,8.35,'$s_G$',fontsize=10)
  ax1.set_title('States')
  for axi in (ax1.xaxis, ax1.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False

  ax2.axis('scaled')
  ax2.set_xticks(np.linspace(0,3,4))
  ax2.set_yticks(np.linspace(0,3,4))
  ax2.axis([0,3,0,3])
  ax2.grid(linestyle='-',color='black')
  ax2.set_title('Actions')
  ax2.set_xlabel('Expert 1')
  ax2.imshow(action1_heat,cmap='viridis',extent=(0,3,3,0))
  obstacle12=plt.Rectangle((0,0),1,1,linewidth=1.5,edgecolor='black',facecolor='none')
  obstacle13=plt.Rectangle((2,0),1,1,linewidth=1.5,edgecolor='black',facecolor='none')
  ax2.add_patch(obstacle12)
  ax2.add_patch(obstacle13)
  ax2.scatter(0.5,0.5,s=80,c="r",marker="x")
  ax2.scatter(2.5,0.5,s=80,c="r",marker="x")
  ax2.arrow(1.5,1.5,0,1,head_width=0.2, head_length=0.2,color='silver')
  ax2.arrow(1.5,1.5,0.707,0.707,head_width=0.2, head_length=0.2,color='silver')
  ax2.arrow(1.5,1.5,1,0,head_width=0.2, head_length=0.2,color='silver')
  #ax2.arrow(1.5,1.5,0.707,-0.707,head_width=0.2, head_length=0.2,color='silver')
  ax2.arrow(1.5,1.5,0,-1,head_width=0.2, head_length=0.2,color='silver')
  #ax2.arrow(1.5,1.5,-0.707,-0.707,head_width=0.2, head_length=0.2,color='silver')
  ax2.arrow(1.5,1.5,-1,0,head_width=0.2, head_length=0.2,color='silver')
  ax2.arrow(1.5,1.5,-0.707,0.707,head_width=0.2, head_length=0.2,color='silver')
  for axi in (ax2.xaxis, ax2.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False

  ax3.axis('scaled')
  ax3.set_xticks(np.linspace(0,3,4))
  ax3.set_yticks(np.linspace(0,3,4))
  ax3.axis([0,3,0,3])
  ax3.grid(linestyle='-',color='black')
  #ax3.set_title('Actions')
  ax3.set_xlabel('Expert 2')
  ax3.imshow(action2_heat,cmap='viridis',extent=(0,3,3,0))
  obstacle16=plt.Rectangle((0,2),1,1,linewidth=1.5,edgecolor='black',facecolor='none')
  obstacle17=plt.Rectangle((2,0),1,1,linewidth=1.5,edgecolor='black',facecolor='none')
  ax3.add_patch(obstacle16)
  ax3.add_patch(obstacle17)
  ax3.scatter(0.5,2.5,s=80,c="r",marker="x")
  ax3.scatter(2.5,0.5,s=80,c="r",marker="x")
  ax3.arrow(1.5,1.5,0,1,head_width=0.2, head_length=0.2,color='silver')
  ax3.arrow(1.5,1.5,0.707,0.707,head_width=0.2, head_length=0.2,color='silver')
  ax3.arrow(1.5,1.5,1,0,head_width=0.2, head_length=0.2,color='silver')
  #ax3.arrow(1.5,1.5,0.707,-0.707,head_width=0.2, head_length=0.2,color='silver')
  ax3.arrow(1.5,1.5,0,-1,head_width=0.2, head_length=0.2,color='silver')
  ax3.arrow(1.5,1.5,-0.707,-0.707,head_width=0.2, head_length=0.2,color='silver')
  ax3.arrow(1.5,1.5,-1,0,head_width=0.2, head_length=0.2,color='silver')
  #ax3.arrow(1.5,1.5,-0.707,0.707,head_width=0.2, head_length=0.2,color='silver')
  for axi in (ax3.xaxis, ax3.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False

  ax4.axis('scaled')
  ax4.set_xticks(np.linspace(0,3,4))
  ax4.set_yticks(np.linspace(0,3,4))
  ax4.axis([0,3,0,3])
  ax4.grid(linestyle='-',color='black')
  #ax4.set_title('Actions')
  ax4.set_xlabel('Expert 3')
  ax4.imshow(action3_heat,cmap='viridis',extent=(0,3,3,0))
  obstacle14=plt.Rectangle((0,2),1,1,linewidth=1.5,edgecolor='black',facecolor='none')
  obstacle15=plt.Rectangle((2,2),1,1,linewidth=1.5,edgecolor='black',facecolor='none')
  ax4.add_patch(obstacle14)
  ax4.add_patch(obstacle15)
  ax4.scatter(0.5,2.5,s=80,c="r",marker="x")
  ax4.scatter(2.5,2.5,s=80,c="r",marker="x")
  ax4.arrow(1.5,1.5,0,1,head_width=0.2, head_length=0.2,color='silver')
  #ax4.arrow(1.5,1.5,0.707,0.707,head_width=0.2, head_length=0.2,color='silver')
  ax4.arrow(1.5,1.5,1,0,head_width=0.2, head_length=0.2,color='silver')
  ax4.arrow(1.5,1.5,0.707,-0.707,head_width=0.2, head_length=0.2,color='silver')
  ax4.arrow(1.5,1.5,0,-1,head_width=0.2, head_length=0.2,color='silver')
  ax4.arrow(1.5,1.5,-0.707,-0.707,head_width=0.2, head_length=0.2,color='silver')
  ax4.arrow(1.5,1.5,-1,0,head_width=0.2, head_length=0.2,color='silver')
  #ax4.arrow(1.5,1.5,-0.707,0.707,head_width=0.2, head_length=0.2,color='silver')
  for axi in (ax4.xaxis, ax4.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  plt.savefig('heat_map.pdf')  
  plt.show()


def single_expert_stochastic_dynamics(state,action):
  action=action.item()
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  sign=np.random.uniform()
  if sign<=0.2:
    the_next_state=[x,y]
  else:
    if action==0:
      the_next_state=[x,y]
    elif action==1:
      if y==8:
        the_next_state=[x,y]
      else:
        the_next_state=[x,y+1]
    elif action==2:
      if y==0:
        the_next_state=[x,y]
      else:
        the_next_state=[x,y-1]
    elif action==3:
      if x==0:
        the_next_state=[x,y]
      else:
        the_next_state=[x-1,y]
    elif action==4:
      if x==8:
        the_next_state=[x,y]
      else:
        the_next_state=[x+1,y]
    elif action==5:
      if x==0 or y==8:
        the_next_state=[x,y]
      else:
        the_next_state=[x-1,y+1]
    elif action==6:
      if x==8 or y==8:
        the_next_state=[x,y]
      else:
        the_next_state=[x+1,y+1]
    elif action==7:
      if x==0 or y==0:
        the_next_state=[x,y]
      else:
        the_next_state=[x-1,y-1]
    elif action==8:
      if x==8 or y==0:
        the_next_state=[x,y]
      else:
        the_next_state=[x+1,y-1]
  return np.mat(the_next_state).T

def single_expert_dynamics(state,action):
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  if action==0:
    the_next_state=[x,y]
  elif action==1:
    if y==8:
      the_next_state=[x,y]
    else:
      the_next_state=[x,y+1]
  elif action==2:
    if y==0:
      the_next_state=[x,y]
    else:
      the_next_state=[x,y-1]
  elif action==3:
    if x==0:
      the_next_state=[x,y]
    else:
      the_next_state=[x-1,y]
  elif action==4:
    if x==8:
      the_next_state=[x,y]
    else:
      the_next_state=[x+1,y]
  elif action==5:
    if x==0 or y==8:
      the_next_state=[x,y]
    else:
      the_next_state=[x-1,y+1]
  elif action==6:
    if x==8 or y==8:
      the_next_state=[x,y]
    else:
      the_next_state=[x+1,y+1]
  elif action==7:
    if x==0 or y==0:
      the_next_state=[x,y]
    else:
      the_next_state=[x-1,y-1]
  elif action==8:
    if x==8 or y==0:
      the_next_state=[x,y]
    else:
      the_next_state=[x+1,y-1]
  return np.mat(the_next_state).T

def single_expert_deterministic_dynamics(state,action):
  action=action.item()
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  if action==0:
    the_next_state=[x,y]
  elif action==1:
    the_next_state=[x,y+1]
  elif action==2:
    the_next_state=[x,y-1]
  elif action==3:
    the_next_state=[x-1,y]
  elif action==4:
    the_next_state=[x+1,y]
  elif action==5:
    the_next_state=[x-1,y+1]
  elif action==6:
    the_next_state=[x+1,y+1]
  elif action==7:
    the_next_state=[x-1,y-1]
  elif action==8:
    the_next_state=[x+1,y-1]
  return np.mat(the_next_state).T


def feature1(state,action):
  next_state=single_expert_deterministic_dynamics(state,action)
  if next_state.item(0)<0 or next_state.item(0)>8 or next_state.item(1)<0 or next_state.item(1)>8:
    return np.mat([0.0,8.0]).T
  elif state.item(0)==8 and state.item(1)==8:
    return np.mat([40.0,0.0]).T
  else:
    #return np.mat([0.0,0.0]).T
    return 1.0*np.mat([np.copy(state.item(0)),(8-np.copy(state.item(1)))]).T

def feature2(state,action):
  next_state=single_expert_deterministic_dynamics(state,action)
  if next_state.item(0)<0 or next_state.item(0)>8 or next_state.item(1)<0 or next_state.item(1)>8:
    return np.mat([0.0,8.0]).T
  elif state.item(0)==8 and state.item(1)==4:
    return np.mat([40.0,0.0]).T
  else:
    #return np.mat([0.0,0.0]).T
    return 1.0*np.mat([np.copy(state.item(0)),abs(4-np.copy(state.item(1)))]).T

def feature3(state,action):
  next_state=single_expert_deterministic_dynamics(state,action)
  if next_state.item(0)<0 or next_state.item(0)>8 or next_state.item(1)<0 or next_state.item(1)>8:
    return np.mat([0.0,8.0]).T
  elif state.item(0)==8 and state.item(1)==0:
    return np.mat([40.0,0.0]).T
  else:
    #return np.mat([0.0,0.0]).T
    return 1.0*state

def expert1_reward(omega,state,action):
  reward=np.dot(omega.T,feature1(state,action))
  return reward.item()

def expert2_reward(omega,state,action):
  reward=np.dot(omega.T,feature2(state,action))
  return reward.item()

def expert3_reward(omega,state,action):
  reward=np.dot(omega.T,feature3(state,action))
  return reward.item()

def single_expert_basis_state_constraint(single_expert_state):
  x=np.copy(single_expert_state.item(0))
  y=np.copy(single_expert_state.item(1))
  obstacle=np.mat(np.zeros((6,1)))
  if x==4 and y>=0 and y<=3:
    obstacle[0]=1000.0
  if x==3 and y==4:
    obstacle[1]=1000.0
  if x==2 and y==5:
    obstacle[1]=1000.0 
  if x==5 and y==4:
    obstacle[2]=1000.0
  if x==6 and y==5:
    obstacle[2]=1000.0
  if x==4 and y>=6 and y<=8:
    obstacle[3]=1000.0
  if x==2 and y>=6 and y<=7:
    obstacle[4]=1000.0
  if x==3 and y>=6 and y<=7:
    obstacle[4]=1000.0
  if x==1 and y>=2 and y<=3:
    obstacle[5]=1000.0
  if x==2 and y>=2 and y<=3:
    obstacle[5]=1000.0
  return obstacle

def expert_1_basis_constraint(state,action):
  action=action.item()
  state_basis_constraint=single_expert_basis_state_constraint(state)
  action_constraint=np.mat(np.zeros((2,1)))
  if action==7:
    action_constraint[0]=1000.0
  if action==8:
    action_constraint[1]=1000.0
  return np.vstack((state_basis_constraint,action_constraint))

def expert_2_basis_constraint(state,action):
  action=action.item()
  state_basis_constraint=single_expert_basis_state_constraint(state)
  action_constraint=np.mat(np.zeros((2,1)))
  if action==1:
    action_constraint[0]=1000.0
  if action==2:
    action_constraint[1]=1000.0
  return np.vstack((state_basis_constraint,action_constraint))

def expert_3_basis_constraint(state,action):
  action=action.item()
  state_basis_constraint=single_expert_basis_state_constraint(state)
  action_constraint=np.mat(np.zeros((4,1)))
  if action==5:
    action_constraint[0]=1000.0
  if action==6:
    action_constraint[1]=1000.0
  if action==7:
    action_constraint[2]=1000.0
  if action==8:
    action_constraint[3]=1000.0
  return np.vstack((state_basis_constraint,action_constraint))

def expert1_cost(theta,state,action):
  constraint_vector=expert_1_basis_constraint(state,action)
  cost=np.dot(theta.T,constraint_vector) 
  return cost.item()    #dimension is 8

def expert2_cost(theta,state,action):
  constraint_vector=expert_2_basis_constraint(state,action)
  cost=np.dot(theta.T,constraint_vector) 
  return cost.item()    #dimension is 8

def expert3_cost(theta,state,action):
  constraint_vector=expert_3_basis_constraint(state,action)
  cost=np.dot(theta.T,constraint_vector)
  return cost.item()    #dimension is 10


num_trials=100
gamma=0.9
draw_env(num_trials,gamma)
















