import torch
from PIL import Image
import numpy as np
import torch.nn as nn
import numpy as np


def color_map(prob,colors):
 
    patch = torch.matmul(prob.squeeze(0).permute(1, 2, 0), colors).clamp(0, 255)

    return patch

if __name__=="__main__":
    prob = torch.from_numpy(np.load(""))
    prob = torch.clamp(prob,0,1)

    colors = torch.tensor([],dtype=torch.float)
    
    patch = color_map(prob,colors)
    Image.fromarray(patch.detach().numpy().astype(np.uint8)).save("test.png")
