import torch
from torchvision.utils import save_image

def print_img(x, path):
    H = round(1000*x[0][1][0].item())
    W = round(1000*x[0][1][1].item())
    # print(H, W)
    h = H
    w = W
    out = 1.0*torch.ones(3, h, w)
    for i in range(16):
        tmp3 = 1.0*torch.ones(3, h, w)
        tmpt = torch.zeros(1, h, w)
        tmp = x[0][:, 4*i:4*i+4]
        b0 = round((tmp[0][0].item() + 1) * (w/2.0))
        b0 = max(0, min(w, b0))
        b1 = round((tmp[0][1].item() + 1) * (h/2.0))
        b1 = max(0, min(h, b1))
        b2 = round((2.0*tmp[0][2].item() + tmp[0][0].item() + 1) * (w/2.0))
        b2 = max(0, min(w, b2))
        b3 = round((2.0*tmp[0][3].item() + tmp[0][1].item() + 1) * (h/2.0))
        b3 = max(0, min(h, b3))
        if (b0+b1+b2+b3>5):
            tmpt[:, b1:b3+1, b0:b2+1] = 1
        else:
            break

        typt = tmp[2:4, :]
        tt = -1
        for t in range(4):
            if (typt[0][t] > 0):
                tt = t
                break
        if (tt == -1) and (typt[1][0] > 0):
            tt = 4

        if (tt == 0):
            tmp3[0:1, :, :] = 211.0/255.0 * tmpt
            tmp3[1:2, :, :] = 241.0/255.0 * tmpt
            tmp3[2:3, :, :] = 253.0/255.0 * tmpt
            tmp3[0:1, b1:b1+1, b0:min(b2+1, w)] = 92.0/255.0
            tmp3[1:2, b1:b1+1, b0:min(b2+1, w)] = 191.0/255.0
            tmp3[2:3, b1:b1+1, b0:min(b2+1, w)] = 249.0/255.0
            tmp3[0:1, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 92.0/255.0
            tmp3[1:2, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 191.0/255.0
            tmp3[2:3, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 249.0/255.0
            tmp3[0:1, b1:min(b3+1, h), b0:b0+1] = 92.0/255.0
            tmp3[1:2, b1:min(b3+1, h), b0:b0+1] = 191.0/255.0
            tmp3[2:3, b1:min(b3+1, h), b0:b0+1] = 249.0/255.0
            tmp3[0:1, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 92.0/255.0
            tmp3[1:2, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 191.0/255.0
            tmp3[2:3, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 249.0/255.0
        elif(tt == 1):
            tmp3[0:1, :, :] = 214.0/255.0 * tmpt
            tmp3[1:2, :, :] = 254.0/255.0 * tmpt
            tmp3[2:3, :, :] = 231.0/255.0 * tmpt
            tmp3[0:1, b1:b1+1, b0:min(b2+1, w)] = 119.0/255.0
            tmp3[1:2, b1:b1+1, b0:min(b2+1, w)] = 251.0/255.0
            tmp3[2:3, b1:b1+1, b0:min(b2+1, w)] = 146.0/255.0
            tmp3[0:1, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 119.0/255.0
            tmp3[1:2, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 251.0/255.0
            tmp3[2:3, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 146.0/255.0
            tmp3[0:1, b1:min(b3+1, h), b0:b0+1] = 119.0/255.0
            tmp3[1:2, b1:min(b3+1, h), b0:b0+1] = 251.0/255.0
            tmp3[2:3, b1:min(b3+1, h), b0:b0+1] = 146.0/255.0
            tmp3[0:1, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 119.0/255.0
            tmp3[1:2, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 251.0/255.0
            tmp3[2:3, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 146.0/255.0
        elif(tt == 2):
            tmp3[0:1, :, :] = 224.0/255.0 * tmpt
            tmp3[1:2, :, :] = 254.0/255.0 * tmpt
            tmp3[2:3, :, :] = 208.0/255.0 * tmpt
            tmp3[0:1, b1:b1+1, b0:min(b2+1, w)] = 133.0/255.0
            tmp3[1:2, b1:b1+1, b0:min(b2+1, w)] = 251.0/255.0
            tmp3[2:3, b1:b1+1, b0:min(b2+1, w)] = 81.0/255.0
            tmp3[0:1, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 133.0/255.0
            tmp3[1:2, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 251.0/255.0
            tmp3[2:3, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 81.0/255.0
            tmp3[0:1, b1:min(b3+1, h), b0:b0+1] = 133.0/255.0
            tmp3[1:2, b1:min(b3+1, h), b0:b0+1] = 251.0/255.0
            tmp3[2:3, b1:min(b3+1, h), b0:b0+1] = 81.0/255.0
            tmp3[0:1, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 133.0/255.0
            tmp3[1:2, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 251.0/255.0
            tmp3[2:3, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 81.0/255.0
        elif(tt == 3):
            tmp3[0:1, :, :] = 255.0/255.0 * tmpt
            tmp3[1:2, :, :] = 255.0/255.0 * tmpt
            tmp3[2:3, :, :] = 209.0/255.0 * tmpt
            tmp3[0:1, b1:b1+1, b0:min(b2+1, w)] = 255.0/255.0
            tmp3[1:2, b1:b1+1, b0:min(b2+1, w)] = 255.0/255.0
            tmp3[2:3, b1:b1+1, b0:min(b2+1, w)] = 118.0/255.0
            tmp3[0:1, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 255.0/255.0
            tmp3[1:2, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 255.0/255.0
            tmp3[2:3, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 118.0/255.0
            tmp3[0:1, b1:min(b3+1, h), b0:b0+1] = 255.0/255.0
            tmp3[1:2, b1:min(b3+1, h), b0:b0+1] = 255.0/255.0
            tmp3[2:3, b1:min(b3+1, h), b0:b0+1] = 118.0/255.0
            tmp3[0:1, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 255.0/255.0
            tmp3[1:2, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 255.0/255.0
            tmp3[2:3, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 118.0/255.0
        elif(tt == 4):
            tmp3[0:1, :, :] = 249.0/255.0 * tmpt
            tmp3[1:2, :, :] = 218.0/255.0 * tmpt
            tmp3[2:3, :, :] = 206.0/255.0 * tmpt
            tmp3[0:1, b1:b1+1, b0:min(b2+1, w)] = 235.0/255.0
            tmp3[1:2, b1:b1+1, b0:min(b2+1, w)] = 86.0/255.0
            tmp3[2:3, b1:b1+1, b0:min(b2+1, w)] = 45.0/255.0
            tmp3[0:1, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 235.0/255.0
            tmp3[1:2, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 86.0/255.0
            tmp3[2:3, min(b3+1, h)-1:min(b3+1, h), b0:min(b2+1, w)] = 45.0/255.0
            tmp3[0:1, b1:min(b3+1, h), b0:b0+1] = 235.0/255.0
            tmp3[1:2, b1:min(b3+1, h), b0:b0+1] = 86.0/255.0
            tmp3[2:3, b1:min(b3+1, h), b0:b0+1] = 45.0/255.0
            tmp3[0:1, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 235.0/255.0
            tmp3[1:2, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 86.0/255.0
            tmp3[2:3, b1:min(b3+1, h), min(b2+1, w)-1:min(b2+1, w)] = 45.0/255.0
        
        out[:, b1:min(b3+1, h), b0:min(b2+1, w)] = tmp3[:, b1:min(b3+1, h), b0:min(b2+1, w)]

    save_image(out, path, nrow=4, normalize=False, value_range=(-1, 1))

for i in range(100): #number of layouts that are sampled
    pathin = '' #.pt file path
    pathout = '' #.png image save path
    try:
        x = torch.load(pathin)
    except FileNotFoundError:
        continue
    print_img(x, pathout)