from typing import Literal
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from vis_analysis_utils.visualize import images

DataType = Literal["train", "val", "test"]

def show_dataset_samples(
    data: pl.LightningDataModule,
    n_samples: int,
    data_type: DataType,
    figsize=(20, 10),
    show_titles: bool = True,
) -> plt.Figure:
    data.setup()
    if data_type == "train":
        vis_loader = data.train_dataloader()
    elif data_type == "val":
        vis_loader = data.val_dataloader()
    elif data_type == "test":
        vis_loader = data.test_dataloader()
    else:
        raise ValueError()
    fig = images.show_images(
        vis_loader,
        n_images=n_samples,
        figsize=figsize,
        show_titles=show_titles,
    )
    return fig
