import os
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import torch
from lightning import LightningDataModule
from torch.utils.data import random_split
from torch_geometric.data import InMemoryDataset
from torch_geometric.loader import DataLoader

from src.data.ginkgo.dataset import GinkgoDataset
from src.data.ginkgo.simulate_dataset import get_dataset_name, simulate


class GinkgoDataModule(LightningDataModule):
    """`LightningDataModule` for the Ginkgo dataset.

    The Ginkgo dataset provides simulated jet physics data for machine learning research. It
    contains jets represented as binary trees, where nodes are unstable particles and leaves are
    stable jet constituents. Each constituent is characterized by a 4D vector (E, p_x, p_y, p_z).
    The dataset ensures momentum conservation and captures key features like the running of the
    splitting scale and permutation invariance. Jets are generated using a recursive algorithm with
    customizable parameters for different jet types and sizes.
    """

    def __init__(
        self,
        data_dir: str = "data/ginkgo/",
        train_val_test_split: Tuple[float, float, float] = (0.98, 0.01, 0.01),
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
        jet_type: str = "QCD",
        n_samples: int = 100000,
        min_leaves: int = 5,
        max_leaves: int = 21,
        w_rate: float = 3.0,
        qcd_rate: float = 1.5,
        qcd_mass: float = 30.0,
        pt_min_sqrt: float = 4.0,
        max_n_nodes: int = 21,
    ) -> None:
        """Initialize a `JetDataModule`.

        :param data_dir: Data directory. Defaults to `"data/ginkgo"`.
        :param train_val_test_split: Train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
        :param batch_size: Batch size. Defaults to `64`.
        :param num_workers: Number of workers. Defaults to `0`.
        :param pin_memory: Whether to pin memory. Defaults to `False`.
        :param jet_type: "QCD" or "W" like jet. Defaults to `"QCD"`.
        :param n_samples: Number of jets in the dataset. Defaults to `100_000`.
        :param min_leaves: Cut on the minimum number of leaves of the jets to save (including). Defaults to `5`.
        :param max_leaves: Cut on the maximum number of leaves of the jets to save (excluding). Defaults to `21`.
        :param w_rate: W-rate parameter used by Ginkgo simulator. Defaults to `3.0`.
        :param qcd_rate: QCD-rate parameter used by Ginkgo simulator. Defaults to `1.5`.
        :param qcd_mass: QCD-mass parameter used by Ginkgo simulator. Defaults to `30.0`.
        :param pt_min_sqrt: Square root of pt-min parameter used by Ginkgo simulator. Defaults to `4.0`.
        :param max_n_nodes: Maximum number of nodes in any graph. Defaults to `21`.
        """
        super().__init__()

        self.save_hyperparameters(logger=False)
        self.data_dir = Path(data_dir) / self.dataset_name

        self.data_train: Optional[InMemoryDataset] = None
        self.data_val: Optional[InMemoryDataset] = None
        self.data_test: Optional[InMemoryDataset] = None

    @property
    def dataset_name(self) -> str:
        """
        :returns: Dataset name including the settings for Ginkgo.
        """
        return get_dataset_name(
            self.hparams.n_samples,
            self.hparams.jet_type,
            self.hparams.min_leaves,
            self.hparams.max_leaves,
            self.hparams.w_rate,
            self.hparams.qcd_rate,
            self.hparams.qcd_mass,
            self.hparams.pt_min_sqrt,
        )

    def prepare_data(self) -> None:
        """If there is no dataset with the given parameters, generate it using the simulator."""
        if not self.data_dir.exists():
            simulate(
                root=Path(self.hparams.data_dir),
                n_samples=self.hparams.n_samples,
                jet_type=self.hparams.jet_type,
                min_leaves=self.hparams.min_leaves,
                max_leaves=self.hparams.max_leaves,
                w_rate=self.hparams.w_rate,
                qcd_rate=self.hparams.qcd_rate,
                qcd_mass=self.hparams.qcd_mass,
                pt_min_sqrt=self.hparams.pt_min_sqrt,
            )

    def setup(self, stage: Optional[str] = None) -> None:
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
        """
        # load and split datasets only if not loaded already
        if not self.data_train and not self.data_val and not self.data_test:
            dataset = GinkgoDataset(root=self.data_dir)
            self.data_train, self.data_val, self.data_test = random_split(
                dataset=dataset,
                lengths=self.hparams.train_val_test_split,
                generator=torch.Generator().manual_seed(42),
            )

    def train_dataloader(self) -> DataLoader:
        """Create and return the train dataloader.

        :return: The train dataloader.
        """
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader:
        """Create and return the validation dataloader.

        :return: The validation dataloader.
        """
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def test_dataloader(self) -> DataLoader:
        """Create and return the test dataloader.

        :return: The test dataloader.
        """
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def teardown(self, stage: Optional[str] = None) -> None:
        """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
        `trainer.test()`, and `trainer.predict()`.

        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
            Defaults to ``None``.
        """
        pass

    def state_dict(self) -> Dict[Any, Any]:
        """Called when saving a checkpoint. Implement to generate and save the datamodule state.

        :return: A dictionary containing the datamodule state that you want to save.
        """
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
        `state_dict()`.

        :param state_dict: The datamodule state returned by `self.state_dict()`.
        """
        pass


if __name__ == "__main__":
    _ = GinkgoDataModule()
