import torch
import pandas as pd
import os
import pickle

from torch_geometric.graphgym.config import cfg
import torch.nn.functional as F
from torch_geometric.loader import GraphSAINTRandomWalkSampler
from torch_geometric.data import Data, InMemoryDataset


class GraphSAINTRandomWalkSampler_custom(GraphSAINTRandomWalkSampler):

    def sample_neighbours(self, node_idx):
        start = node_idx
        node_idx = self.adj.random_walk(start.flatten(), self.walk_length)
        return node_idx.view(-1)


def calc_k_hops(
    node_idx,
    num_hops,
    data_obj,
    # edge_index,
    hop_cutoff=100,
    get_undirected_hops=False,
    num_nodes=None,
):

    k_hop_neigh_idx = []
    max_size = 0

    # Since GraphSAINTSampler samples subgraphs, the usage here is different from NeighborSampler
    for node_id in node_idx:
        # Sub_data contains the sampled subgraph for this iteration
        # Use sub_data.node_idx to get the node indices in the sampled subgraph
        subset = data_obj.sample_neighbours(node_id)

        # Update max_size based on the actual sampled neighborhood size
        max_size = max(max_size, subset.size(0))
        k_hop_neigh_idx.append(subset)  # Example: collecting numpy arrays of indices

    # Note: This function's behavior and return values have changed due to the differences in sampling methods
    return k_hop_neigh_idx, max_size


def pad_hops(k_hop_neigh_idx, max_size):
    # Pad the idx tensors - added 0 idx for padding remember to reserve node 0 for padded features

    # cal max size from the array
    seq_mask = []
    for i in range(len(k_hop_neigh_idx)):
        seq_mask.append(
            k_hop_neigh_idx[i].new_ones(k_hop_neigh_idx[i].shape[0], dtype=torch.bool)
        )
        pad_size = max_size - k_hop_neigh_idx[i].shape[0]
        k_hop_neigh_idx[i] = F.pad(k_hop_neigh_idx[i], (0, pad_size), "constant", 0)
        mask = k_hop_neigh_idx[i].new_zeros(pad_size, dtype=torch.bool)
        seq_mask[-1] = torch.cat([seq_mask[-1], mask], dim=0)

    k_hop_neigh_idx = torch.stack(k_hop_neigh_idx, dim=0)
    seq_mask = torch.stack(seq_mask).unsqueeze(0)

    k_hop_neigh_idx = torch.tensor(k_hop_neigh_idx, dtype=torch.long)

    return k_hop_neigh_idx, seq_mask


def pad_hops_dict(k_hop_neigh_idx_dict, max_size):
    seq_mask_dict = {}
    for key in k_hop_neigh_idx_dict:
        seq_mask = []
        for i in range(len(k_hop_neigh_idx_dict[key])):
            seq_mask.append(
                k_hop_neigh_idx_dict[key][i].new_ones(
                    k_hop_neigh_idx_dict[key][i].shape[0], dtype=torch.bool
                )
            )
            pad_size = max_size - k_hop_neigh_idx_dict[key][i].shape[0]
            k_hop_neigh_idx_dict[key][i] = torch.nn.functional.pad(
                k_hop_neigh_idx_dict[key][i], (0, pad_size), "constant", 0
            )
            mask = k_hop_neigh_idx_dict[key][i].new_zeros(pad_size, dtype=torch.bool)
            seq_mask[-1] = torch.cat([seq_mask[-1], mask], dim=0)
        seq_mask_dict[key] = torch.stack(seq_mask)
        k_hop_neigh_idx_dict[key] = torch.stack(k_hop_neigh_idx_dict[key], dim=0)

    return k_hop_neigh_idx_dict, seq_mask_dict


def hops_multi(
    main_dict,
    subset_keys,
    dataset_name_to_index,
    node_dataset_id,
    node_idx,
    num_hops,
    hop_cutoff=100,
    get_undirected_hops=False,
):
    k_hop_neigh_idx_dict = {}
    max_size = 0
    for key in subset_keys:
        if not (node_dataset_id == dataset_name_to_index[key]).any():
            continue
        k_hop_neigh_idx_dataset = calc_k_hops(
            node_idx[node_dataset_id == dataset_name_to_index[key]],
            num_hops,
            main_dict[key],
            # main_dict[key].edge_index,
            hop_cutoff,
            get_undirected_hops,
            num_nodes=None,
            # num_nodes=main_dict[key].num_nodes,
        )
        k_hop_neigh_idx_dict[key], max_size_node = k_hop_neigh_idx_dataset
        max_size = max(max_size, max_size_node)
    k_hop_neigh_idx_pad_dict, seq_mask_dict = pad_hops_dict(
        k_hop_neigh_idx_dict, max_size
    )

    k_hop_neigh_idx = torch.zeros((len(node_idx), max_size), dtype=torch.long)
    seq_mask = torch.zeros((len(node_idx), max_size), dtype=torch.bool)
    for key in subset_keys:
        if not (node_dataset_id == dataset_name_to_index[key]).any():
            continue
        k_hop_neigh_idx[node_dataset_id == dataset_name_to_index[key]] = (
            k_hop_neigh_idx_pad_dict[key]
        )
        seq_mask[node_dataset_id == dataset_name_to_index[key]] = seq_mask_dict[key]

    return k_hop_neigh_idx, seq_mask


