import os
import torch
from torch_geometric.data import Data, InMemoryDataset
import pandas as pd
import glob
from tqdm import tqdm


from custom_modules.loader.custom_loaders import *
from custom_modules.loader.utils import *
import os.path as osp

from custom_modules.loader.utils import *
from torch_geometric.graphgym.loader import set_dataset_attr
from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
from custom_modules.transform.posenc_stats import compute_posenc_stats
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)
import time
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch


def split_data_object(data):
    # List to store individual graph data objects
    data_list = []

    # Get unique graph IDs from the batch tensor
    graph_ids = torch.unique(data.batch)

    # Iterate over each graph ID to extract respective data
    for graph_id in graph_ids:
        # Mask to filter elements belonging to the current graph
        mask = data.batch == graph_id

        # Node features and labels
        x = data.x[mask]
        y = data.y[mask] if data.y is not None else None

        # Edge indices - need to adjust these to the new node indexing
        edge_mask = (data.batch[data.edge_index[0]] == graph_id) & (
            data.batch[data.edge_index[1]] == graph_id
        )
        edge_index = data.edge_index[:, edge_mask]

        # Adjust edge indices to new indexing
        _, edge_index = torch.unique(edge_index, return_inverse=True)
        edge_index = edge_index.reshape(2, -1)

        # Edge attributes, if they exist
        edge_attr = data.edge_attr[edge_mask] if data.edge_attr is not None else None

        # Create a new data object for the current graph
        single_graph_data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)

        # Append to list
        data_list.append(single_graph_data)

    return data_list


class NetworkRepository(InMemoryDataset):

    def __init__(self, root, transform=None, pre_transform=None):
        super(NetworkRepository, self).__init__(root, transform, pre_transform)
        self.load()

    def load(self):
        # Load processed data if it exists
        processed_path = self.processed_paths[0]
        print(processed_path)
        if os.path.exists(processed_path):
            # self.process()
            self.data, self.slices = torch.load(processed_path)
        else:
            self.process()

    @property
    def raw_file_names(self):
        return os.listdir(self.root)

    @property
    def processed_file_names(self):
        return ["processed.pt"]

    def process(self):
        self.data_list = []
        # Read data into huge `Data` list.
        for filename in self.raw_file_names:
            if filename.endswith(".edges"):
                base_name = filename[:-6]  # Remove '.edges' from filename to get base
                edge_index = self.load_edges(os.path.join(self.root, filename))

                # file_path = os.path.join(self.root, f'{base_name}.node_labels')
                # data = pd.read_csv(file_path, header=None,sep=None,usecols=[1], engine='python').values
                # data -= 1
                # y = torch.tensor(data.squeeze(), dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_labels")
                try:
                    # Extract node indices and labels
                    df = pd.read_csv(
                        file_path,
                        sep=None,
                        engine="python",
                        usecols=[0, 1],
                        header=None,
                    )
                    node_indices = df[0].values  # Node indices
                    node_labels = df[1].values  # Labels
                    node_indices = node_indices - 1
                    max_index = node_indices.max()
                    num_nodes = max_index + 1
                    y = torch.full((num_nodes,), -1, dtype=torch.long)
                    node_labels = node_labels - 1
                    y[node_indices] = torch.tensor(node_labels, dtype=torch.long)

                except Exception as e:
                    df = pd.read_csv(file_path, engine="python", header=None)
                    node_labels = df[0].values  # Labels
                    node_labels = node_labels - 1
                    y = torch.tensor(node_labels, dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.node_attrs")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    x = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    x = torch.zeros(y.shape[0], dtype=torch.long)

                file_path = os.path.join(self.root, f"{base_name}.link_labels")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    edge_attr = torch.tensor(data.squeeze(), dtype=torch.long)
                else:
                    edge_attr = None

                valid_edge_mask = (edge_index[0] <= (x.shape[0] - 1)) & (
                    edge_index[1] <= (x.shape[0] - 1)
                )
                if valid_edge_mask is not None:
                    edge_index = edge_index[:, valid_edge_mask]

                # if edge_attr is not None:
                #     if edge_attr.shape[0] != edge_index.shape[1]:
                #         edge_attr=None
                #     if valid_edge_mask is not None:
                #         edge_attr = edge_attr[valid_edge_mask]
                edge_attr = None

                file_path = os.path.join(self.root, f"{base_name}.graph_idx")
                if os.path.exists(file_path):
                    data = pd.read_csv(file_path, header=None, engine="python").values
                    data -= 1
                    batch = torch.tensor(data.squeeze(), dtype=int)
                else:
                    batch = None
                data = Data(
                    x=x, edge_index=edge_index, y=y, edge_attr=edge_attr, batch=batch
                )
        if batch is not None:
            self.data_list = split_data_object(data)
        else:
            self.data_list.append(data)

        data, slices = self.collate(self.data_list)
        torch.save((data, slices), self.processed_paths[0])
        # torch.save(self.data_list, self.processed_paths[0])

    def load_edges(self, filepath):
        edges_df = pd.read_csv(
            filepath, header=None, names=["source", "target"], sep=None, engine="python"
        )
        # Subtract 1 from each column to convert from 1-based to 0-based indexing
        # Use this ONLY if your indices incorrectly start from 1
        edges_df -= 1

        return torch.tensor(edges_df.values.T, dtype=torch.long)


def check_and_load_processed_eig(dataset_dir, dataset_name, dataset_dir_with_err):
    all_dataset_eig_files = glob.glob(
        os.path.join(dataset_dir, f"{dataset_name}*_eigen_*_processed.pt")
    )

    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{32}_processed.pt",
    )

    # print(processed_path)
    if os.path.exists(processed_path):
        # Adjust the loading mechanism based on how you saved the dataset
        with open(processed_path, "rb") as f:
            dataset = pickle.load(f)
        print(f"Loaded processed dataset from {processed_path}")

        dataset.data.eigvecs_sn = dataset.data.eigvecs_sn[:, :32]
        dataset.data.eigvals_sn = dataset.data.eigvals_sn[:, :32, :]
        return dataset
    return None


