# from typing import Literal
from dataclasses import dataclass

import numpy as np
import pytorch_lightning as pl

from vis_datasets.natural.cifar import CifarData, CifarDataConfig, DataLoaderConfig
from ..cross_transforms.data import (
    CTDataConfig,
    get_random_configs,
    # reexported
    create_transforms_datasets,
)
from utils.data import (
    DataConfig,
    create_obj2d_dataset,
    DATASET_PARAMETERS,
)
from datasets.objects_2d import (
    RandomForegroundsConfig,
    UniformBackgroundsConfig,
)


# TviDatasetType = Literal["obj2d"]

@dataclass
class RIDataConfig(CTDataConfig):
    # object_sample_seed: int# = 5103
    # n_object_samples: int# = 100
    pass


def create_single_object_dataset(
    config: RIDataConfig,
    transform_name: str,
    object_id: int,
    random_object: bool,
    n_distance_samples: int,
    normalize: bool = True,
) -> pl.LightningDataModule:
    if config.dataset == "obj2d":
        create_dataset = create_obj2d_dataset
    else:
        raise ValueError(f"Invalid dataset config: ${config.dataset}")

    dataset_config = DataConfig(
        # Use a different sampling seed for each object_id to prevent the
        # transformations and backgrounds from being sampled
        # in exactly the same way
        sampling_seed=config.sampling_seed + 113 + object_id,
        img_size=config.img_size,
        n_training_samples=n_distance_samples,
        n_val_samples=n_distance_samples,
        n_test_samples=n_distance_samples,
        batch_size=config.batch_size,
    )
    if random_object:
        fg_object_config, bg_config = get_random_configs(
            n_classes=1,
            transforms=[transform_name],
            config_seed=config.config_seed + object_id,
            img_size=config.img_size,
        )
    else:
        fg_object_config = [object_id]
        bg_config = None
        # bg_config = UniformBackgroundsConfig(color=[0, 0, 0], img_size=config.img_size)

    # print("data config:")
    # print(dataset_config)
    # print(fg_object_config, bg_config)
    # print("random object:", random_object)
    return create_dataset(
        dataset_config,
        [transform_name],
        fg_object_config,
        backgrounds=bg_config,
        normalize=normalize,
    )

    # cifar_conf = CifarDataConfig(
    #     cifar_type="cifar10",
    #     loader_config=DataLoaderConfig(
    #         batch_size=256,
    #     ),
    # )
    # return CifarData(cifar_conf)
