import numpy as np

import torch
import torch_geometric.transforms as T

from torch_geometric.datasets import *
from torch.utils.data import DataLoader as torch_DataLoader
from torch.utils.data import ConcatDataset
from torch_geometric.graphgym.config import cfg
from custom_modules.loader.utils import hops_multi
from einops import repeat


def load_pyg(name, dataset_dir):
    """
    Load PyG dataset objects. (More PyG datasets will be supported)

    Args:
        name (string): dataset name
        dataset_dir (string): data directory

    Returns: PyG dataset object

    """
    dataset_dir = "{}/{}".format(dataset_dir, name)
    if name[:3] == "TU_":
        if name[3:] == "IMDB":
            name = "IMDB-MULTI"
            dataset = TUDataset(dataset_dir, name, transform=T.Constant())
        else:
            dataset = TUDataset(dataset_dir, name[3:])
    elif name == "Karate":
        dataset = KarateClub()
    elif "Coauthor" in name:
        if "CS" in name:
            dataset = Coauthor(dataset_dir, name="CS")
        else:
            dataset = Coauthor(dataset_dir, name="Physics")
    elif "AttributedGraphDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = AttributedGraphDataset(dataset_dir, dataset_name)
    elif "HeterophilousGraphDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = HeterophilousGraphDataset(dataset_dir, dataset_name)
    elif "Airports" in name:
        dataset_name = name.split("_")[1]
        dataset = Airports(dataset_dir, dataset_name)
    elif "Twitch" in name:
        dataset_name = name.split("_")[1]
        dataset = Twitch(dataset_dir, dataset_name)
    elif "LINKXDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = LINKXDataset(dataset_dir, dataset_name)
    elif "Planetoid" in name:
        dataset_name = name.split("_")[1]
        dataset = Planetoid(dataset_dir, dataset_name)
    elif "CitationFull" in name:
        dataset_name = name.split("_")[1]
        if "Cora_ML" in name:
            dataset = CitationFull(dataset_dir, "Cora_ML")
        else:
            dataset = CitationFull(dataset_dir, dataset_name)
    elif "Amazon" in name:
        if "Computers" in name:
            dataset = Amazon(dataset_dir, name="Computers")
        elif "Products" in name:
            dataset = AmazonProducts(dataset_dir)
        else:
            dataset = Amazon(dataset_dir, name="Photo")
    elif "Reddit" in name:
        if "Reddit2" in name:
            dataset = Reddit2(dataset_dir)
        else:
            dataset = Reddit(dataset_dir)
    elif name == "KarateClub":
        dataset = KarateClub()
    elif name == "MNIST":
        dataset = MNISTSuperpixels(dataset_dir)
    elif name == "PPI":
        dataset = PPI(dataset_dir)
    elif name == "QM7b":
        dataset = QM7b(dataset_dir)
    elif name == "QM9":
        dataset = QM9(dataset_dir)
    elif "SnapDataset" in name:
        dataset_name = name.split("_")[1]
        dataset = SNAPDataset(dataset_dir, name=dataset_name)
    else:
        dataset = eval(f"{name}('{dataset_dir}')")
    if not hasattr(dataset, "num_nodes"):
        if not hasattr(dataset, "edge_index"):
            dataset.data.num_nodes = len(dataset.data.y)
        else:
            dataset.data.num_nodes = dataset.data.edge_index.max().item() + 1
    if not hasattr(dataset, "x"):
        dataset = preformat_add_zero_node_features(dataset)
    return dataset


def preformat_add_zero_node_features(dataset):
    """Add zero node features to dataset.

    Args:
        dataset: PyG dataset object

    Returns:
        PyG dataset object
    """

    def add_zeros(data):
        data.x = torch.zeros((data.num_nodes, 1), dtype=torch.long)
        return data

    dataset = add_zeros(dataset)
    return dataset


