import re
import matplotlib.pyplot as plt
import numpy as np
import copy
import os
import random
from scipy.ndimage.filters import gaussian_filter

fig = plt.figure()
axes=[]


def heatmap2d(arr: np.ndarray, picName, bar=False, cmapToUse='viridis'):
    axes.append( fig.add_subplot(1, 3, len(axes)+1) )
    subplot_title=(picName)
    axes[-1].set_title(subplot_title)
    plt.imshow(arr, cmap=cmapToUse)
    plt.xticks([]) 
    plt.yticks([]) 
    if bar:
        plt.colorbar()
    # plt.savefig("plot"+picName+".png")
    # plt.show()

############# pol target
img = np.zeros(300*400).reshape(300,400)
print(img.shape)
for i in range(40,160):
    img[i, 118-int(i/4):122-int(i/4)] = 1
    img[i, 158+int(i/4):162+int(i/4)] = 1
for i in range(40,100):
    img[i, 88+int(i/2):92+int(i/2)] = 1
    img[i, 188-int(i/2):192-int(i/2)] = 1

for i in range(40,160):
    img[i, 246:250] = 1
for i in range(246,320):
    img[158:162, i] = 1

for i in range(246,320):
    img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-20] = -1
    img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-140] = -1
heatmap2d(img[20:,55:340], 'policy & target', cmapToUse='bwr')
####################

#################### gail
img = np.zeros(300*400).reshape(300,400)
print(img.shape)
for i in range(40,160):
    img[i, 118-int(i/4):122-int(i/4)] = 1
    img[i, 158+int(i/4):162+int(i/4)] = 1
for i in range(40,100):
    img[i, 88+int(i/2):92+int(i/2)] = 1
    img[i, 188-int(i/2):192-int(i/2)] = 1

for i in range(40,160):
    img[i, 246:250] = 1
for i in range(246,320):
    img[158:162, i] = 1

# for i in range(246,320):
#     img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-20] = 1
#     img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-140] = 1

blurred = 5*gaussian_filter(img, sigma=12) + img
blurred = gaussian_filter(blurred, sigma=3)
print(np.max(blurred))
z = 1/(1 + np.exp(-2*blurred))
heatmap2d(z[20:,55:340], 'GAIL')
##########################

img = np.zeros(300*400).reshape(300,400)
print(img.shape)
# the outer lines in M
for i in range(40,160):
    img[i, 118-int(i/4):122-int(i/4)] = 1
    img[i, 158+int(i/4):162+int(i/4)] = 1

# the inner lines in M
for i in range(40,100):
    img[i, 88+int(i/2):92+int(i/2)] = 1
    img[i, 188-int(i/2):192-int(i/2)] = 1

# vertical line in L
p = 0
for i in range(40,160):
    p+=1
    img[i, 246:250] = 1
# horizontal line in L
for i in range(246,320):
    p+=1
    img[158:162, i] = 1
print(p)

# # squiggly
# for i in range(246,320):
#     img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-20] = 1
#     img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-140] = 1

Lx = []
for i in range(40,160):
    Lx.append(i)
Lx = Lx + [161]*74

Ly = [246]*120
for i in range(246,320):
    Ly.append(i)

# L to squig
for i in range(194):
    for j in range(1000):
        alpha = 0.001*j
        squig = i*74/194

        x0 = 258+int(20*np.sin((squig+246)/15))
        x1 = Lx[i]

        y0 = squig+246 - 20
        y1 = Ly[i]
        x_alpha = x0 + alpha*(x1 - x0)
        y_alpha = y0 + alpha*(y1 - y0)
        img[int(x_alpha),int(y_alpha)] = alpha


Mx = []
for j in range(160,40,-1):
    Mx.append(j)
for k in range(40,100):
    Mx.append(k)
for k in range(100,40,-1):
    Mx.append(k)
for j in range(40,160):
    Mx.append(j)

My = []
for j in range(160,40,-1):
    My.append(120-int(j/4))
for k in range(40,100):
    My.append(90+int(k/2))
for k in range(100,40,-1):
    My.append(190-int(k/2))
for j in range(40,160):
    My.append(160+int(j/4))

# now M
for i in range(360):
    for j in range(1000):
        alpha = 0.001*j
        squig = i*74/360

        x0 = 258+int(20*np.sin((squig+246)/15))
        x1 = Mx[i]

        y0 = squig+246 - 140
        y1 = My[i]
        x_alpha = x0 + alpha*(x1 - x0)
        y_alpha = y0 + alpha*(y1 - y0)
        img[int(x_alpha),int(y_alpha)] = alpha
img = gaussian_filter(img, sigma=0.6)
heatmap2d(img[20:,55:340], 'SS-GAIL')








fig.tight_layout()  
plt.show()