def save_processed_eig(dataset, dataset_dir, dataset_name):
    dataset_dir = os.path.join(dataset_dir, "eigen_processed")
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)

    pecfg = cfg.posenc_SignNet
    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{pecfg.eigen.max_freqs}_processed.pt",
    )
    with open(processed_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Saved processed dataset to {processed_path}")


def check_processed_eig(dataset_dir, dataset_name):
    dataset_dir = os.path.join(dataset_dir, "eigen_processed")
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)

    pecfg = cfg.posenc_SignNet
    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{pecfg.eigen.max_freqs}_processed.pt",
    )
    print(processed_path)
    if os.path.exists(processed_path):
        print(f"File {processed_path} already exists. Loading dataset.")
        with open(processed_path, "rb") as f:
            loaded_dataset = pickle.load(f)
        return loaded_dataset
    else:
        return None


def load_dataset_from_pt(data_cfg):
    dataset_dir = data_cfg.dir
    if data_cfg.format.startswith("PyG-"):
        pyg_dataset_id = data_cfg.format.split("-", 1)[1]
        dataset_dir = os.path.join(dataset_dir, pyg_dataset_id)
    dataset_loaded = check_processed_eig(dataset_dir, data_cfg.dataset_name)
    return dataset_loaded


def load_dataset_from_pt_syn(processed_path):
    if os.path.exists(processed_path):
        with open(processed_path, "rb") as f:
            dataset_loaded = pickle.load(f)
        return dataset_loaded
    else:
        return None


def split_data_object(data):
    # List to store individual graph data objects
    data_list = []

    # Get unique graph IDs from the batch tensor
    graph_ids = torch.unique(data.batch)

    # Iterate over each graph ID to extract respective data
    for graph_id in graph_ids:
        # Mask to filter elements belonging to the current graph
        mask = data.batch == graph_id

        # Node features and labels
        x = data.x[mask]
        y = data.y[mask] if data.y is not None else None

        # Edge indices - need to adjust these to the new node indexing
        edge_mask = (data.batch[data.edge_index[0]] == graph_id) & (
            data.batch[data.edge_index[1]] == graph_id
        )
        edge_index = data.edge_index[:, edge_mask]

        # Adjust edge indices to new indexing
        _, edge_index = torch.unique(edge_index, return_inverse=True)
        edge_index = edge_index.reshape(2, -1)

        # Edge attributes, if they exist
        edge_attr = data.edge_attr[edge_mask] if data.edge_attr is not None else None

        # Create a new data object for the current graph
        single_graph_data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)

        # Append to list
        data_list.append(single_graph_data)

    return data_list


class NetworkRepository(InMemoryDataset):

    def __init__(self, root, transform=None, pre_transform=None):
        super(NetworkRepository, self).__init__(root, transform, pre_transform)
        self.load()

    def load(self):
        # Load processed data if it exists
        processed_path = self.processed_paths[0]
        print(processed_path)
        if os.path.exists(processed_path):
            # self.process()
            self.data, self.slices = torch.load(processed_path)
        else:
            self.process()

    @property
    def raw_file_names(self):
        return os.listdir(self.root)

    @property
    def processed_file_names(self):
        return ["processed.pt"]

    def process(self):
        self.data_list = []
        # Read data into huge `Data` list.
        print(self.root)
        print(self.raw_file_names)
        for filename in self.raw_file_names:
            if filename.endswith(".edges"):
                base_name = filename[:-6]  # Remove '.edges' from filename to get base
                edge_index = self.load_edges(os.path.join(self.root, filename))

                file_path = os.path.join(self.root, f"{base_name}.node_labels")
                try:
                    # Extract node indices and labels
                    df = pd.read_csv(
                        file_path,
                        sep=None,
                        engine="python",
                        usecols=[0, 1],
                        header=None,
                    )
                    node_indices = df[0].values  # Node indices
                    node_labels = df[1].values  # Labels
                    node_indices = node_indices - 1
                    max_index = node_indices.max()
                    num_nodes = max_index + 1
                    y = torch.full((num_nodes,), -1, dtype=torch.long)
                    node_labels = node_labels - 1
                    y[node_indices] = torch.tensor(node_labels, dtype=torch.long)

                except Exception as e:
                    df = pd.read_csv(file_path, engine="python", header=None)
                    node_labels = df[0].values  # Labels
                    node_labels = node_labels - 1
                    y = torch.tensor(node_labels, dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_attrs")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    x = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    x = torch.zeros(y.shape[0], dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.link_labels")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    edge_attr = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    edge_attr = None

                valid_edge_mask = (edge_index[0] <= (x.shape[0] - 1)) & (
                    edge_index[1] <= (x.shape[0] - 1)
                )
                if valid_edge_mask is not None:
                    edge_index = edge_index[:, valid_edge_mask]

                edge_attr = None

                file_path = os.path.join(self.root, f"{base_name}.graph_idx")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    data -= 1
                    batch = torch.tensor(data.squeeze(), dtype=int)
                else:
                    batch = None
                data = Data(
                    x=x, edge_index=edge_index, y=y, edge_attr=edge_attr, batch=batch
                )
        if batch is not None:
            self.data_list = split_data_object(data)
        else:
            self.data_list.append(data)

        data, slices = self.collate(self.data_list)
        torch.save((data, slices), self.processed_paths[0])
        # torch.save(self.data_list, self.processed_paths[0])

    def load_edges(self, filepath):
        edges_df = pd.read_csv(
            filepath, header=None, names=["source", "target"], sep=None, engine="python"
        )
        # Subtract 1 from each column to convert from 1-based to 0-based indexing
        # Use this ONLY if your indices incorrectly start from 1
        edges_df -= 1

        return torch.tensor(edges_df.values.T, dtype=torch.long)
