import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt

def save_sample_images(save_path, name, epoch, img_list):
    if not os.path.isdir(save_path):
        os.makedirs(save_path)
    fig = plt.figure(figsize=(4,4))
    for i in range(64):
        plt.subplot(8,8,i+1)
        if img_list.shape[1] == 1:
            plt.imshow(np.squeeze(np.transpose(img_list[i,:,:,:], (1,2,0))), cmap = 'gray')
        else:
            plt.imshow(np.squeeze(np.transpose(img_list[i,:,:,:], (1,2,0))))
        plt.axis('off')
    
    fig.tight_layout(pad = 0)
    fig.subplots_adjust(wspace=0.0, hspace = 0.0)
    plt.savefig('%s/%s-%03d.png' % (save_path, name, epoch + 1))
    return