import os
import torch
import numpy as np
from typing import Sequence, Optional, List
from pytorch_lightning import LightningDataModule
from utils.utils import (
    select_simplexes_and_relations,
)
from torch_geometric.loader import DataLoader as DataLoaderPyG
from torch_geometric.data import Dataset as DatasetPyG
from sklearn.model_selection import train_test_split
import torch
import os
import numpy as np


class DACSDataset(DatasetPyG):
    """Dataset of hetero‑graphs representing Dynamical Activity Complexes.

    Parameters
    ----------
    dacs_dir : str
        Directory containing pickled ``HeteroData`` DACs.
    split_indices : Sequence[int]
        Indices (relative to *all* files in *root*) that belong to the
        current split. The order of *split_indices* is preserved.
    device : str
        Device on which *x* feature tensors will be placed.
    max_dim : int, default 2
        Highest simplex dimension to keep.
    relations : list[str]
        Which adjacency relations to retain; cf. `select_simplexes_and_relations`.
    """

    def __init__(
        self,
        root: str,
        split_indices: Sequence[int],
        device: str,
        max_dim: int = 2,
        relations: List[str] = None,
    ):
        super().__init__(root)
        self.root = root
        self.dacs_files = sorted(os.listdir(self.root))
        self.split_indices = split_indices

        self.device = device
        self.max_dim = max_dim
        self.relations = relations or []

    def len(self):
        """Number of samples in the split."""
        return len(self.split_indices)

    def get(self, idx):

        idx = self.split_indices[idx]
        data = torch.load(os.path.join(self.root, self.dacs_files[idx]))
        data = select_simplexes_and_relations(data, self.max_dim, self.relations)

        for idx, x in enumerate(data.x_dict):
            data[x].x = data[x].x.to(self.device)

        return data


class DACSDataModule(LightningDataModule):
    """
    Dataset class for graph classification task
    """

    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg  # keep full cfg for future convenience
        dcfg = cfg.dataset

        self.random_seed: int = cfg.training.seed
        self.root: str = dcfg.root
        self.dacs_path: str = dcfg.dacs_path
        self.split_number: int = dcfg.split_number

        self.batch_size: int = dcfg.batch_size
        self.device: str | torch.device = dcfg.device
        self.max_dim: int = dcfg.max_dim
        self.relations: List[str] = dcfg.relations

        # Datasets will be initialised in ``setup``
        self._train_ds: Optional[DACSDataset] = None
        self._val_ds: Optional[DACSDataset] = None
        self._test_ds: Optional[DACSDataset] = None

    def prepare_data(self) -> None:
        """
        Prepare the data, such as downloading or saving preprocessed files.
        Runs only once on a single process.
        """
        pass

    def setup(self, stage=None) -> None:

        # Determine the folder containing your dataset file
        dacs_dir = os.path.join(self.root, self.dacs_path)
        dacs = sorted(os.listdir(dacs_dir))

        # ------------------------------------------------------------------
        # 1. Gather labels
        # ------------------------------------------------------------------
        labels = [torch.load(os.path.join(dacs_dir, dac)).y.item() for dac in dacs]

        # ------------------------------------------------------------------
        # 2. Reuse or compute stratified splits
        # ------------------------------------------------------------------
        split_file = os.path.join(self.root, f"split" f"_{self.split_number}.pt")

        if os.path.exists(split_file):
            splits = torch.load(split_file)
            train_idx, val_idx, test_idx = (
                splits["train"],
                splits["val"],
                splits["test"],
            )
        else:
            indices = np.arange(len(dacs))

            train_idx, val_test_idx, _, y_val_test = train_test_split(
                indices,
                labels,
                stratify=labels,
                train_size=0.6,
                random_state=self.random_seed,
            )
            val_idx, test_idx = train_test_split(
                val_test_idx,
                stratify=y_val_test,
                train_size=0.5,
                random_state=self.random_seed,
            )
            torch.save(
                {"train": train_idx, "val": val_idx, "test": test_idx},
                split_file,
            )

        # ------------------------------------------------------------------
        # 3. Instantiate datasets
        # ------------------------------------------------------------------

        self.train_dataset = DACSDataset(
            root=dacs_dir,
            split_indices=train_idx,
            device=self.device,
            max_dim=self.max_dim,
            relations=self.relations,
        )

        self.val_dataset = DACSDataset(
            root=dacs_dir,
            split_indices=val_idx,
            device=self.device,
            max_dim=self.max_dim,
            relations=self.relations,
        )

        self.test_dataset = DACSDataset(
            root=dacs_dir,
            split_indices=test_idx,
            device=self.device,
            max_dim=self.max_dim,
            relations=self.relations,
        )

    # ------------------------------------------------------------------
    # Dataloaders
    # ------------------------------------------------------------------
    def train_dataloader(self):
        return DataLoaderPyG(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoaderPyG(
            self.val_dataset,
            batch_size=self.batch_size,
        )

    def test_dataloader(self):
        return DataLoaderPyG(
            self.test_dataset,
            batch_size=self.batch_size,
        )