class Collate:
    def __init__(
        self,
        graph_dataset_dict,
        key_to_idx,
        saint_samplers,
        hop_cutoff,
        model_num_latents,
    ):
        self.graph_dataset_dict = graph_dataset_dict
        self.saint_samplers = saint_samplers
        self.key_to_idx = key_to_idx

        self.hop_cutoff = hop_cutoff
        self.model_num_latents = model_num_latents

    def __call__(self, batch):
        # Extract unique dataset names from the batch
        node_classification_batch = [
            item for item in batch if item["task_type"] == "node_classification"
        ]

        graph_name_both, node_dataset_indices = np.unique(
            [item["dataset_id_name"] for item in batch],
            return_inverse=True,
        )

        graph_dataset_keys_node, _ = np.unique(
            [item["dataset_id_name"] for item in node_classification_batch],
            return_inverse=True,
        )

        dataset_key_map = {
            item["dataset_id_name"]: item["node_dataset_main_graph_dict_key"]
            for item in batch
        }

        graph_dataset_keys = [dataset_key_map[key] for key in graph_name_both]

        dataset_name_to_index = {name: idx for idx, name in enumerate(graph_name_both)}

        node_ids = torch.tensor([item["node_id"] for item in batch], dtype=torch.long)

        from collections import defaultdict

        # Efficiently creating the dictionary
        output_values = defaultdict(list)
        dataset_indices_efficient = defaultdict(list)
        task_indices_efficient = defaultdict(
            list
        )  # task_type and index in batch corresponding to that task
        task_indices_efficient["node_classification"] = []
        task_indices_efficient["graph_classification"] = []
        for index, item in enumerate(batch):
            task_indices_efficient[item["task_type"]].append(index)
            dataset_indices_efficient[item["node_dataset_task_name"]].append(index)
            output_values[item["node_dataset_task_name"]].append(item["node_label"])

        dataset_indices_efficient = dict(dataset_indices_efficient)
        task_indices_efficient = dict(task_indices_efficient)
        output_values = dict(output_values)
        for key in dataset_indices_efficient:
            output_values[key] = torch.stack(output_values[key])

        graph_names = graph_dataset_keys

        graph_seq_len = []
        for graph_name in graph_names:
            if isinstance(graph_name, str):
                graph_seq_len.append(
                    self.graph_dataset_dict[self.key_to_idx[graph_name]].num_nodes
                )
            else:
                graph_seq_len.append(graph_name.num_nodes)

        graph_cumsum = np.cumsum([0] + graph_seq_len)
        node_classification_node_ids = node_ids[
            task_indices_efficient["node_classification"]
        ]
        node_classification_dataset_indices = node_dataset_indices[
            task_indices_efficient["node_classification"]
        ]
        batch_wise_neigh_index = torch.empty(0, 0, dtype=torch.long)
        token_type_id = torch.empty(0, 0, dtype=torch.long)
        token_type_id_graph = torch.empty(0, 0, dtype=torch.long)

        if len(task_indices_efficient["node_classification"]) > 0:
            k_hop_neigh_idx, seq_mask = hops_multi(
                main_dict=self.saint_samplers,
                subset_keys=graph_dataset_keys_node,
                dataset_name_to_index=dataset_name_to_index,
                node_dataset_id=node_classification_dataset_indices,
                node_idx=node_classification_node_ids,
                num_hops=1,
                hop_cutoff=self.hop_cutoff,
                get_undirected_hops=True,
            )

            batch_wise_neigh_index = k_hop_neigh_idx + torch.LongTensor(
                graph_cumsum[node_classification_dataset_indices]
            ).unsqueeze(1)

            # token type
            token_type_id = torch.LongTensor(
                [0]
                + [1] * (k_hop_neigh_idx.shape[1] - 1)
                + [2] * self.model_num_latents
            )
            token_type_id = repeat(
                token_type_id, "n -> b n", b=len(node_classification_node_ids)
            )

        if len(task_indices_efficient["graph_classification"]) > 0:
            token_type_id_graph = torch.LongTensor([2] * self.model_num_latents)
            token_type_id_graph = repeat(
                token_type_id_graph,
                "n -> b n",
                b=len(node_ids[task_indices_efficient["graph_classification"]]),
            )

        return {
            "graph_names": graph_names,
            "graph_seq_len": graph_seq_len,
            "node_dataset_indices": torch.LongTensor(node_dataset_indices),
            "batch_wise_neigh_index": batch_wise_neigh_index,
            "token_type_id": token_type_id,
            "token_type_id_graph": token_type_id_graph,
            "main_graph_dict": self.graph_dataset_dict,
            "key_to_idx": self.key_to_idx,
            "output_task_indices": dataset_indices_efficient,
            "output_values": output_values,
            "task_indices": task_indices_efficient,
            "graph_name_both": graph_name_both,
        }


