import os
import os.path as osp
import pytorch_lightning as pl

from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import (
    HeterophilousGraphDataset,
    WikipediaNetwork,
    WebKB,
    LRGBDataset,
    ZINC,
)
from .tu import TUDataset
from torch_geometric.transforms import Constant, Compose
from torch_geometric.loader import DataLoader
from torch_geometric.utils import homophily
from ogb.graphproppred import PygGraphPropPredDataset
from tgp.data import PoolDataLoader

# Local imports
from source.utils import (
    DataToFloat,
    LabelsToInt,
    NodeOneHot,
    EdgeOneHot,
    UnsqueezeY,
    SqueezeY,
    get_train_val_test_datasets,
)
from tgp.datasets import (
    GraphClassificationBench,
    PyGSPDataset,
    EXPWL1Dataset,
    MultipartiteGraphDataset,
)


class GraphTaskDataModule(pl.LightningDataModule):
    """Template for dataset on graph-level tasks."""

    def __init__(self, args):
        super().__init__()
        self.args = args

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def compute_properties(self, dataset):
        self.num_features = dataset.num_features
        self.num_classes = dataset.num_classes
        if dataset.num_edge_features is not None and dataset.num_edge_features > 0:
            self.num_edge_features = dataset.num_edge_features
        else:
            self.num_edge_features = 0
        self.avg_nodes = int(dataset._data.num_nodes / len(dataset))
        self.avg_edges = int(dataset._data.num_edges / len(dataset))
        self.max_nodes = max([d.num_nodes for d in dataset])
        self.max_edges = max([d.num_edges for d in dataset])

    def train_dataloader(self):
        return PoolDataLoader(
            self.train_dataset,
            self.args.batch_size,
            shuffle=True,
            num_workers=self.args.get("num_workers", 0),
        )

    def val_dataloader(self):
        return PoolDataLoader(
            self.val_dataset,
            self.args.batch_size,
            num_workers=self.args.get("num_workers", 0),
        )

    def test_dataloader(self):
        return PoolDataLoader(
            self.test_dataset,
            self.args.batch_size,
            num_workers=self.args.get("num_workers", 0),
        )


class EXPWL1DataModule(GraphTaskDataModule):
    def __init__(self, args, root="data/EXPWL1/", pre_transform=None):
        super().__init__(args=args)

        # Load dataset
        if pre_transform is None:
            pre_transform = DataToFloat()
        elif isinstance(pre_transform, list):
            pre_transform = Compose([*pre_transform, DataToFloat()])
        else:
            pre_transform = Compose([pre_transform, DataToFloat()])
        self.dataset = EXPWL1Dataset(
            root, pre_transform=pre_transform, force_reload=args.force_reload
        )

        # Make splits
        self.n_folds = args.n_folds
        if args.fold_id is not None:
            assert args.fold_id < args.n_folds
            self.fold_id = args.fold_id
        else:
            self.fold_id = 0
        self.train_dataset, self.val_dataset, self.test_dataset = (
            get_train_val_test_datasets(
                self.dataset, args.seed, args.n_folds, args.fold_id
            )
        )

        self.compute_properties(self.dataset)


class BenchHardDataModule(GraphTaskDataModule):
    def __init__(self, args, root="data/Bench-hard/", pre_transform=None):
        super().__init__(args=args)

        pre_transform = Compose(pre_transform) if pre_transform is not None else None

        # Load datasets
        self.train_dataset = GraphClassificationBench(
            root,
            split="train",
            easy=False,
            small=False,
            pre_transform=pre_transform,
            force_reload=args.force_reload,
        )
        self.val_dataset = GraphClassificationBench(
            root,
            split="val",
            easy=False,
            small=False,
            pre_transform=pre_transform,
            force_reload=args.force_reload,
        )
        self.test_dataset = GraphClassificationBench(
            root,
            split="test",
            easy=False,
            small=False,
            pre_transform=pre_transform,
            force_reload=args.force_reload,
        )
        self.compute_properties(self.train_dataset)


class MultipartiteDataModule(GraphTaskDataModule):
    def __init__(self, args, root="data/Multipartite/", pre_transform=None):
        super().__init__(args=args)

        pre_transform = Compose(pre_transform) if pre_transform is not None else None

        self.dataset = MultipartiteGraphDataset(
            root=root, pre_transform=pre_transform, force_reload=args.force_reload
        )
        # Make splits
        self.n_folds = args.n_folds
        if args.fold_id is not None:
            assert args.fold_id < args.n_folds
            self.fold_id = args.fold_id
        else:
            self.fold_id = 0
        self.train_dataset, self.val_dataset, self.test_dataset = (
            get_train_val_test_datasets(
                self.dataset, args.seed, args.n_folds, args.fold_id
            )
        )

        self.compute_properties(self.dataset)


class TUDataModule(GraphTaskDataModule):
    def __init__(self, name, args, root="data/TUDataset", pre_transform=None):
        super().__init__(args=args)

        custom_trans = args.get("transform", [])
        new_trans = []
        if "constant" in custom_trans:
            new_trans.append(Constant())
        if "labels_to_int" in custom_trans:
            new_trans.append(LabelsToInt())
        transforms = (
            Compose(new_trans)
            if pre_transform is None
            else Compose([*pre_transform, *new_trans])
        )

        self.dataset = TUDataset(
            root=root,
            name=name,
            cleaned=args.get("clean", True),
            use_node_attr=args.get("use_node_attr", False),
            pre_transform=transforms,
            force_reload=args.force_reload,
        )

        # Make splits
        self.n_folds = args.n_folds
        if args.fold_id is not None:
            assert args.fold_id < args.n_folds
            self.fold_id = args.fold_id
        else:
            self.fold_id = 0
        self.train_dataset, self.val_dataset, self.test_dataset = (
            get_train_val_test_datasets(
                self.dataset, args.seed, args.n_folds, args.fold_id
            )
        )

        self.compute_properties(self.dataset)


