from typing import Union

import numpy as np

from datasets.numpy import ConcatNumpyDataset, FolderNumpyDataset, ConstantNumpyDataset
from src.configs.dataset import DatasetConfig
from src.configs.edm_sampler import EDMSamplerConfig
from src.configs.edm_scheduler import EDMSchedulerConfig
from src.configs.numpy_folder import NumpyFolderNoiseImageLabelSigmaConfig, NumpyFolderNoiseLabelConfig, \
    NumpyFolderImageLabelConfig, NumpyFolderNoiseConfig, NumpyFolderNoiseImageSigmaConfig, NumpyFolderImageConfig
from src.datasets.image import NumpyImageDataset
from src.datasets.image_label import NumpyImageLabelDataset
from src.datasets.lines import NumpyLinesConditionalDataset, NumpyLinesUnconditionalDataset
from src.datasets.noise import NumpyNoiseDataset
from src.datasets.noise_label import NumpyNoiseLabelDataset
from src.utils.edm import get_sigmas


def create_lines_unconditional_dataset(
        edm_scheduler_config: EDMSchedulerConfig,
        edm_sampler_config: EDMSamplerConfig,
        dataset_config: DatasetConfig,
        configs: list[NumpyFolderNoiseImageSigmaConfig]
) -> NumpyLinesUnconditionalDataset:
    sigmas: np.ndarray = get_sigmas(
        num_steps=edm_sampler_config.num_steps,
        sigma_min=edm_scheduler_config.sigma_min,
        sigma_max=edm_scheduler_config.sigma_max,
        rho=edm_scheduler_config.rho
    )
    return NumpyLinesUnconditionalDataset(
        noise_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].noise.folder,
                num_samples=configs[i].noise.num_samples,
                start_index=configs[i].noise.start_index
            ) for i in range(len(configs))
        ]),
        image_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].image.folder,
                num_samples=configs[i].image.num_samples,
                start_index=configs[i].image.start_index
            ) for i in range(len(configs))
        ]),
        sigma_dataset=ConcatNumpyDataset([
            ConstantNumpyDataset(
                value=np.asarray(sigmas[configs[i].time_step]),
                length=configs[i].num_samples
            ) for i in range(len(configs))
        ])
    )


def create_lines_conditional_dataset(
        edm_scheduler_config: EDMSchedulerConfig,
        edm_sampler_config: EDMSamplerConfig,
        dataset_config: DatasetConfig,
        configs: list[NumpyFolderNoiseImageLabelSigmaConfig]
) -> NumpyLinesConditionalDataset:
    sigmas: np.ndarray = get_sigmas(
        num_steps=edm_sampler_config.num_steps,
        sigma_min=edm_scheduler_config.sigma_min,
        sigma_max=edm_scheduler_config.sigma_max,
        rho=edm_scheduler_config.rho
    )
    return NumpyLinesConditionalDataset(
        noise_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].noise.folder,
                num_samples=configs[i].noise.num_samples,
                start_index=configs[i].noise.start_index
            ) for i in range(len(configs))
        ]),
        image_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].image.folder,
                num_samples=configs[i].image.num_samples,
                start_index=configs[i].image.start_index
            ) for i in range(len(configs))
        ]),
        label_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=(),
                data_type=np.int64,
                folder=configs[i].label.folder,
                num_samples=configs[i].label.num_samples,
                start_index=configs[i].label.start_index
            ) for i in range(len(configs))
        ]),
        sigma_dataset=ConcatNumpyDataset([
            ConstantNumpyDataset(
                value=np.asarray(sigmas[configs[i].time_step]),
                length=configs[i].num_samples
            ) for i in range(len(configs))
        ])
    )


def create_noise_dataset(
        dataset_config: DatasetConfig,
        configs: Union[list[NumpyFolderNoiseConfig]]
) -> NumpyNoiseDataset:
    return NumpyNoiseDataset(
        noise_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].noise.folder,
                num_samples=configs[i].noise.num_samples,
                start_index=configs[i].noise.start_index
            ) for i in range(len(configs))
        ])
    )


def create_noise_label_dataset(
        dataset_config: DatasetConfig,
        configs: Union[list[NumpyFolderNoiseLabelConfig]]
) -> NumpyNoiseLabelDataset:
    return NumpyNoiseLabelDataset(
        noise_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].noise.folder,
                num_samples=configs[i].noise.num_samples,
                start_index=configs[i].noise.start_index
            ) for i in range(len(configs))
        ]),
        label_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=(),
                data_type=np.int64,
                folder=configs[i].label.folder,
                num_samples=configs[i].label.num_samples,
                start_index=configs[i].label.start_index
            ) for i in range(len(configs))
        ])
    )


def create_image_dataset(
        dataset_config: DatasetConfig,
        configs: list[NumpyFolderImageConfig]
) -> NumpyImageDataset:
    return NumpyImageDataset(
        image_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].image.folder,
                num_samples=configs[i].image.num_samples,
                start_index=configs[i].image.start_index
            ) for i in range(len(configs))
        ])
    )


def create_image_label_dataset(
        dataset_config: DatasetConfig,
        configs: list[NumpyFolderImageLabelConfig]
) -> NumpyImageLabelDataset:
    return NumpyImageLabelDataset(
        image_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=dataset_config.image_shape,
                data_type=np.float64,
                folder=configs[i].image.folder,
                num_samples=configs[i].image.num_samples,
                start_index=configs[i].image.start_index
            ) for i in range(len(configs))
        ]),
        label_dataset=ConcatNumpyDataset([
            FolderNumpyDataset(
                data_shape=(),
                data_type=np.int64,
                folder=configs[i].label.folder,
                num_samples=configs[i].label.num_samples,
                start_index=configs[i].label.start_index
            ) for i in range(len(configs))
        ])
    )
