import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

color_list=['tab:green','tab:orange','tab:blue']

def env():
  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,9,10))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,10,0,13])
  ax.grid(linestyle='-',color='black')
  plt.plot([0,10], [0,0], color='black', linewidth=2)
  plt.plot([3,10], [2,2], color='black', linewidth=2)
  #plt.plot([0,1], [1,1], color='black', linewidth=2)
  plt.plot([0,0], [0,13], color='black', linewidth=2)
  #plt.plot([0,0], [6,7], color='black', linewidth=2)
  plt.plot([1,1], [0,13], color='black', linewidth=2)
  #plt.plot([1,1], [7,12], color='black', linewidth=2)
  #plt.plot([0,1], [12,12], color='black', linewidth=2)
  plt.plot([3,10], [11,11], color='black', linewidth=2)
  plt.plot([3,3], [2,11], color='black', linewidth=2)
  #plt.plot([0,1], [6,6], color='black', linewidth=2)
  #plt.plot([0,1], [7,7], color='black', linewidth=2)
  #plt.plot([0,0], [12,13], color='black', linewidth=2)
  plt.plot([0,10], [13,13], color='black', linewidth=2)
  plt.plot([10,10], [0,2], color='black', linewidth=2)
  plt.plot([10,10], [11,13], color='black', linewidth=2)

  ax.add_patch(plt.Rectangle((8,0),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax.add_patch(plt.Rectangle((2,1),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax.add_patch(plt.Rectangle((3,11),1,1,linewidth=2,edgecolor='black',facecolor='none'))
  ax.add_patch(plt.Rectangle((7,12),1,1,linewidth=2,edgecolor='black',facecolor='none'))

  ax.scatter(8.5,0.5,s=160,c="r",marker="x")
  ax.scatter(2.5,1.5,s=160,c="r",marker="x")
  ax.scatter(3.5,11.5,s=160,c="r",marker="x")
  ax.scatter(7.5,12.5,s=160,c="r",marker="x")
  for y in range(13):
    ax.scatter(0.5,y+0.5,s=160,c="r",marker="x")
  for x in range(3,9):
    for y in range(2,11):
      ax.scatter(x+0.5,y+0.5,s=160,c="r",marker="x")
  for y in range(2,8):
    ax.scatter(9.5,y+0.5,s=160,c="r",marker="x")
  ax.scatter(9.5,10.5,s=160,c="r",marker="x")


  ax.text(0.25,0.35,'$L2$',fontsize=10)
  ax.text(0.25,12.35,'$L1$',fontsize=10)
  ax.text(9.25,0.35,'$E2$',fontsize=10)
  ax.text(9.25,12.35,'$E1$',fontsize=10)
  ax.text(0.35,6.35,'$G$',fontsize=10)
  #ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  plt.show()

env()

def dynamics(state,action):   # input and output are matrices
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  a=action.item()
  if a==0 and y<12:
    y=y+1
  if a==1 and y>0:
    y=y-1
  if a==2 and x>0:
    x=x-1
  if a==3 and x<9:
    x=x+1
  return np.mat([x,y]).T


