def create_loader_distributed(
    rank,
    world_size,
    shared_graph_dataset_dict,
    key_to_idx,
    saint_samplers,
    node_dataset_loaded_list_train,
):
    """Create data loader object.

    Returns: List of PyTorch data loaders

    """
    from torch.utils.data.distributed import DistributedSampler
    from custom_modules.loader.snake_sampler import (
        DistributedSSSampler,
    )

    if len(cfg.dataset_multi.name_list) > 0:
        cfg.dataset_config_list = []
        graph_ids = []
        unique_graph_id = []
        graph_counter = 0
        num_nodes = []
        ratio_in_epoch = []
        seq_len = []
        count = 0
        for node_dataset in node_dataset_loaded_list_train:
            if node_dataset.dataset_type == "NC":
                if len(node_dataset) == 0:
                    print(
                        "No labels present in: ",
                        node_dataset.main_graph_dict_key,
                    )
                graph_ids.extend([count] * len(node_dataset))
                unique_graph_id.extend([graph_counter] * len(node_dataset))
                count += 1
                graph_counter += 1

                num_nodes.extend([node_dataset.num_nodes] * len(node_dataset))
                ratio_in_epoch.extend([node_dataset.graph_ratio_in_epoch])
                seq_len.extend([len(node_dataset)])
            else:

                for _ in range(len(node_dataset)):
                    graph_ids.append(count)
                    unique_graph_id.extend([graph_counter])
                    count += 1
                    graph_counter += 1
                num_nodes.extend(node_dataset.num_nodes)
                ratio_in_epoch.extend([1.0] * len(node_dataset.num_nodes))
                seq_len.extend([len(node_dataset)])

        graph_ids = torch.Tensor(graph_ids)
        unique_graph_id = torch.Tensor(unique_graph_id)
        num_nodes = torch.Tensor(num_nodes)
        ratio_in_epoch = torch.Tensor(ratio_in_epoch)
        seq_len = torch.Tensor(seq_len)

        print("Number of nodes to decode:", len(graph_ids))

        node_dataset_loaded_train = ConcatDataset(node_dataset_loaded_list_train)
        batch_sampler_train = DistributedSSSampler(
            node_dataset_loaded_train,
            graph_ids,
            unique_graph_id,
            num_nodes,
            ratio_in_epoch,
            seq_len,
            cfg.train.sampler_graph_limit,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            batch_size=cfg.train.batch_size,
            seed=cfg.seed,
            accum_gradient_steps=cfg.train.accum_gradient_steps,
        )

        collate_fn = Collate(
            shared_graph_dataset_dict,
            key_to_idx,
            saint_samplers,
            cfg.model.hop_cutoff,
            cfg.model.num_latents,
        )

        train_loader = torch_DataLoader(
            node_dataset_loaded_train,
            num_workers=4,
            pin_memory=False,
            batch_sampler=batch_sampler_train,
            collate_fn=collate_fn,
            prefetch_factor=2,
            persistent_workers=True,
        )

        return train_loader
