import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

color_list=['tab:green','tab:orange','tab:blue']
num_trials=50

def env():
  a=np.loadtxt("learner_trajectory_file1.txt",dtype=float)
  trajectories=a.reshape(35*num_trials,6)
  state_heat=np.zeros((13,10))
  for i in range(num_trials):
    trajectory=trajectories[35*i:35*(i+1),:]
    for j in range(35):
      state_heat[int(trajectory[j,1]),int(trajectory[j,0])]=state_heat[int(trajectory[j,1]),int(trajectory[j,0])]+1
      state_heat[int(trajectory[j,3]),int(trajectory[j,2])]=state_heat[int(trajectory[j,3]),int(trajectory[j,2])]+1
  state_heat=state_heat/100.0
  state_heat[0,1]=1.0
  state_heat[0,2]=1.0
  state_heat[0,3]=1.0
  state_heat[0,9]=1.0
  state_heat[11,8]=1.0
  state_heat[12,9]=1.0
  state_heat[11,6]=1.0
  state_heat[11,7]=1.0

  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,9,10))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,10,0,13])
  ax.grid(linestyle='-',color='black')
  plt.plot([0,10], [0,0], color='black', linewidth=2)
  plt.plot([3,10], [2,2], color='black', linewidth=2)
  #plt.plot([0,1], [1,1], color='black', linewidth=2)
  plt.plot([0,0], [0,13], color='black', linewidth=2)
  #plt.plot([0,0], [6,7], color='black', linewidth=2)
  plt.plot([1,1], [0,13], color='black', linewidth=2)
  #plt.plot([1,1], [7,12], color='black', linewidth=2)
  #plt.plot([0,1], [12,12], color='black', linewidth=2)
  plt.plot([3,10], [11,11], color='black', linewidth=2)
  plt.plot([3,3], [2,11], color='black', linewidth=2)
  #plt.plot([0,1], [6,6], color='black', linewidth=2)
  #plt.plot([0,1], [7,7], color='black', linewidth=2)
  #plt.plot([0,0], [12,13], color='black', linewidth=2)
  plt.plot([0,10], [13,13], color='black', linewidth=2)
  plt.plot([10,10], [0,2], color='black', linewidth=2)
  plt.plot([10,10], [11,13], color='black', linewidth=2)

  ax.add_patch(plt.Rectangle((8,0),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax.add_patch(plt.Rectangle((2,1),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax.add_patch(plt.Rectangle((3,11),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax.add_patch(plt.Rectangle((7,12),1,1,linewidth=2,edgecolor='black',facecolor='none'))

  ax.scatter(8.5,0.5,s=160,c="r",marker="x")
  ax.scatter(2.5,1.5,s=160,c="r",marker="x")
  ax.scatter(3.5,11.5,s=160,c="r",marker="x")
  ax.scatter(7.5,12.5,s=160,c="r",marker="x")
  for y in range(13):
    ax.scatter(0.5,y+0.5,s=160,c="r",marker="x")
  for x in range(3,9):
    for y in range(2,11):
      ax.scatter(x+0.5,y+0.5,s=160,c="r",marker="x")
  for y in range(2,8):
    ax.scatter(9.5,y+0.5,s=160,c="r",marker="x")
  ax.scatter(9.5,10.5,s=160,c="r",marker="x")


  ax.text(0.25,0.35,'$L2$',color='white',fontsize=10)
  ax.text(0.25,12.35,'$L1$',color='white',fontsize=10)
  ax.text(9.25,0.35,'$E2$',fontsize=10)
  ax.text(9.25,12.35,'$E1$',fontsize=10)
  ax.text(0.35,6.35,'$G$',color='white',fontsize=10)

  im=ax.imshow(state_heat,cmap='viridis',extent=[0,10,13,0])
  #cb=plt.colorbar(im, ax=[ax],fraction=0.046, pad=0.04,location='left')
  #cb.set_label('Visitation Frequency (Scaled)')
  #ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  plt.show()

env()

def dynamics(state,action):   # input and output are matrices
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  a=action.item()
  if a==0 and y<12:
    y=y+1
  if a==1 and y>0:
    y=y-1
  if a==2 and x>0:
    x=x-1
  if a==3 and x<9:
    x=x+1
  return np.mat([x,y]).T


















