"""
Module for loading and working with datasets in the context of SCRaWl.
"""
from collections.abc import Generator
from functools import partial
from itertools import combinations
from pathlib import Path
from typing import Literal

import torch
from lightning.pytorch import LightningDataModule
from sklearn.model_selection import train_test_split
from toponetx.classes import SimplicialComplex as ToponetxSimplicialComplex
from toponetx.transform import graph_to_clique_complex
from torch.utils.data import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset, LRGBDataset, TUDataset
from torch_geometric.transforms import RemoveIsolatedNodes
from tqdm import tqdm

from scrawl.simplicial import SimplicialComplex, SimplicialData
from scrawl.transformers import (
    pad_data,
    smart_to_networkx,
    toponetx_to_data,
    torch_geometric_to_data,
)
from scrawl.utils import compose
from scrawl.walker import RandomWalk, Walker

DATA_DIR = Path(__file__).parent.parent / "data"


def collate_data(batch: list[SimplicialData]) -> SimplicialData:
    """
    Collate a list of simplicial data objects into a single simplicial data object.

    Parameters
    ----------
    batch : list[SimplicialData]
        List of simplicial data objects to collate.

    Returns
    -------
    SimplicialData
        Collated simplicial data object where all complexes have been merged.
    """
    boundary = {}
    for i in range(1, max(sc.domain.dim for sc in batch) + 1):
        matrices = []
        for data in batch:
            if i <= data.domain.dim:
                matrices.append(data.domain.boundary[i])
            elif (i - 1) in data.domain.boundary:
                matrices.append(torch.zeros((data.domain.boundary[i - 1].shape[1], 0)))

        boundary[i] = torch.block_diag(*matrices)

    simplicial_complex = SimplicialComplex(boundary)
    out = SimplicialData(
        simplicial_complex, dtype=batch[0].dtype, device=batch[0].device
    )

    for i in range(max(sc.domain.dim for sc in batch) + 1):
        out[i] = torch.vstack([data[i] for data in batch if i <= data.domain.dim])

        out.set_aux_tensor(
            i,
            torch.hstack(
                [
                    torch.full((data.domain.shape[i],), j, dtype=torch.int64)
                    for j, data in enumerate(batch)
                    if i <= data.domain.dim
                ]
            ),
        )

    if batch[0].has_aux_tensor(-1):
        out.set_aux_tensor(-1, torch.hstack([data.aux_tensor(-1) for data in batch]))

    return out


def collate_walks(
    batch: SimplicialData, steps: int
) -> tuple[SimplicialData, dict[int, list[RandomWalk]]]:
    """
    Sample random walks for each simplex in the batch.

    Parameters
    ----------
    batch : SimplicialData
        Simplicial data.
    steps : int
        Number of steps to sample.

    Returns
    -------
    batch : SimplicialData
        Simplicial data.
    walk_groups : dict[int, list[RandomWalk]]
        Dictionary mapping rank to list of random walks.
    """
    walker = Walker(
        batch.domain,
        use_lower_connections=True,
        use_upper_connections=True,
        max_rank=batch.domain.dim,
    )
    walk_groups: dict[int, list[RandomWalk]] = {}

    for rank, num_simplices in enumerate(batch.domain.shape):
        walk_groups[rank] = walker.random_walks(
            rank, torch.arange(num_simplices, dtype=torch.int64), steps
        )

    return batch, walk_groups


class WalkingDataModule(LightningDataModule):
    """
    Base class for data modules that pre-sample random walks.

    Parameters
    ----------
    walk_steps : tuple[int, int] | int
        Number of steps to sample for training and validation.
    """

    def __init__(self, walk_steps: tuple[int, int] | int) -> None:
        """
        Base class for data modules that pre-sample random walks.

        Parameters
        ----------
        walk_steps : tuple[int, int] | int
            Number of steps to sample for training and validation.
        """
        super().__init__()

        if isinstance(walk_steps, tuple):
            self.walk_steps_train, self.walk_steps_val = walk_steps
        else:
            self.walk_steps_train = self.walk_steps_val = walk_steps

    def _get_collate_fn(self, mode: Literal["train", "val"]):
        """
        Collate function to use in data loaders in this class.

        Parameters
        ----------
        mode : {"train", "val"}
            Whether to use the training or validation collate function.

        Returns
        -------
        callable
            Collate function.
        """
        concrete_collate_walks = partial(
            collate_walks, steps=getattr(self, f"walk_steps_{mode}")
        )
        return compose(concrete_collate_walks, collate_data)


