import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import sys
import cv2

import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from PIL import Image

def vis_unet_feat_map(outs):
    vis_list = []
    for i in range(len(outs)):
        vis = outs[i][0].sum(dim=0).cpu().numpy()
        vis = ((vis-vis.min())/(vis.max() - vis.min()) *255).astype(np.uint8)
        vis_list.append(Image.fromarray(vis).resize((512,512)).convert("RGB"))
    return 