class OGBDataModule(GraphTaskDataModule):
    def __init__(self, name, args, root="data/ogb/", pre_transform=None):
        super().__init__(args=args)

        # Force reload
        if args.force_reload:
            processed_path = osp.join(root, "_".join(name.split("-")), "processed")
            if osp.exists(processed_path):
                print(f"Force reload: deleting {processed_path}")
                os.system(f"rm -r {processed_path}")

        if name != "ogbg-molpcba":
            if pre_transform is None:
                pre_transform = Compose([SqueezeY()])
            elif isinstance(pre_transform, list):
                pre_transform = Compose([*pre_transform, SqueezeY()])
            else:
                pre_transform = Compose([pre_transform, SqueezeY()])

        dataset = PygGraphPropPredDataset(
            name=name, root=root, pre_transform=pre_transform
        )

        split_idx = dataset.get_idx_split()
        self.train_dataset = dataset[split_idx["train"]]
        self.val_dataset = dataset[split_idx["valid"]]
        self.test_dataset = dataset[split_idx["test"]]

        self.compute_properties(self.train_dataset)


class LRGBDataModule(GraphTaskDataModule):
    def __init__(self, name, args, root="data/lrgb/", pre_transform=None):
        super().__init__(args=args)

        pre_transform = Compose(pre_transform) if pre_transform is not None else None

        self.train_dataset = LRGBDataset(
            name=name,
            root=root,
            split="train",
            pre_transform=pre_transform,
            force_reload=args.force_reload,
        )
        self.val_dataset = LRGBDataset(
            name=name, root=root, split="val", pre_transform=pre_transform
        )
        self.test_dataset = LRGBDataset(
            name=name, root=root, split="test", pre_transform=pre_transform
        )

        self.compute_properties(self.train_dataset)


class ZINCDataModule(GraphTaskDataModule):
    def __init__(self, args, root="data/zinc/", pre_transform=None):
        super().__init__(args=args)

        new_trans = [UnsqueezeY()]
        if args.one_hot:
            new_trans.append(NodeOneHot(num_node_types=21))
            new_trans.append(EdgeOneHot(num_edge_types=4))
        transforms = (
            Compose(new_trans)
            if pre_transform is None
            else Compose([pre_transform, *new_trans])
        )

        self.train_dataset = ZINC(
            root=root,
            subset=args.subset,
            split="train",
            pre_transform=transforms,
            force_reload=args.force_reload,
        )  # it's enough to call reload once
        self.val_dataset = ZINC(
            root=root, subset=args.subset, split="val", pre_transform=transforms
        )
        self.test_dataset = ZINC(
            root=root, subset=args.subset, split="test", pre_transform=transforms
        )

        self.compute_properties(self.train_dataset)
        self.num_classes = 1


class NodeClassDataModule(pl.LightningDataModule):
    def __init__(
        self,
        name,
        root="data/NodeClass/",
        transform=None,
        pre_transform=None,
        force_reload=True,
        **args,
    ):
        super().__init__()
        self.fold = args.get("fold", None)
        heterophilous__datasets = [
            "Roman-empire",
            "Amazon-ratings",
            "Minesweeper",
            "Tolokers",
            "Questions",
        ]
        wikipedia_datasets = ["chameleon", "squirrel"]
        planetoid_datasets = ["Cora", "CiteSeer", "PubMed"]
        webkb_datasets = ["Cornell", "Texas", "Wisconsin"]
        available_datasets = (
            heterophilous__datasets
            + wikipedia_datasets
            + planetoid_datasets
            + webkb_datasets
        )
        assert name in available_datasets, (
            f"Available datasets are {available_datasets}"
        )

        if name in heterophilous__datasets:
            self.torch_dataset = HeterophilousGraphDataset(
                root=root,
                name=name,
                transform=transform,
                pre_transform=pre_transform,
                force_reload=force_reload,
            )
        elif name in wikipedia_datasets:
            self.torch_dataset = WikipediaNetwork(
                root=root,
                name=name,
                transform=transform,
                pre_transform=pre_transform,
                force_reload=force_reload,
                geom_gcn_preprocess=args.get("geom_gcn_preprocess", True),
            )
        elif name in planetoid_datasets:
            self.torch_dataset = Planetoid(
                root=root,
                name=name,
                split=args["split"],
                transform=transform,
                pre_transform=pre_transform,
                force_reload=force_reload,
            )
        elif name in webkb_datasets:
            self.torch_dataset = WebKB(
                root=root,
                name=name,
                transform=transform,
                pre_transform=pre_transform,
                force_reload=force_reload,
            )

    def dataloader(self):
        return DataLoader(self.torch_dataset, batch_size=1)

    def compute_homophily(self, method="edge_insensitive"):
        return homophily(
            self.torch_dataset[0].edge_index, self.torch_dataset[0].y, method=method
        )


class PyGSPDataModule(pl.LightningDataModule):
    def __init__(self, args, pre_transform=None):
        super().__init__()
        self.args = args
        path = "data/PyGSP"

        self.dataset = PyGSPDataset(
            root=path,
            name=args.pygsp_graph,
            pre_transform=pre_transform,
            force_reload=args.force_reload,
        )
        self.num_features = self.dataset.num_features

    def train_dataloader(self):
        return DataLoader(self.dataset, self.args.batch_size)
