import itertools
import os
from tempfile import NamedTemporaryFile

from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
from torch_geometric.loader import DataLoader

from .datasets import (
    DAGDenoisingDataset,
    ElectricalCircuitsDenoisingInterpolationDataset,
    LongestCycleIdentificationDataset,
    MixedLongestCycleIdentificationDataset,
    RandomWalkDenoisingDataset,
    TNTPFlowDenoisingInterpolationDataset,
    TrafficLADataset,
    TypedTrianglesOrientationDataset,
)


class EdgeLevelTaskDataModule(LightningDataModule):
    def __init__(
        self,
        config: DictConfig,
        batch_size: int,
        seed: int | None = None,
        arbitrary_orientation: bool = True,
    ):
        """
        PyTorch Lightning datamodule class for edge-level tasks.

        Args:
            config (DictConfig): Configuration file for the dataset.
            batch_size (int): Batch size.
            seed (int, optional): Random seed. Defaults to 0.
            arbitrary_orientation (bool, optional): Whether to arbitrarily orient the edges.
                Defaults to False.
        """
        super().__init__()

        self.name = config.name
        self.dataset_path = config.dataset_path
        self.val_ratio = config.val_ratio
        self.test_ratio = config.test_ratio
        self.batch_size = batch_size
        self.seed = seed
        self.arbitrary_orientation = arbitrary_orientation
        self.orientation_equivariant_labels = config.orientation_equivariant_labels
        self.config = config
        self.dataset_kwargs = {}

        self.supported_datasets = list(
            "-".join(tpl)
            for tpl in itertools.product(
                [
                    "traffic-anaheim",
                    "traffic-barcelona",
                    "traffic-chicago",
                    "traffic-sioux-falls",
                    "traffic-winnipeg",
                    "electrical-circuits",
                ],
                ["denoising", "interpolation", "simulation"],
            )
        ) + [
            "traffic-LA",
            "DAG-denoising",
            "longest-cycle-identification",
            "random-walk-denoising",
            "mixed-longest-cycle-identification",
            "typed-triangles-orientation",
        ]

        if self.name not in self.supported_datasets:
            raise ValueError(f"The dataset {self.name} is not supported!")

    def setup(self, stage: str = None):
        dataset_cls = None
        if self.name in tuple(
            "-".join([dataset, task])
            for dataset, task in itertools.product(
                [
                    "traffic-anaheim",
                    "traffic-barcelona",
                    "traffic-chicago",
                    "traffic-sioux-falls",
                    "traffic-winnipeg",
                ],
                ["denoising", "interpolation", "simulation"],
            )
        ):
            dataset_cls = TNTPFlowDenoisingInterpolationDataset
            self.dataset_kwargs |= dict(
                interpolation_label_size=self.config.get(
                    "interpolation_label_size", 0.75
                ),
            )
        elif self.name in tuple(
            "-".join([dataset, task])
            for dataset, task in itertools.product(
                [
                    "electrical-circuits",
                ],
                ["denoising", "interpolation", "simulation"],
            )
        ):
            dataset_cls = ElectricalCircuitsDenoisingInterpolationDataset
            self.dataset_kwargs |= dict(
                include_non_source_voltages=self.config.include_non_source_voltages,
                current_relative_to_voltage=self.config.current_relative_to_voltage,
                interpolation_label_size=self.config.get(
                    "interpolation_label_size", 0.75
                ),
            )
        elif self.name == "traffic-LA":
            dataset_cls = TrafficLADataset
        elif self.name == "DAG-denoising":
            dataset_cls = DAGDenoisingDataset
        elif self.name == "longest-cycle-identification":
            dataset_cls = LongestCycleIdentificationDataset
        elif self.name == "random-walk-denoising":
            dataset_cls = RandomWalkDenoisingDataset
        elif self.name == "mixed-longest-cycle-identification":
            dataset_cls = MixedLongestCycleIdentificationDataset
        elif self.name == "typed-triangles-orientation":
            dataset_cls = TypedTrianglesOrientationDataset
        else:
            raise ValueError(f"The dataset {self.name} is not supported!")

        with NamedTemporaryFile(suffix=".pt", delete=False) as f:
            self.train_dataset = dataset_cls(
                split="train",
                dataset_name=self.name,
                dataset_path=self.dataset_path,
                val_ratio=self.val_ratio,
                test_ratio=self.test_ratio,
                seed=self.seed,
                arbitrary_orientation=self.arbitrary_orientation,
                orientation_equivariant_labels=self.orientation_equivariant_labels,
                cache_file=f.name,
                preprocess=True,  # Creates the file once
                **self.dataset_kwargs,
            )
            self.val_dataset = dataset_cls(
                split="val",
                dataset_name=self.name,
                dataset_path=self.dataset_path,
                val_ratio=self.val_ratio,
                test_ratio=self.test_ratio,
                seed=self.seed,
                arbitrary_orientation=self.arbitrary_orientation,
                orientation_equivariant_labels=self.orientation_equivariant_labels,
                cache_file=f.name,
                **self.dataset_kwargs,
            )
            self.test_dataset = dataset_cls(
                split="test",
                dataset_name=self.name,
                dataset_path=self.dataset_path,
                val_ratio=self.val_ratio,
                test_ratio=self.test_ratio,
                seed=self.seed,
                arbitrary_orientation=self.arbitrary_orientation,
                orientation_equivariant_labels=self.orientation_equivariant_labels,
                cache_file=f.name,
                **self.dataset_kwargs,
            )
        os.remove(f.name)

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
        )