class SchoolContactsDataModule(LightningDataModule):
    """
    Data module for the social contacts datasets.

    Parameters
    ----------
    name : str
        Name of the dataset.
    train_size : float, default = 0.8
        Fraction of the dataset to use for training.
    max_rank : int
        Maximum rank of simplices to consider.
    """

    name: Literal["high-school", "primary-school"]

    NUM_CLASSES = {"high-school": 10, "primary-school": 12}

    def __init__(
        self,
        name: Literal["high-school", "primary-school"],
        train_size: float,
        max_rank: int | None,
    ) -> None:
        """
        Data module for the social contacts datasets.

        Parameters
        ----------
        name : str
            Name of the dataset.
        train_size : float, default = 0.8
            Fraction of the dataset to use for training.
        max_rank : int
            Maximum rank of simplices to consider.
        """
        super().__init__()
        self.name = name
        self.train_split = train_size
        self.max_rank = max_rank

    @property
    def data_dir(self) -> Path:
        """
        Path to the data directory containing the raw data.

        Returns
        -------
        Path
            Path to the data directory.
        """
        return Path(__file__).parent.parent / "data" / f"contact-{self.name}"

    @property
    def num_classes(self) -> int:
        """
        Number of target classes.

        Returns
        -------
        int
            Number of target classes.
        """
        return self.NUM_CLASSES[self.name]

    def setup(self, stage: str | None = None) -> None:
        """
        Set-up the social contacts data module by downloading and pre-processing the dataset.

        Parameters
        ----------
        stage : str, optional
            Not used.
        """
        simplicial_complex = ToponetxSimplicialComplex()
        with (self.data_dir / "node-labels.txt").open() as labels_file:
            for i, label in enumerate(labels_file):
                # nodes are 1-indexed in this dataset
                simplicial_complex.add_simplex([i + 1], label=int(label))

        with (self.data_dir / "hyperedges.txt").open() as edges_file:
            for simplex in edges_file:
                simplex = [int(x) for x in simplex.split(",")]

                if self.max_rank is not None and len(simplex) > self.max_rank + 1:
                    simplicial_complex.add_simplices_from(
                        combinations(simplex, self.max_rank)
                    )
                else:
                    simplicial_complex.add_simplex(simplex)

        train_mask = torch.empty(
            simplicial_complex.shape[0], dtype=torch.bool
        ).bernoulli_(self.train_split)

        self.dataset = toponetx_to_data(simplicial_complex, "label", dtype=torch.int64)
        self.dataset.set_aux_tensor(0, train_mask)

    def train_dataloader(self) -> DataLoader:
        """
        Return the training dataloader.

        Returns
        -------
        DataLoader
            The training dataloader.
        """
        return DataLoader([self.dataset], batch_size=None)

    def val_dataloader(self) -> DataLoader:
        """
        Return the validation dataloader.

        Returns
        -------
        DataLoader
            The validation dataloader.
        """
        return DataLoader([self.dataset], batch_size=None)


class GNNBenchmarkDataModule(LightningDataModule):
    """
    Data module for the GNN benchmark datasets.

    Parameters
    ----------
    name : str
        Name of the dataset.
    train_size : float, default = 0.8
        Fraction of the dataset to use for training.
    batch_size : int, default = 32
        Batch size.
    """

    name: str
    train_data: list[SimplicialData]
    train_data: list[SimplicialData]
    val_data: list[SimplicialData]

    FEATURE_SIZES = {
        "CIFAR10": [10, 10, 10],
    }

    def __init__(
        self,
        name: str,
        train_size: float = 0.8,
        batch_size: int = 32,
    ) -> None:
        """
        Data module for the GNN benchmark datasets.

        Parameters
        ----------
        name : str
            Name of the dataset.
        train_size : float, default = 0.8
            Fraction of the dataset to use for training.
        batch_size : int, default = 32
            Batch size.
        """
        super().__init__()
        self.name = name
        self.train_size = train_size
        self.batch_size = batch_size

    @property
    def cache_dir(self) -> Path:
        """
        Directory for caching.

        Returns
        -------
        Path
            Path to the cache directory.
        """
        return DATA_DIR / "gnn-benchmark"

    def cache_file(self, split: Literal["train", "test", "val"]) -> Path:
        """
        Path to the cached simplicial complexes.

        Parameters
        ----------
        split : {"train", "test", "val"}
            Split of the dataset.

        Returns
        -------
        Path
            Path to the cached simplicial complexes.
        """
        if split not in ["train", "test", "val"]:
            raise ValueError(
                f"split must be one of 'train', 'test', or 'val', got {split}."
            )
        return self.cache_dir / f"{self.name}-{split}.pt"

    def prepare_data(self) -> None:
        """
        Set-up the GNN benchmark data module by downloading and pre-processing the dataset.
        """

        def transform_fn(
            dataset,
        ) -> Generator[SimplicialData, None, None]:  # numpydoc ignore=GL08
            for data in tqdm(dataset, desc="Converting to Simplicial Complexes"):
                yield torch_geometric_to_data(data, dtype=torch.float32)

        for split in ["train", "test", "val"]:
            if self.cache_file(split).exists():
                continue

            dataset = GNNBenchmarkDataset(
                str(self.cache_dir),
                self.name,
                split=split,
                transform=RemoveIsolatedNodes(),
            )
            simplicial_complexes = list(transform_fn(dataset))
            torch.save(simplicial_complexes, self.cache_file(split))

    def setup(self, stage: str | None = None) -> None:
        """
        Set-up the GNN benchmark data module by loading the data.

        Parameters
        ----------
        stage : str, optional
            Not used.
        """
        for split in tqdm(["train", "test", "val"], desc="Setup LRGB data module..."):
            simplicial_complexes = torch.load(self.cache_file(split))
            setattr(self, f"{split}_data", simplicial_complexes)

    def train_dataloader(self):
        """
        Return the training dataloader.

        Returns
        -------
        DataLoader
            The training dataloader.
        """
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            collate_fn=compose(collate_walks, collate_data),
        )

    def val_dataloader(self):
        """
        Return the validation dataloader.

        Returns
        -------
        DataLoader
            The validation dataloader.
        """
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            collate_fn=compose(collate_walks, collate_data),
        )


