from abc import abstractmethod
from os.path import exists, join

import torch
from torch.utils.data import Dataset

from magnetic_edge_gnn.datasets.dataset_utils import (
    random_orientation,
)


class InductiveDataset(Dataset):
    def __init__(
        self,
        split: str,
        dataset_name: str,
        dataset_path: str,
        val_ratio: float = 0.1,
        test_ratio: float = 0.2,
        seed: float | None = None,
        arbitrary_orientation: bool = True,
        orientation_equivariant_labels: bool = False,
        cache_file: str | None = None,
        preprocess: bool | None = None,
    ):
        """
        Abstract dataset class for inductive tasks.

        Args:
            split (str): Data split to load. Should be one of: ["train", "val", "test"].
            dataset_name (str): Name of the dataset.
            dataset_path (str): Path to the dataset.
            val_ratio (float, optional): Ratio of validation data. Defaults to 0.1.
            test_ratio (float, optional): Ratio of test data. Defaults to 0.2.
            seed (float, optional): Random seed. Defaults to 0.
            arbitrary_orientation (bool, optional): Whether to arbitrarily orient the edges.
                Defaults to False.
            orientation_equivariant_labels (bool, optional): Whether the labels are orientation-equivariant or not.
                Defaults to False.
        """
        super().__init__()

        if split not in ["train", "val", "test"]:
            raise ValueError(
                f"The split should be in ['train', 'val', 'test']. Split {split} is not supported!"
            )

        self.split = split
        self.dataset_path = dataset_path
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.seed = seed
        self.arbitrary_orientation = arbitrary_orientation
        self.orientation_equivariant_labels = orientation_equivariant_labels

        if cache_file is not None:
            self.filename = cache_file
        else:
            self.filename = join(self.dataset_path, f"graph-{self.seed}.pt")

        if preprocess or not exists(self.filename):
            self.preprocess()
        assert exists(self.filename)
        data = torch.load(self.filename)

        self.graphs = data[self.split]

        if self.arbitrary_orientation:
            self.graphs = [
                random_orientation(
                    data,
                    orientation_equivariant_labels=self.orientation_equivariant_labels,
                )
                for data in self.graphs
            ]

    @abstractmethod
    def preprocess(self):
        pass

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx: int):
        return self.graphs[idx]
