import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

color_list=['tab:green','tab:orange','tab:blue']

def env():
  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,9,10))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,10,0,13])
  ax.grid(linestyle='-',color='black')
  plt.plot([0,10], [0,0], color='black', linewidth=2)
  plt.plot([3,10], [2,2], color='black', linewidth=2)
  plt.plot([0,1], [1,1], color='black', linewidth=2)
  plt.plot([0,0], [0,1], color='black', linewidth=2)
  plt.plot([0,0], [6,7], color='black', linewidth=2)
  plt.plot([1,1], [1,6], color='black', linewidth=2)
  plt.plot([1,1], [7,12], color='black', linewidth=2)
  plt.plot([0,1], [12,12], color='black', linewidth=2)
  plt.plot([3,10], [11,11], color='black', linewidth=2)
  plt.plot([3,3], [2,11], color='black', linewidth=2)
  plt.plot([0,1], [6,6], color='black', linewidth=2)
  plt.plot([0,1], [7,7], color='black', linewidth=2)
  plt.plot([0,0], [12,13], color='black', linewidth=2)
  plt.plot([0,10], [13,13], color='black', linewidth=2)
  plt.plot([10,10], [0,2], color='black', linewidth=2)
  plt.plot([10,10], [11,13], color='black', linewidth=2)
  plt.plot([8,8], [0,1], color='black', linewidth=2)
  plt.plot([9,9], [0,1], color='black', linewidth=2)
  plt.plot([8,9], [1,1], color='black', linewidth=2)
  plt.plot([3,3], [11,12], color='black', linewidth=2)
  plt.plot([4,4], [11,12], color='black', linewidth=2)
  plt.plot([3,4], [12,12], color='black', linewidth=2)
  plt.plot([2,2], [1,2], color='black', linewidth=2)
  plt.plot([3,3], [1,2], color='black', linewidth=2)
  plt.plot([2,3], [1,1], color='black', linewidth=2)
  plt.plot([2,3], [2,2], color='black', linewidth=2)
  plt.plot([7,7], [12,13], color='black', linewidth=2)
  plt.plot([8,8], [12,13], color='black', linewidth=2)
  plt.plot([7,8], [12,12], color='black', linewidth=2)

  obstacle1=plt.Rectangle((8,0),1,1,facecolor='none',hatch='//')
  obstacle2=plt.Rectangle((3,11),1,1,facecolor='none',hatch='//')
  obstacle3=plt.Rectangle((2,1),1,1,facecolor='none',hatch='//')
  obstacle4=plt.Rectangle((1,3),1,1,facecolor='none',hatch='//')
  obstacle5=plt.Rectangle((7,12),1,1,facecolor='none',hatch='//')
  obstacle6=plt.Rectangle((3,2),7,9,facecolor='none',hatch='//')
  obstacle7=plt.Rectangle((0,1),1,5,facecolor='none',hatch='//')
  obstacle8=plt.Rectangle((0,7),1,5,facecolor='none',hatch='//')
  ax.add_patch(obstacle1)
  ax.add_patch(obstacle2)
  ax.add_patch(obstacle3)
  #ax.add_patch(obstacle4)
  ax.add_patch(obstacle5)
  ax.add_patch(obstacle6)
  ax.add_patch(obstacle7)
  ax.add_patch(obstacle8)
  #ax.scatter(2.5,3.5,s=160,c="r",marker="x")
  ax.text(0.25,0.35,'$L2$',fontsize=10)
  ax.text(0.25,12.35,'$L1$',fontsize=10)
  ax.text(9.25,0.35,'$E2$',fontsize=10)
  ax.text(9.25,12.35,'$E1$',fontsize=10)
  ax.text(0.35,6.35,'$G$',fontsize=10)
  #ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  plt.show()

env()

def dynamics(state,action):   # input and output are matrices
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  a=action.item()
  if a==0 and y<12:
    y=y+1
  if a==1 and y>0:
    y=y-1
  if a==2 and x>0:
    x=x-1
  if a==3 and x<9:
    x=x+1
  return np.mat([x,y]).T


