class LRGBDataModule(WalkingDataModule):
    """
    Data module for the LRGB datasets.

    Parameters
    ----------
    name : str
        Name of the dataset.
    walk_steps : int
        Number of steps to sample for training and validation.
    train_size : float, default = 0.8
        Fraction of the dataset to use for training.
    batch_size : int, default = 32
        Batch size.
    """

    name: str
    train_data: list[SimplicialData]
    train_data: list[SimplicialData]
    val_data: list[SimplicialData]

    FEATURE_SIZES = {
        "PascalVOC-SP": [14, 2, 0],
        "PCQM-Contact": [0, 0, 0],
        "Peptides-func": [9, 3, 0],
    }

    TARGETS = {
        "PascalVOC-SP": {"mode": "node-classification", "num_classes": 21},
        "Peptides-func": {"mode": "sc-classification", "num_classes": 10},
    }

    def __init__(
        self,
        name: str,
        walk_steps: int,
        train_size: float = 0.8,
        batch_size: int = 32,
    ) -> None:
        """
        Data module for the LRGB datasets.

        Parameters
        ----------
        name : str
            Name of the dataset.
        walk_steps : int
            Number of steps to sample for training and validation.
        train_size : float, default = 0.8
            Fraction of the dataset to use for training.
        batch_size : int, default = 32
            Batch size.
        """
        super().__init__(walk_steps)
        self.name = name
        self.train_size = train_size
        self.batch_size = batch_size

    @property
    def target_conf(self):
        """
        Prediction target configuration.

        Returns
        -------
        dict
            Dictionary containing the prediction target configuration.
        """
        return self.TARGETS[self.name]

    @property
    def cache_dir(self) -> Path:
        """
        Directory for caching.

        Returns
        -------
        Path
            Path to the cache directory.
        """
        return DATA_DIR / "lrgb"

    def cache_file(self, split: Literal["train", "test", "val"]) -> Path:
        """
        Path to the cached simplicial complexes.

        Parameters
        ----------
        split : {"train", "test", "val"}
            Split of the dataset.

        Returns
        -------
        Path
            Path to the cached simplicial complexes.
        """
        if split not in ["train", "test", "val"]:
            raise ValueError(
                f"split must be one of 'train', 'test', or 'val', got {split}."
            )
        return self.cache_dir / f"{self.name}-{split}.pt"

    def prepare_data(self) -> None:
        """
        Prepare the TU data module by downloading and pre-processing the dataset.
        """

        def transform_fn(
            dataset,
        ) -> Generator[SimplicialData, None, None]:  # numpydoc ignore=GL08
            for data in tqdm(dataset, desc="Converting to Simplicial Complexes"):
                data = torch_geometric_to_data(data, dtype=torch.float32)
                yield pad_data(data, self.FEATURE_SIZES[self.name])

        for split in ["train", "test", "val"]:
            if self.cache_file(split).exists():
                continue

            dataset = LRGBDataset(
                str(self.cache_dir),
                self.name,
                split=split,
                transform=RemoveIsolatedNodes(),
            )
            simplicial_complexes = list(transform_fn(dataset))
            torch.save(simplicial_complexes, self.cache_file(split))

    def setup(self, stage: str | None = None) -> None:
        """
        Set-up the LRGB data module by loading the data.

        Parameters
        ----------
        stage : str, optional
            Not used.
        """
        for split in tqdm(["train", "test", "val"], desc="Setup LRGB data module..."):
            simplicial_complexes = torch.load(self.cache_file(split))
            setattr(self, f"{split}_data", simplicial_complexes)

    def train_dataloader(self) -> DataLoader:
        """
        Return the training dataloader.

        Returns
        -------
        DataLoader
            The training dataloader.
        """
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            collate_fn=self._get_collate_fn("train"),
        )

    def val_dataloader(self) -> DataLoader:
        """
        Return the validation dataloader.

        Returns
        -------
        DataLoader
            The validation dataloader.
        """
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            num_workers=4,
            collate_fn=self._get_collate_fn("val"),
        )


