from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, Union

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset

import vis_datasets.loading as data_loading
from vis_datasets.lib.dataset_accessor import (
    LightningDataAccessor,
    DatasetStage,
)
from .datasets import (
    BackgroundsDataset,
    Composite2DDataset,
    Composite2DDatasetConfig,
    FGConfig,
    BGConfig,
    ForegroundsDataset,
    ImageForegroundsConfig,
    ImageBackgroundsConfig,
)

FGSource = Union[FGConfig, LightningDataAccessor]
BGSource = Union[BGConfig, LightningDataAccessor]

@dataclass
class Objects2DSamplingSeedsConfig:
    training: int
    validation: int
    test: int

@dataclass
class Objects2DDataConfig:
    foregrounds: FGSource
    backgrounds: BGSource
    composition: Composite2DDatasetConfig
    n_training_samples: int# = 50000
    n_val_samples: int# = 10000
    n_test_samples: int# = 10000
    sampling_seed: Union[int, Objects2DSamplingSeedsConfig]
    loader: data_loading.DataLoaderConfig

class Objects2DData(LightningDataAccessor):

    def __init__(
        self,
        config: Objects2DDataConfig,
        # loader_config: data_loading.DataLoaderConfig,
    ) -> None:
        super().__init__()
        self.config = config

    def prepare_data(self):
        if (
            (
                isinstance(self.config.foregrounds, ImageForegroundsConfig)
                and not self.config.foregrounds.fg_dir.exists()
            ) or (
                isinstance(self.config.backgrounds, ImageBackgroundsConfig)
                and not self.config.backgrounds.bg_dir.exists()
            )
        ):
            raise FileNotFoundError(
                "Could not find foreground or background images. "
                "Make sure to download them from AWS using the instructions "
                "in the README."
            )

        # try
        # import subprocess
        # if self.data_config.fg_dir.exists():
        #     print("Foregrounds already downloaded, skipping")
        # else:
        #     print("Downloading foregrounds")
        #     self.data_config.fg_dir.mkdir()
        #     subprocess.run([
        #         "aws", "s3",
        #         "cp", "s3://si-score-dataset/foregrounds/",
        #         str(self.data_config.fg_dir.absolute()), "--recursive",
        #     ])
        # if self.data_config.bg_dir.exists():
        #     print("Backgrounds already downloaded, skipping")
        # else:
        #     print("Downloading backgrounds")
        #     self.data_config.bg_dir.mkdir()
        #     subprocess.run([
        #         "aws", "s3",
        #         "cp", "s3://si-score-dataset/backgrounds/",
        #         str(self.data_config.bg_dir.absolute()), "--recursive",
        #     ])
        # except:
        #     raise FileNotFoundError(
        #         "An error occurred trying to download the images. "
        #         "Make sure that you have the 'aws' utility installed. "
        #         "You need to install the 'aws' utility and login using an "
        #         "AWS acount as described here: "
        #         "https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html "
        #     )

    def setup(self, stage: Optional[str] = None) -> None:
        if stage is None or stage == "fit":
            stage_type = "train"
            self.set_dataset(stage_type, _create_objects_2d_dataset(
                foregrounds_config=self.config.foregrounds,
                backgrounds_config=self.config.backgrounds,
                composite_config=self.config.composition,
                n_samples=self.config.n_training_samples,
                sampling_seed=_get_seed(
                    self.config.sampling_seed, stage_type
                ),
                stage=stage_type,
            ))
            stage_type = "val"
            self.set_dataset(stage_type, _create_objects_2d_dataset(
                foregrounds_config=self.config.foregrounds,
                backgrounds_config=self.config.backgrounds,
                composite_config=self.config.composition,
                n_samples=self.config.n_val_samples,
                sampling_seed=_get_seed(
                    self.config.sampling_seed, stage_type
                ),
                stage=stage_type,
            ))
        if stage is None or stage == "test":
            stage_type = "test"
            self.set_dataset(stage_type, _create_objects_2d_dataset(
                foregrounds_config=self.config.foregrounds,
                backgrounds_config=self.config.backgrounds,
                composite_config=self.config.composition,
                n_samples=self.config.n_test_samples,
                sampling_seed=_get_seed(
                    self.config.sampling_seed, stage_type
                ),
                stage=stage_type,
            ))

    def train_dataloader(self) -> DataLoader:
        return data_loading.load(
            self.get_dataset("train"),
            train=True,
            config=self.config.loader,
        )

    def val_dataloader(self) -> DataLoader:
        return data_loading.load(
            self.get_dataset("val"),
            train=False,
            config=self.config.loader,
        )

    def test_dataloader(self) -> DataLoader:
        return data_loading.load(
            self.get_dataset("test"),
            train=False,
            config=self.config.loader,
        )

def _create_objects_2d_dataset(
    foregrounds_config: FGSource,
    backgrounds_config: BGSource,
    composite_config: Composite2DDatasetConfig,
    n_samples: int,
    sampling_seed: int,
    stage: DatasetStage,
) -> Dataset:
    if isinstance(foregrounds_config, LightningDataAccessor):
        foregrounds_config.setup()
        return foregrounds_config.get_dataset(stage)
    else:
        foregrounds_data = ForegroundsDataset(
            foregrounds_config,
            n_samples=n_samples,
            sampling_seed=sampling_seed,
        )

    if isinstance(backgrounds_config, LightningDataAccessor):
        backgrounds_config.setup()
        backgrounds_data = backgrounds_config.get_dataset(stage)
    else:
        backgrounds_data = BackgroundsDataset(
            backgrounds_config,
            n_samples=n_samples,
            sampling_seed=sampling_seed + 593,
        )
    return Composite2DDataset(
        composite_config,
        foregrounds=foregrounds_data,
        backgrounds=backgrounds_data,
    )

def _get_seed(
    config: Union[int, Objects2DSamplingSeedsConfig],
    data_type: DatasetStage,
) -> int:
    if data_type == "train":
        if isinstance(config, Objects2DSamplingSeedsConfig):
            return config.training
        else:
            return config
    elif data_type == "val":
        if isinstance(config, Objects2DSamplingSeedsConfig):
            return config.validation
        else:
            return config + 110
    elif data_type == "test":
        if isinstance(config, Objects2DSamplingSeedsConfig):
            return config.test
        else:
            return config + 220
    else:
        raise ValueError()
