"""
Visualize the expressions in the form of images
"""
import matplotlib.pyplot as plt
from src.Models.models import ParseModelOutput
from src.utils.train_utils import chamfer
import numpy as np
from src.utils.train_utils import prepare_input_op, beams_parser, validity, image_from_expressions
import ot

# Load the terminals symbols of the grammar
canvas_shape = [64, 64]
max_len = 10

with open("terminals.txt", "r") as file:
    unique_draw = file.readlines()
for index, e in enumerate(unique_draw):
    unique_draw[index] = e[0:-1]



# Fill the expressions that you want to render
#expressions = ["c(32,32,28)c(32,32,24)-s(32,32,28)s(32,32,20)-+t(32,32,20)+", "c(32,32,28)c(32,32,24)-"]
img1 = ["c(16,48,8)s(32,16,12)+"]
img2 = ["c(16,48,8)c(24,40,12)+"] #s(24,40,12)
parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len, canvas_shape)
predicted_img1 = image_from_expressions(parser, img1)
predicted_img2 = image_from_expressions(parser, img2)

index_1 = np.array(list(np.where(predicted_img1[0] == 1)))
index_2 = np.array(list(np.where(predicted_img2[0] == 1)))

a, b = np.ones((index_1.shape[1],)) / index_1.shape[1], np.ones((index_2.shape[1],)) / index_2.shape[1]

n = 64
M = ot.dist(index_1.T, index_2.T)
M_max = np.amax(M)
M = M/M_max

print (ot.emd2(a, b, M))

#plt.imshow(predicted_images[0], cmap="Greys")
#plt.grid("off")
#plt.axis("off")
#plt.show()