import matplotlib.pyplot as plt
import torch


def visualize_batch(train_dl):
    batch = next(iter(train_dl))
    low_res, high_res = batch

    low_res = (low_res + 1) / 2  # Rescale to [0, 1]
    high_res = (high_res + 1) / 2  # Rescale to [0, 1]

    show = min(len(low_res), 4)

    fig, axs = plt.subplots(2, show, figsize=(20, 8))

    for i in range(show):
        axs[0, i].imshow(low_res[i].permute(1, 2, 0).numpy())
        axs[0, i].axis("off")
        axs[0, i].set_title("Low Resolution")

        axs[1, i].imshow(high_res[i].permute(1, 2, 0).numpy())
        axs[1, i].axis("off")
        axs[1, i].set_title("High Resolution")

    plt.tight_layout()
    plt.show()


def visualize_results_imgs(y, pred_x0, x0, traj, index=0, show=5):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 3, 1)
    plt.imshow((y[index].permute(1, 2, 0).cpu().numpy() + 1.0) * 0.5)
    plt.title("Low Resolution Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow((pred_x0[index].permute(1, 2, 0).cpu().clamp(-1, 1).numpy() + 1.0) * 0.5)
    plt.title("Sampled Image")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow((x0[index].permute(1, 2, 0).cpu().numpy() + 1.0) * 0.5)
    plt.title("Original Image")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

    if traj is not None:
        indeces = torch.linspace(0, len(traj) - 1, show, dtype=torch.long)

        plt.figure(figsize=(show * 5, 6))
        for i, idx in enumerate(indeces):
            plt.subplot(1, show, i + 1)
            img = (traj[idx, 0].permute(1, 2, 0).cpu().clamp(-1, 1).numpy() + 1.0) * 0.5
            plt.imshow(img)
            plt.title(f"t = {1 - idx / len(traj):.4f}")
            plt.axis("off")

        plt.tight_layout()
        plt.show()
