import os
import torch
from torch_geometric.data import Data, InMemoryDataset
import pandas as pd
import glob
from tqdm import tqdm


from custom_modules.loader.custom_loaders import *
from custom_modules.loader.utils import *
import os.path as osp

from custom_modules.loader.utils import *
from torch_geometric.graphgym.loader import set_dataset_attr
from custom_modules.loader.split_generator import prepare_splits, set_dataset_splits
from custom_modules.transform.posenc_stats import compute_posenc_stats
from custom_modules.transform.task_preprocessing import task_specific_preprocessing
from custom_modules.transform.transforms import (
    pre_transform_in_memory,
)
import time
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch

from custom_modules.loader.synthetic_dataset import SyntheticDataset


import logging


def check_and_load_processed_eig(dataset_dir, dataset_name):
    all_dataset_eig_files = glob.glob(
        os.path.join(dataset_dir, f"{dataset_name}*_eigen_*_processed.pt")
    )

    processed_path = os.path.join(
        dataset_dir,
        f"{dataset_name}_eigen_{32}_processed.pt",
    )

    # print(processed_path)
    if os.path.exists(processed_path):
        # Adjust the loading mechanism based on how you saved the dataset
        with open(processed_path, "rb") as f:
            dataset = pickle.load(f)
        print(f"Loaded processed dataset from {processed_path}")

        dataset.data.eigvecs_sn = dataset.data.eigvecs_sn[:, :32]
        dataset.data.eigvals_sn = dataset.data.eigvals_sn[:, :32, :]
        return dataset
    return None


dataset_dir = "~/graph-datasets/graphworld/"


syn_dirs = glob.glob(dataset_dir + "*")


args = parse_args()
# Load config file
set_cfg(cfg)
load_cfg(cfg, args)

count = 0
for syn_dir in syn_dirs:
    dataset_file_names = glob.glob(syn_dir + "/*.pkl")
    print(syn_dir)
    print(len(dataset_file_names))
    for dataset_file_name in dataset_file_names:
        with open(dataset_file_name, "rb") as f:
            dataset = pickle.load(f)
        graph_name = os.path.basename(dataset_file_name).split(".")[0]
        dataset_loaded = check_and_load_processed_eig(
            syn_dir + "/eigen_processed", graph_name
        )

        # if dataset.x.shape[0] > 100000:
        #     print("skipping for now: ", graph_name)
        #     continue
        if dataset_loaded is not None:
            print("passing as eigen processed")
            # print(count)
            # count += 1
            continue

        dataset = SyntheticDataset("placeholder", [dataset])
        dataset.data.x = dataset.data.x.to(torch.float32)
        pe_enabled_list = []
        for key, pecfg in cfg.items():
            if key.startswith("posenc_") and pecfg.enable:
                pe_name = key.split("_", 1)[1]
                pe_enabled_list.append(pe_name)
                if hasattr(pecfg, "kernel"):
                    # Generate kernel times if functional snippet is set.
                    if pecfg.kernel.times_func:
                        pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
                    logging.info(
                        f"Parsed {pe_name} PE kernel times / steps: "
                        f"{pecfg.kernel.times}"
                    )
        if pe_enabled_list:
            start = time.perf_counter()
            logging.info(
                f"Precomputing Positional Encoding statistics: "
                f"{pe_enabled_list} for all graphs..."
            )
            # Estimate directedness based on 10 graphs to save time.
            is_undirected = all(d.is_undirected() for d in dataset[:10])
            logging.info(f"  ...estimated to be undirected: {is_undirected}")
            pre_transform_in_memory(
                dataset,
                partial(
                    compute_posenc_stats,
                    pe_types=pe_enabled_list,
                    is_undirected=is_undirected,
                    cfg=cfg,
                ),
                show_progress=True,
            )
            elapsed = time.perf_counter() - start
            timestr = (
                time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
            )
            logging.info(f"Done! Took {timestr}")

        dataset_graph_name = os.path.join(syn_dir, f"{graph_name}")
        set_dataset_attr(
            dataset,
            "dataset_name",
            [dataset_graph_name] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "dataset_task_name",
            [f"{dataset_graph_name}_{'node'}_{'classification'}"] * len(dataset),
            len(dataset),
        )

        set_dataset_attr(
            dataset,
            "node_id",
            torch.tensor(list(range(len(dataset.data.y))), dtype=torch.long),
            len(dataset),
        )

        save_processed_eig(dataset, syn_dir + "/eigen_processed", graph_name)
