import os
import torch
from torch_geometric.data import Data, InMemoryDataset
import pandas as pd
import glob
from tqdm import tqdm


# from custom_modules.loader.custom_loaders import *
# from custom_modules.loader.utils import *
import os.path as osp

# from custom_modules.loader.utils import *
from torch_geometric.graphgym.loader import set_dataset_attr

# from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
# from custom_modules.transform.posenc_stats import compute_posenc_stats
# from custom_modules.transform.task_preprocessing import task_specific_preprocessing
# from custom_modules.transform.transforms import (
#     pre_transform_in_memory,
# )
import time
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch
import pickle


def split_data_object(data):
    # List to store individual graph data objects
    data_list = []

    # Get unique graph IDs from the batch tensor
    graph_ids = torch.unique(data.batch)

    # Iterate over each graph ID to extract respective data
    for graph_id in graph_ids:
        # Mask to filter elements belonging to the current graph
        mask = data.batch == graph_id

        # Node features and labels
        x = data.x[mask]
        y = data.y[mask] if data.y is not None else None

        # Edge indices - need to adjust these to the new node indexing
        edge_mask = (data.batch[data.edge_index[0]] == graph_id) & (
            data.batch[data.edge_index[1]] == graph_id
        )
        edge_index = data.edge_index[:, edge_mask]

        # Adjust edge indices to new indexing
        _, edge_index = torch.unique(edge_index, return_inverse=True)
        edge_index = edge_index.reshape(2, -1)

        # Edge attributes, if they exist
        edge_attr = data.edge_attr[edge_mask] if data.edge_attr is not None else None

        # Create a new data object for the current graph
        single_graph_data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)

        # Append to list
        data_list.append(single_graph_data)

    return data_list


class NetworkRepository(InMemoryDataset):

    def __init__(self, root, transform=None, pre_transform=None):
        super(NetworkRepository, self).__init__(root, transform, pre_transform)
        self.load()

    def load(self):
        # Load processed data if it exists
        processed_path = self.processed_paths[0]
        print(processed_path)
        if os.path.exists(processed_path):
            # self.process()
            self.data, self.slices = torch.load(processed_path)
        else:
            self.process()

    @property
    def raw_file_names(self):
        return os.listdir(self.root)

    @property
    def processed_file_names(self):
        return ["processed.pt"]

    def process(self):
        self.data_list = []
        # Read data into huge `Data` list.
        for filename in self.raw_file_names:
            if filename.endswith(".edges"):
                base_name = filename[:-6]  # Remove '.edges' from filename to get base
                edge_index = self.load_edges(os.path.join(self.root, filename))

                # file_path = os.path.join(self.root, f'{base_name}.node_labels')
                # data = pd.read_csv(file_path, header=None,sep=None,usecols=[1], engine='python').values
                # data -= 1
                # y = torch.tensor(data.squeeze(), dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_labels")
                try:
                    # Extract node indices and labels
                    df = pd.read_csv(
                        file_path,
                        sep=None,
                        engine="python",
                        usecols=[0, 1],
                        header=None,
                    )
                    node_indices = df[0].values  # Node indices
                    node_labels = df[1].values  # Labels
                    node_indices = node_indices - 1
                    max_index = node_indices.max()
                    num_nodes = max_index + 1
                    y = torch.full((num_nodes,), -1, dtype=torch.long)
                    node_labels = node_labels - 1
                    y[node_indices] = torch.tensor(node_labels, dtype=torch.long)

                except Exception as e:
                    df = pd.read_csv(file_path, engine="python", header=None)
                    node_labels = df[0].values  # Labels
                    node_labels = node_labels - 1
                    y = torch.tensor(node_labels, dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_attrs")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    x = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    x = torch.zeros(y.shape[0], dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.link_labels")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    edge_attr = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    edge_attr = None

                valid_edge_mask = (edge_index[0] <= (x.shape[0] - 1)) & (
                    edge_index[1] <= (x.shape[0] - 1)
                )
                if valid_edge_mask is not None:
                    edge_index = edge_index[:, valid_edge_mask]

                # if edge_attr is not None:
                #     if edge_attr.shape[0] != edge_index.shape[1]:
                #         edge_attr=None
                #     if valid_edge_mask is not None:
                #         edge_attr = edge_attr[valid_edge_mask]
                edge_attr = None

                file_path = os.path.join(self.root, f"{base_name}.graph_idx")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    data -= 1
                    batch = torch.tensor(data.squeeze(), dtype=int)
                else:
                    batch = None
                data = Data(
                    x=x, edge_index=edge_index, y=y, edge_attr=edge_attr, batch=batch
                )

        if batch is not None:
            self.data_list = split_data_object(data)
        else:
            self.data_list.append(data)

        data, slices = self.collate(self.data_list)
        torch.save((data, slices), self.processed_paths[0])
        # torch.save(self.data_list, self.processed_paths[0])

    def load_edges(self, filepath):
        edges_df = pd.read_csv(
            filepath, header=None, names=["source", "target"], sep=None, engine="python"
        )
        # Subtract 1 from each column to convert from 1-based to 0-based indexing
        # Use this ONLY if your indices incorrectly start from 1
        edges_df -= 1

        return torch.tensor(edges_df.values.T, dtype=torch.long)


