import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import argparse





def show_images(path, rho=""):
    # Number of images
    num_images = 100

    # Figure setup
    fig, axes = plt.subplots(12, 10, figsize=(16.5, 15))  # 10x10 grid
    fig.subplots_adjust(wspace=0.1, hspace=0.1)

    for row in range(12):
        for col in range(10):
            ax = axes[row, col]
            ax.axis('off')  # Hide axis labels
            if row == 0:
                if col == 0:
                    img_name = f"noisy_0.png"
                else:
                    continue
            elif row == 11:
                if col == 0:
                    img_name = f"original_0.png"
                else:
                    continue
            else:
                index = (row-1) * 10 + col
                img_name = f"{args.rho}/{index}_0.png"

            # print(row, col, img_name)
            img_path = os.path.join(path, img_name)
            if os.path.exists(img_path):
                img = mpimg.imread(img_path)
                ax.imshow(img)
            else:
                print(f"Image {img_path} does not exist")
                plt.close(fig)
                return

    # Save the final figure
    output_path = f"/home/ubuntu/results/mixed_diffusion/gibbs_samples/output_{rho}.png"
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close(fig)  # Close the figure to free memory

    print(f"Saved final image as {output_path}")


def main(args):
    path = f"/home/ubuntu/results/mixed_diffusion/gibbs_samples"
    show_images(path, args.rho)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Show images")
    parser.add_argument("--rho", type=str, default="exponential_70_20", help="Rho value")
    args = parser.parse_args()
    main(args)