from PIL import Image
import os

def obs_to_img(obs):
    if len(obs.shape) == 4:
        o = obs.squeeze().detach()
    else:
        o = obs.clone().detach()

    if o.shape[0] == 3:
        o = o.permute(2, 1, 0)
    
    assert o.shape[2] == 3

    img = Image.fromarray(o.numpy(), "RGB")
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../..", "img.png")
    img.save(path)