import os
import torch
from torch_geometric.data import Data, InMemoryDataset
import pandas as pd
import glob
from tqdm import tqdm


from custom_modules.loader.custom_loaders import *
from custom_modules.loader.utils import *
import os.path as osp

from custom_modules.loader.utils import *
from torch_geometric.graphgym.loader import set_dataset_attr
from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
from custom_modules.transform.posenc_stats import compute_posenc_stats
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)
import time
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch


def split_data_object(data):
    # List to store individual graph data objects
    data_list = []

    # Get unique graph IDs from the batch tensor
    graph_ids = torch.unique(data.batch)

    # Iterate over each graph ID to extract respective data
    for graph_id in graph_ids:
        # Mask to filter elements belonging to the current graph
        mask = data.batch == graph_id

        # Node features and labels
        x = data.x[mask]
        y = data.y[mask] if data.y is not None else None

        # Edge indices - need to adjust these to the new node indexing
        edge_mask = (data.batch[data.edge_index[0]] == graph_id) & (
            data.batch[data.edge_index[1]] == graph_id
        )
        edge_index = data.edge_index[:, edge_mask]

        # Adjust edge indices to new indexing
        _, edge_index = torch.unique(edge_index, return_inverse=True)
        edge_index = edge_index.reshape(2, -1)

        # Edge attributes, if they exist
        edge_attr = data.edge_attr[edge_mask] if data.edge_attr is not None else None

        # Create a new data object for the current graph
        single_graph_data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)

        # Append to list
        data_list.append(single_graph_data)

    return data_list


class NetworkRepository(InMemoryDataset):

    def __init__(self, root, transform=None, pre_transform=None):
        super(NetworkRepository, self).__init__(root, transform, pre_transform)
        self.load()

    def load(self):
        # Load processed data if it exists
        processed_path = self.processed_paths[0]
        print(processed_path)
        if os.path.exists(processed_path):
            # self.process()
            self.data, self.slices = torch.load(processed_path)
        else:
            self.process()

    @property
    def raw_file_names(self):
        return os.listdir(self.root)

    @property
    def processed_file_names(self):
        return ["processed.pt"]

    def process(self):
        self.data_list = []
        # Read data into huge `Data` list.
        for filename in self.raw_file_names:
            if filename.endswith(".edges"):
                base_name = filename[:-6]  # Remove '.edges' from filename to get base
                edge_index = self.load_edges(os.path.join(self.root, filename))

                # file_path = os.path.join(self.root, f'{base_name}.node_labels')
                # data = pd.read_csv(file_path, header=None,sep=None,usecols=[1], engine='python').values
                # data -= 1
                # y = torch.tensor(data.squeeze(), dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_labels")
                try:
                    # Extract node indices and labels
                    df = pd.read_csv(
                        file_path,
                        sep=None,
                        engine="python",
                        usecols=[0, 1],
                        header=None,
                    )
                    node_indices = df[0].values  # Node indices
                    node_labels = df[1].values  # Labels
                    node_indices = node_indices - 1
                    max_index = node_indices.max()
                    num_nodes = max_index + 1
                    y = torch.full((num_nodes,), -1, dtype=torch.long)
                    node_labels = node_labels - 1
                    y[node_indices] = torch.tensor(node_labels, dtype=torch.long)

                except Exception as e:
                    df = pd.read_csv(file_path, engine="python", header=None)
                    node_labels = df[0].values  # Labels
                    node_labels = node_labels - 1
                    y = torch.tensor(node_labels, dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_attrs")
                if os.path.exists(file_path):
                    data = pd.read_csv(
                        file_path, header=None, dtype=int, engine="python"
                    ).values
                    x = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    x = torch.zeros(y.shape[0], dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.link_labels")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    edge_attr = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    edge_attr = None

                valid_edge_mask = (edge_index[0] <= (x.shape[0] - 1)) & (
                    edge_index[1] <= (x.shape[0] - 1)
                )
                if valid_edge_mask is not None:
                    edge_index = edge_index[:, valid_edge_mask]

                # if edge_attr is not None:
                #     if edge_attr.shape[0] != edge_index.shape[1]:
                #         edge_attr=None
                #     if valid_edge_mask is not None:
                #         edge_attr = edge_attr[valid_edge_mask]
                edge_attr = None

                file_path = os.path.join(self.root, f"{base_name}.graph_idx")
                if os.path.exists(file_path):
                    data = pd.read_csv(
                        file_path, header=None, dtype=int, engine="python"
                    ).values
                    data -= 1
                    batch = torch.tensor(data.squeeze(), dtype=int)
                else:
                    batch = None
                data = Data(
                    x=x, edge_index=edge_index, y=y, edge_attr=edge_attr, batch=batch
                )
        if batch is not None:
            self.data_list = split_data_object(data)
        else:
            self.data_list.append(data)

        data, slices = self.collate(self.data_list)
        torch.save((data, slices), self.processed_paths[0])
        # torch.save(self.data_list, self.processed_paths[0])

    def load_edges(self, filepath):
        edges_df = pd.read_csv(
            filepath, header=None, names=["source", "target"], sep=None, engine="python"
        )
        # Subtract 1 from each column to convert from 1-based to 0-based indexing
        # Use this ONLY if your indices incorrectly start from 1
        edges_df -= 1

        return torch.tensor(edges_df.values.T, dtype=torch.long)
