from dataclasses import dataclass
from copy import deepcopy
from typing import Union

import pytorch_lightning as pl

from vis_datasets.loading import DataLoaderConfig
# from datasets import objects_2d
from datasets.objects_2d import (
    OBJECTS as OBJ_2D_OBJECTS,
    TRANSFORMS_WITH_NONE as OBJ_2D_TRANSFORMS,
    FGSource,
    BGSource,
    ImageForegroundsConfig,
    ImageBackgroundsConfig,
    CompositeTransform,
    Composite2DDatasetConfig,
    Objects2DDataConfig,
    Objects2DData,
)


@dataclass
class DatasetParameters:
    class_objects: list
    transforms: list

DATASET_PARAMETERS = {
    "obj2d": DatasetParameters(
        class_objects=list(OBJ_2D_OBJECTS),
        transforms=list(OBJ_2D_TRANSFORMS.keys()),
    ),
}


@dataclass
class DataConfig:
    # Used to draw samples from the dataset
    sampling_seed: int
    img_size: int
    n_training_samples: int
    n_val_samples: int
    n_test_samples: int
    batch_size: int

def create_obj2d_dataset(
    config: DataConfig,
    # config: objects_2d.Objects2DDataConfig,
    transforms: CompositeTransform,
    foregrounds: Union[FGSource, list[int]],
    backgrounds: Union[BGSource, None] = None,
    normalize: bool = True,
    class_from_foreground: bool = True,
    fg_bg_correlation: float = 0.0,
    fg_availability: float = 1.0,
) -> Objects2DData:
    if isinstance(foregrounds, list):
        foregrounds_config = ImageForegroundsConfig(
            fg_class_indices=foregrounds,
            transforms=transforms,
            img_size=config.img_size,
        )
    elif isinstance(foregrounds, pl.LightningDataModule):
        foregrounds_config = foregrounds
    else:
        foregrounds_config = deepcopy(foregrounds)
        foregrounds_config.img_size = config.img_size
        foregrounds_config.transforms = transforms
    if backgrounds is None:
        backgrounds_config = ImageBackgroundsConfig(
            img_size=config.img_size,
        )
    elif isinstance(backgrounds, pl.LightningDataModule):
        backgrounds_config = backgrounds
    else:
        backgrounds_config = deepcopy(backgrounds)
        backgrounds_config.img_size = config.img_size
    loader_config = DataLoaderConfig(
        batch_size=config.batch_size,
        # num_workers=16,
    )
    composition_config = Composite2DDatasetConfig(
        img_size=config.img_size,
        sampling_seed=config.sampling_seed,
        normalize=normalize,
        class_from_foreground=class_from_foreground,
        fg_bg_correlation=fg_bg_correlation,
        fg_availability=fg_availability,
    )
    dataset_config = Objects2DDataConfig(
        foregrounds=foregrounds_config,
        backgrounds=backgrounds_config,
        composition=composition_config,
        n_training_samples=config.n_training_samples,
        n_val_samples=config.n_val_samples,
        n_test_samples=config.n_test_samples,
        sampling_seed=config.sampling_seed,
        loader=loader_config,
    )
    data = Objects2DData(dataset_config)
    # data.setup()
    return data
