import matplotlib.pyplot as plt

from utils.visualize.images import show_dataset_samples
from utils.data import create_obj2d_dataset, DataConfig
from datasets.objects_2d import (
    RandomForegroundsConfig,
    RandomBackgroundsConfig,
)


def show_rw_samples(
    transform: str,
    n_samples: int = 4,
    seed: int = 423,
    random: bool = False,
) -> plt.Figure:
    n_classes = 60
    img_size = 32
    config = DataConfig(
        sampling_seed=seed,
        img_size=img_size,
        n_training_samples=0,
        n_val_samples=0,
        n_test_samples=n_samples,
        batch_size=n_samples,
    )

    if random:
        foregrounds = RandomForegroundsConfig(
            n_classes=n_classes,
            transforms=[transform],
            img_size=img_size,
            seed=seed + 224,
        )
        backgrounds = RandomBackgroundsConfig(
            n_backgrounds=2 * n_samples,
            img_size=img_size,
            seed=seed + 32,
        )
    else:
        foregrounds = list(range(n_classes))
        backgrounds = None
    dataset = create_obj2d_dataset(
        config,
        transforms=[transform],
        foregrounds=foregrounds,
        backgrounds=backgrounds,
        normalize=False,
    )
    fig = show_dataset_samples(
        dataset,
        n_samples,
        "test",
        show_titles=False,
        figsize=(3, 1),
    )
    fig.tight_layout()
    plt.show()
    return fig
