import os
import numpy as np

import matplotlib.pyplot as plt

def main(
    file_path="output_CIFAR10",
    img_idx=0,
    file_idx=0,
    C=1
):
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]

    for i, mu in enumerate(mus):
        batch_file = os.path.join(file_path, f"mu_{mu:04d}_C_{int(C)}_batch_{file_idx}.npz")
        data = np.load(batch_file)
        imgs = data["img"][img_idx]
        noisy = data["noisy"][img_idx]
        denoised = data["denoised"][img_idx]
        imgs = imgs*0.5+0.5
        denoised = (denoised*0.5)+0.5

        if i==0:
            plt.subplot(2, len(mus)+1, 1)
            plt.title(f"Original")
            plt.imshow(imgs.transpose(1,2,0).clip(0,1))
            plt.axis('off')
            plt.subplot(2, len(mus)+1, len(mus)+2)
            plt.imshow(imgs.transpose(1,2,0).clip(0,1))
            plt.axis('off')
        plt.subplot(2, len(mus)+1, 2+i)
        plt.title(f"$\mu$={mu}")
        plt.imshow(noisy.transpose(1,2,0).clip(0,1))
        plt.axis('off')
        plt.subplot(2, len(mus)+1, len(mus)+3+i)
        plt.imshow(denoised.transpose(1,2,0).clip(0,1))
        plt.axis('off')

        plt.imsave(f"Original{img_idx}.png", imgs.transpose(1,2,0).clip(0,1))
        plt.imsave(f"Noisy{img_idx}_mu_{mu:04d}.png", noisy.transpose(1,2,0).clip(0,1))
        plt.imsave(f"Denoised{img_idx}_mu_{mu:04d}.png", denoised.transpose(1,2,0).clip(0,1))

    plt.tight_layout()
    plt.show()
    plt.close()

if __name__ == "__main__":
    ids = [0,1,2]
    for id in ids:
        main(img_idx=id)