import matplotlib.pyplot as plt
import numpy as np
import os
from itertools import chain
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset

class MergedDataLoader:
    def __init__(self, dataloader1, dataloader2):
        self.dataloader1 = dataloader1
        self.dataloader2 = dataloader2
        self.iterable = chain(dataloader1, dataloader2)  # Merge using chain
        self.length = len(dataloader1.dataset) + len(dataloader2.dataset)  # Sum of dataset lengths

    def __iter__(self):
        return iter(self.iterable)  # Return iterator over merged data

    def __len__(self):
        return self.length // self.dataloader1.batch_size  # Total batches

def show_images(batch_images, binary_pred, cols=2, figsize=(10, 10), savedir="", title=""):
    """
    Displays a batch of RGB images.

    Args:
        batch_images (numpy.ndarray): The batch of images to display.
                                      Shape should be (batch_size, 256, 256, 3).
        cols (int): Number of columns in the display grid.
        figsize (tuple): Size of the figure for displaying the images.
    """
    # Check if the input is a numpy ndarray
    if not isinstance(batch_images, np.ndarray):
        raise TypeError("The input must be a numpy.ndarray.")

    # Check that the shape is correct
    assert len(batch_images.shape) == 4 and batch_images.shape[-1] == 3, \
        "Input batch must have shape (batch_size, 256, 256, 3) for RGB images."

    batch_size = batch_images.shape[0]
    rows = (batch_size + cols - 1) // cols  # Calculate number of rows needed

    # Create a matplotlib figure to show the images
    fig, axes = plt.subplots(rows, cols, figsize=figsize)

    # Display each image in the batch
    for i, ax in enumerate(axes.flat):
        if i < batch_size:
            ax.imshow(batch_images[i])
            ax.set_title(np.sum(binary_pred[i][0]) )
            ax.axis('off')  # Hide the axes
        else:
            ax.axis('off')  # Hide unused subplots if batch_size is not a perfect multiple of cols

    plt.tight_layout()
    plt.savefig(os.path.join(savedir, title + ".pdf") )
    plt.show()



    batch_size = batch_images.shape[0]
    # Ensure savedir exists
    os.makedirs(savedir, exist_ok=True)
    # Save each image separately
    for i in range(batch_size):
        fig, ax = plt.subplots(figsize=(5, 5))  # Adjust the figsize as needed
        ax.imshow(batch_images[i])
        ax.axis('off')  # Hide the axes
        plt.tight_layout()

        # Save each image with a unique title
        save_path = os.path.join(savedir, f"{title}_{i + 1}_label_{np.sum(binary_pred[i][0])}.pdf")
        plt.savefig(save_path)
        plt.close(fig)  # Close the figure after saving to free memory

        