import torch
from torch.utils.data import Dataset
from utils.utils import (
    get_dataset,
    get_folder_from_dset,
    get_dataset_split
)
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from pytorch_lightning import LightningDataModule
from data_preprocess.compute_connectivity import save_heterodata_object
from utils.utils import select_simplexes_and_relations
import os


class NodeClassificationDataset(Dataset):
    """Dataset class for full-batch node classification training."""
    def __init__(self, data, y):
        self.data = data
        self.y = y

    def __len__(self):
        # There is only one sample which is the whole graph
        return 1

    def __getitem__(self, idx):
        return self.data, self.y


def custom_collate_fn(batch):
    """Custom collate function to handle full-batch graph training."""
    return batch[0]


class NodeClassificationDataModule(LightningDataModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.feat_init = self.cfg.feat_init
        self.full_dataset = None
        self.dataloader = None
        self.train_mask = None
        self.val_mask = None
        self.test_mask = None
        self.data_path = self.cfg.data_path

    def prepare_data(self) -> None:
        """
        Prepare the data, such as downloading or saving preprocessed files.
        Runs only once on a single process.
        """
        folder = get_folder_from_dset(self.cfg.name)
        root_dir = self.data_path

        # Ensure the dataset is downloaded
        dataset_path = os.path.join(root_dir, folder)
        if not os.path.exists(dataset_path):
            get_dataset(name=self.cfg.name, root_dir=root_dir)

        # Ensure hetero data object exists - if not, create it
        hetero_data_path = f"{dataset_path}/{folder}_heterodata_{self.feat_init}"
        if not os.path.exists(hetero_data_path):
            save_heterodata_object(
                dset=self.cfg.name,
                max_dim=self.cfg.max_dim,
                feat_init=self.feat_init,
            )

        # Check if the hetero data object has the required simplices
        hetero_data = torch.load(hetero_data_path)
        if len(hetero_data.node_types) < self.cfg.max_dim + 1:
            # The precomputed lifting does not contain simplices of high enough order,
            # so recompute it
            save_heterodata_object(
                dset=self.cfg.name,
                max_dim=self.cfg.max_dim,
                feat_init=self.feat_init,
            )

    def setup(self, stage: str = None) -> None:
        """
        Initialize the DataModule class
        """
        folder = get_folder_from_dset(self.cfg.name)

        root_dir = self.data_path
        dataset = get_dataset(name=self.cfg.name, root_dir=root_dir)

        dataset_path = os.path.join(root_dir, folder)

        # Load heterodata object
        hetero_data_path = f"{dataset_path}/{folder}_heterodata_{self.feat_init}"
        hetero_data = torch.load(hetero_data_path)

        # Normalize data
        scaler = StandardScaler()
        for k in hetero_data.x_dict.keys():
            if hetero_data.x_dict[k].shape[0] > 0:
                hetero_data.x_dict[k] = scaler.fit_transform(hetero_data.x_dict[k])

        # Select simplexes of desired order and adjacency types
        hetero_data = select_simplexes_and_relations(
            hetero_data, self.cfg.max_dim, self.cfg.adjacencies
        )

        y = dataset.y
        # Load to gpu if available
        hetero_data.x_dict = {
            k: v.to(self.cfg.device) for k, v in hetero_data.x_dict.items()
        }
        hetero_data.edge_index_dict = {
            k: v.to(self.cfg.device)
            for k, v in hetero_data.edge_index_dict.items()
        }

        y = y.to(self.cfg.device)
        self.full_dataset = NodeClassificationDataset(hetero_data, y)
        self.train_mask, self.val_mask, self.test_mask = get_dataset_split(
            self.cfg.name, dataset._data, root_dir, self.cfg.split_number
        )
        self.train_mask = self.train_mask.to(self.cfg.device)
        self.val_mask = self.val_mask.to(self.cfg.device)
        self.test_mask = self.test_mask.to(self.cfg.device)

    def train_dataloader(self):
        """Instantiate a DataLoader on the complete dataset -- training
        will be done via masking"""
        dataloader = DataLoader(
            self.full_dataset,
            batch_size=1,
            collate_fn=custom_collate_fn,
        )
        return dataloader

    def val_dataloader(self):
        """Instantiate a DataLoader on the complete dataset -- valid
        will be done via masking."""
        dataloader = DataLoader(
            self.full_dataset,
            batch_size=1,
            collate_fn=custom_collate_fn,
            num_workers=0,
        )
        return dataloader

    def test_dataloader(self):
        """Instantiate a DataLoader on the complete dataset -- testing
        will be done via masking."""
        dataloader = DataLoader(
            self.full_dataset,
            batch_size=1,
            collate_fn=custom_collate_fn,
        )

        return dataloader