import os
import math
import torch
import numpy as np
from tqdm import tqdm
from typing import Dict, Optional
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.data import HeteroData
from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from utils.utils import (
    select_simplexes_and_relations_for_feat_class,
)
import os
import torch
import numpy as np
import math

class FixedVolumeDataModule(LightningDataModule):
    """PyTorch‑Lightning ``DataModule`` for the fixed‑volume feature‑classification task.

    A ``DataModule`` encapsulates all the steps needed to process and load the data for
    training, validation and testing.  This implementation:

    * loads pre‑processed heterogeneous graphs and labels from *``data_path``*;
    * selects the required simplex levels / adjacencies (via ``utils.utils``);
    * constructs *stratified* train / val / test splits that are cached to disk;
    * cuts the graphs according to the splits and wraps them in a
      ``FeatGraphDataset``;
    * finally exposes three ``DataLoader`` objects.
    """

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.data_cfg = cfg.dataset
        self.random_seed: int = cfg.training.seed

        self.data_path: str = self.data_cfg.root
        self.batch_size: int = self.data_cfg.batch_size
        self.device = torch.device(self.data_cfg.device)

        self.train_dataset: Optional[DynamicsDataset] = None
        self.val_dataset: Optional[DynamicsDataset] = None
        self.test_dataset: Optional[DynamicsDataset] = None

    def prepare_data(self) -> None:  # noqa: D401
        """Download or generate raw data (noop – everything is assumed on disk)."""
        pass

    def setup(self, stage: Optional[str] = None) -> None:

        # ------------------------------------------------------------------
        # 1.  Load base tensors
        # ------------------------------------------------------------------
        dynamics_path = os.path.join(self.data_path, "dynamics")

        labels_path = os.path.join(self.data_path, "labels")
        all_dynamics_path = os.path.join(self.data_path, "all_dynamics")

        if os.path.exists(labels_path) and os.path.exists(all_dynamics_path):
            labels: torch.Tensor = torch.load(labels_path)
            dynamics: torch.Tensor = torch.load(all_dynamics_path, weights_only=False)
        else:
            labels = []
            dynamics = HeteroData()

            for fname in tqdm(
                sorted(os.listdir(dynamics_path)), desc="Building global HeteroData"
            ):

                dynamic_i = torch.load(os.path.join(dynamics_path, fname))
                labels.append(dynamic_i.y)

                for ntype in dynamic_i.x_dict:
                    if ntype in dynamics.node_types:
                        dynamics[ntype].x.append(dynamic_i[ntype].x)
                    else:
                        dynamics[ntype].x = [dynamic_i[ntype].x]

                for ntype in dynamics.x_dict:
                    dynamics[ntype].x = torch.stack(dynamics[ntype].x)

            torch.save(dynamics, all_dynamics_path)
            torch.save(torch.tensor(labels), labels_path)

        self.all_dynamics = dynamics
        self.labels = labels

        edge_index_dict: Dict = torch.load(os.path.join(self.data_path, "relations.pt"))

        # ------------------------------------------------------------------
        # 3.  Select simplexes / adjacencies
        # ------------------------------------------------------------------
        self.heterodata, edge_index_dict = (
            select_simplexes_and_relations_for_feat_class(
                self.all_dynamics,
                edge_index_dict,
                self.data_cfg.max_dim,
                self.data_cfg.relations,
            )
        )

        # Move relations to device once (avoids per‑batch transfers).
        self.edge_index_dict = {
            k: v.to(self.device) for k, v in edge_index_dict.items() if v.numel() > 0
        }

        # ------------------------------------------------------------------
        # 4.  Make / load stratified splits
        # ------------------------------------------------------------------
        split_path = os.path.join(
            self.data_path, f"split_{self.data_cfg.split_number}.pt"
        )

        if os.path.exists(split_path):
            splits = torch.load(split_path)
            train_idx, val_idx, test_idx = (
                splits["train"],
                splits["val"],
                splits["test"],
            )
        else:
            indices = np.arange(len(labels))
            train_idx, val_test_idx, _, y_val_test = train_test_split(
                indices,
                labels.cpu(),
                stratify=labels.cpu(),
                train_size=0.6,
                random_state=self.random_seed,
            )
            val_idx, test_idx = train_test_split(
                val_test_idx,
                stratify=y_val_test.cpu(),
                train_size=0.5,
                random_state=self.random_seed,
            )
            torch.save(
                {"train": train_idx, "val": val_idx, "test": test_idx}, split_path
            )

        train_mask = torch.as_tensor(train_idx, dtype=torch.long)
        val_mask = torch.as_tensor(val_idx, dtype=torch.long)
        test_mask = torch.as_tensor(test_idx, dtype=torch.long)

        # ------------------------------------------------------------------
        # 5.  Cut graphs & build datasets
        # ------------------------------------------------------------------
        self.train_dataset = self._make_dataset(
            self.all_dynamics, labels, train_mask, edge_index_dict
        )
        self.val_dataset = self._make_dataset(
            self.all_dynamics, labels, val_mask, edge_index_dict
        )
        self.test_dataset = self._make_dataset(
            self.all_dynamics, labels, test_mask, edge_index_dict
        )

    # ---------------------------------------------------------------------
    # DataLoaders
    # ---------------------------------------------------------------------
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=1,
            collate_fn=custom_collate_fn,
            num_workers=0,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset, batch_size=1, collate_fn=custom_collate_fn, num_workers=0
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset, batch_size=1, collate_fn=custom_collate_fn, num_workers=0
        )

    # ---------------------------------------------------------------------
    # Helpers
    # ---------------------------------------------------------------------
    def _make_dataset(
        self,
        heterodata,
        y: torch.Tensor,
        mask: torch.Tensor,
        edge_index: Dict,
    ) -> "DynamicsDataset":
        """Subset the heterogeneous graph and wrap it in a :class:`FeatGraphDataset`."""
        subset = heterodata.clone()
        for node_type in subset.x_dict:
            subset[node_type].x = subset[node_type].x[mask]
        graphs, labels = self._create_batches(subset, y[mask])
        return DynamicsDataset(graphs, labels, edge_index, self.device)

    def _create_batches(self, heterodata, y):
        batch_heterodatas, ys = [], []
        indexes = torch.randperm(len(y))
        n_batches = math.ceil(len(y) / self.batch_size)

        for i in range(n_batches):
            this_batch_indexes = indexes[
                i * self.batch_size : (i + 1) * self.batch_size
            ]
            this_batch_heterodata = heterodata.clone()
            for xx in this_batch_heterodata.x_dict:
                this_batch_heterodata[xx].x = this_batch_heterodata[xx].x[
                    this_batch_indexes
                ]
            batch_heterodatas.append(this_batch_heterodata)
            ys.append(y[this_batch_indexes])

        return batch_heterodatas, ys


class DynamicsDataset(Dataset):

    def __init__(self, data, y, edge_index, device=None):
        self.data = data
        self.y = y
        self.edge_index = edge_index
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        cur_x_dict = self.data[idx].x_dict
        for xx in cur_x_dict:
            cur_x_dict[xx] = cur_x_dict[xx].to(self.device)
        return cur_x_dict, self.edge_index, self.y[idx].to(self.device)


def custom_collate_fn(batch):
    """Custom collate function to handle full-batch graph training."""
    return batch[0]