class TUDataModule(LightningDataModule):
    """
    Data module for the TU datasets.

    Parameters
    ----------
    name : str
        Name of the dataset.
    train_size : float, default = 0.8
        Fraction of the dataset to use for training.
    batch_size : int, default = 32
        Batch size.
    """

    name: str
    train_data: list[SimplicialData]
    val_data: list[SimplicialData]

    FEATURE_SIZES = {
        "NCI1": [37, 0, 0],
        "PROTEINS": [3, 0, 0],
        "REDDIT-BINARY": [0, 0, 0],
        "REDDIT-MULTI-5K": [0, 0, 0],
    }

    TARGETS = {
        "NCI1": {"mode": "graph-classification", "num_classes": 2, "walk_prop": 1.0},
        "PROTEINS": {
            "mode": "graph-classification",
            "num_classes": 2,
            "walk_prop": 1.0,
        },
        "REDDIT-BINARY": {
            "mode": "graph-classification",
            "num_classes": 2,
            "walk_prop": 0.2,
        },
        "REDDIT-MULTI-5K": {
            "mode": "graph-classification",
            "num_classes": 5,
            "walk_prop": 0.2,
        },
    }

    def __init__(
        self,
        name: str,
        train_size: float = 0.8,
        batch_size: int = 32,
    ) -> None:
        """
        Data module for the TU datasets.

        Parameters
        ----------
        name : str
            Name of the dataset.
        train_size : float, default = 0.8
            Fraction of the dataset to use for training.
        batch_size : int, default = 32
            Batch size.
        """
        super().__init__()
        self.name = name
        self.train_size = train_size
        self.batch_size = batch_size

    @property
    def target_conf(self) -> dict:
        """
        Prediction target configuration.

        Returns
        -------
        dict
            Dictionary containing the prediction target configuration.
        """
        return self.TARGETS[self.name]

    @property
    def cache_dir(self) -> Path:
        """
        Directory for caching.

        Returns
        -------
        Path
            Path to the cache directory.
        """
        return DATA_DIR / "tu"

    @property
    def cache_file(self) -> Path:
        """
        Path to the cached simplicial complexes.

        Returns
        -------
        Path
            Path to the cached simplicial complexes.
        """
        return self.cache_dir / f"{self.name}.pt"

    def prepare_data(self) -> None:
        """
        Prepare the TU data module by downloading and pre-processing the dataset.
        """
        if self.cache_file.exists():
            return

        simplicial_complexes = []
        tu_dataset = TUDataset(
            str(self.cache_dir), self.name, transform=RemoveIsolatedNodes()
        )

        for data in tqdm(tu_dataset, desc="Converting to Simplicial Complexes"):
            graph = smart_to_networkx(data, to_undirected=True)
            simplicial_complex = graph_to_clique_complex(graph, max_dim=3)
            simplicial_data = toponetx_to_data(
                simplicial_complex,
                attr_names=["x", "edge_attr"],
                dtype=torch.float32,
            )
            simplicial_data.set_aux_tensor(-1, data.y)

            simplicial_complexes.append(simplicial_data)

        torch.save(simplicial_complexes, self.cache_file)

    def setup(self, stage: str | None = None) -> None:
        """
        Set-up the TU data module by splitting the dataset into train and validation.

        Parameters
        ----------
        stage : str, optional
            Not used.
        """
        self.simplicial_complexes = torch.load(self.cache_file)
        self.train_data, self.val_data = train_test_split(
            self.simplicial_complexes, train_size=self.train_size
        )

    def train_dataloader(self) -> DataLoader:
        """
        Return the training dataloader.

        Returns
        -------
        DataLoader
            The training dataloader.
        """
        return DataLoader(
            self.train_data, batch_size=self.batch_size, collate_fn=collate_data
        )

    def val_dataloader(self) -> DataLoader:
        """
        Return the validation dataloader.

        Returns
        -------
        DataLoader
            The validation dataloader.
        """
        return DataLoader(
            self.val_data, batch_size=self.batch_size, collate_fn=collate_data
        )