file_path = "custom_modules/loader/dataset/network_repository_dataset_list.txt"

# Open the file and read the lines
with open(file_path, "r") as file:
    lines = file.readlines()

# List to hold the extracted folder names
dataset_names = []
dataset_folder_name = "~/graph-datasets/"
# Process each line in the file
for line in lines:
    # Check if the line contains the specific starting pattern
    if line.startswith("Contents of downloads//"):
        # Extract the folder name which follows the pattern
        # The folder name ends before the colon ':'
        start_idx = len("Contents of downloads//")
        end_idx = line.find(":")
        folder_name = line[start_idx:end_idx].strip()
        dataset_names.append(folder_name)


dataset_dir = "~/graph-datasets/eigen_processed/"

dataset_dir_with_err = "~/graph-datasets/eigen_processed_with_err/"
import logging


args = parse_args()
# Load config file
set_cfg(cfg)
load_cfg(cfg, args)
for dataset_name in dataset_names[:]:

    if dataset_name == "Mutag":
        print("skipping for now: ", dataset_name)
        continue
    print(dataset_name)
    dataset = NetworkRepository(dataset_folder_name + dataset_name)
    dataset_loaded = check_and_load_processed_eig(
        dataset_dir, dataset_name, dataset_dir_with_err
    )

    if dataset_loaded is not None:
        print("passing as eigen processed")
        continue
    pe_enabled_list = []
    for key, pecfg in cfg.items():
        if key.startswith("posenc_") and pecfg.enable:
            pe_name = key.split("_", 1)[1]
            pe_enabled_list.append(pe_name)
            if hasattr(pecfg, "kernel"):
                # Generate kernel times if functional snippet is set.
                if pecfg.kernel.times_func:
                    pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
                logging.info(
                    f"Parsed {pe_name} PE kernel times / steps: "
                    f"{pecfg.kernel.times}"
                )
    if pe_enabled_list:
        start = time.perf_counter()
        logging.info(
            f"Precomputing Positional Encoding statistics: "
            f"{pe_enabled_list} for all graphs..."
        )
        # Estimate directedness based on 10 graphs to save time.
        is_undirected = all(d.is_undirected() for d in dataset[:10])
        logging.info(f"  ...estimated to be undirected: {is_undirected}")
        pre_transform_in_memory(
            dataset,
            partial(
                compute_posenc_stats,
                pe_types=pe_enabled_list,
                is_undirected=is_undirected,
                cfg=cfg,
            ),
            show_progress=True,
        )
        elapsed = time.perf_counter() - start
        timestr = (
            time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
        )
        logging.info(f"Done! Took {timestr}")
    pre_transform_in_memory(
        dataset,
        set_dataset_attr(
            dataset,
            "dataset_name",
            [dataset_name] * len(dataset),
            len(dataset),
        ),
        show_progress=True,
    )

    # Set standard dataset train/val/test splits
    # if hasattr(dataset, "split_idxs"):
    #     set_dataset_splits(dataset, dataset.split_idxs)
    #     delattr(dataset, "split_idxs")

    set_dataset_attr(
        dataset,
        "dataset_name",
        [dataset_name] * len(dataset),
        len(dataset),
    )

    set_dataset_attr(
        dataset,
        "dataset_task_name",
        [f"{dataset_name}_node_classification"] * len(dataset),
        len(dataset),
    )

    # set_dataset_attr(
    #     dataset,
    #     "node_id",
    #     torch.tensor(list(range(len(dataset.data.y))), dtype=torch.long),
    #     len(dataset),
    # )

    # # Save the processed dataset for future use
    save_processed_eig(dataset, dataset_dir, dataset_name)


def preformat_TUDataset(dataset_dir, name):
    """Load and preformat datasets from PyG's TUDataset.

    Args:
        dataset_dir: path where to store the cached dataset
        name: name of the specific dataset in the TUDataset class

    Returns:
        PyG dataset object
    """
    if name in ["DD", "NCI1", "ENZYMES", "PROTEINS", "TRIANGLES"]:
        func = None
    elif name.startswith("IMDB-") or name == "COLLAB":
        func = T.Constant()
    else:
        raise ValueError(
            f"Loading dataset '{name}' from " f"TUDataset is not supported."
        )
    dataset = TUDataset(dataset_dir, name, pre_transform=func)
    return dataset