def check_and_load_processed_eig(dataset_dir, dataset_name, dataset_dir_with_err):
    all_dataset_eig_files = glob.glob(
        os.path.join(dataset_dir, f"{dataset_name}*_eigen_*_processed.pt")
    )

    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{32}_processed.pt",
    )

    # print(processed_path)
    if os.path.exists(processed_path):
        # Adjust the loading mechanism based on how you saved the dataset
        with open(processed_path, "rb") as f:
            dataset = pickle.load(f)
        print(f"Loaded processed dataset from {processed_path}")

        dataset.data.eigvecs_sn = dataset.data.eigvecs_sn[:, :32]
        dataset.data.eigvals_sn = dataset.data.eigvals_sn[:, :32, :]
        return dataset
    return None


file_path = "custom_modules/loader/dataset/network_repository_dataset_list.txt"
dataset_folder_name = "~/graph-datasets/"


with open(file_path, "r") as file:
    lines = file.readlines()

# List to hold the extracted folder names
dataset_names = []

# Process each line in the file
for line in lines:
    if line.startswith("Contents of downloads//"):
        start_idx = len("Contents of downloads//")
        end_idx = line.find(":")
        folder_name = line[start_idx:end_idx].strip()
        dataset_names.append(folder_name)

# List to hold folders without any .link_labels files
filtered_folders = []

# Check each folder for .link_labels files
for folder_name in dataset_names:
    # Construct the full path to the folder
    full_folder_path = os.path.join(dataset_folder_name, folder_name)

    # Use glob to find any .link_labels files in the folder
    if not glob.glob(os.path.join(full_folder_path, "*.link_labels")):
        filtered_folders.append(folder_name)


dataset_dir = "~/graph-datasets/eigen_processed/"

dataset_dir_with_err = "~/graph-datasets/eigen_processed_with_err/"
import logging


args = parse_args()
# Load config file
dataset_names = filtered_folders
for dataset_name in dataset_names[:]:
    print(dataset_name)
    if dataset_name == "Mutag":
        continue
    dataset = NetworkRepository(dataset_folder_name + dataset_name)
    formatted_dataset_name = dataset_name.replace("-", "_").lower()
    # with open('network_repo_utils/network_repo_config.txt', 'a') as file:
    #     file.write(f'cfg.nr_{formatted_dataset_name}.dataset_name = "{dataset_name}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.format = "{"Network_repository"}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.task = "{"node"}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.task_type = "{"classification"}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.loss_fun = "{"cross_entropy"}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.task_dim = {dataset.data.y.max()+1}\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.split_mode = "{"random"}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.transductive = {"True"}\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.split_index = {"0"}\n')
    #     try:
    #         file.write(f'cfg.nr_{formatted_dataset_name}.feat_dim = {dataset.data.x.shape[1]}\n')
    #     except:
    #         file.write(f'cfg.nr_{formatted_dataset_name}.feat_dim = {1}\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.hidden_dim = {32}\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.activate_fn = "{"torch.nn.ReLU()"}"\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.split = {"[0.6, 0.2, 0.2]"}\n')
    #     file.write(f'cfg.nr_{formatted_dataset_name}.num_nodes = {dataset.data.y.shape[0]}\n\n')

    with open("network_repo_utils/network_repo_list_no_edge_attr.txt", "a") as file:
        file.write(f'"nr_{formatted_dataset_name}",\n')

    # with open('network_repo_utils/network_repo_CN.txt', 'a') as file:
    #     file.write(f'cfg.nr_{formatted_dataset_name} = CN()\n')
