# Adapted from: {REDACTED LINK}
import os

import numpy as np
import scipy.io
import torch
from ogb.nodeproppred import NodePropPredDataset


class NCDataset(object):
    def __init__(self, name):
        """
        based off of ogb NodePropPredDataset
        {REDACTED LINK}
        Gives torch tensors instead of numpy arrays
            - name (str): name of the dataset

        Usage after construction:

        split_idx = dataset.get_idx_split()
        train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
        graph, label = dataset[0]

        Where the graph is a dictionary of the following form:
        dataset.graph = {'edge_index': edge_index,
                         'edge_feat': None,
                         'node_feat': node_feat,
                         'num_nodes': num_nodes}
        For additional documentation, see OGB Library-Agnostic Loader {REDACTED LINK}
        """

        self.name = name
        self.graph = {}
        self.label = None

    def get_idx_split(self, train_prop, valid_prop, generator=None):
        """
        train_prop: The proportion of dataset for train split. Between 0 and 1.
        valid_prop: The proportion of dataset for validation split. Between 0 and 1.
        """
        n = int(self.label.shape[0])
        num_train = int(n * train_prop)
        num_valid = int(n * valid_prop)
        if generator is None:
            perm = torch.randperm(n)
        else:
            perm = torch.randperm(n, generator=generator)
        train_idx = perm[:num_train]
        valid_idx = perm[num_train : num_train + num_valid]
        test_idx = perm[num_train + num_valid :]
        return {"train": train_idx, "valid": valid_idx, "test": test_idx}

    def __getitem__(self, idx):
        assert idx == 0
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, len(self))


def even_quantile_labels(vals, nclasses, verbose=False):
    """partitions vals into nclasses by a quantile based split,
    where the first class is less than the 1/nclasses quantile,
    second class is less than the 2/nclasses quantile, and so on

    vals is np array
    returns an np array of int class labels
    """
    label = -1 * np.ones(vals.shape[0], dtype=np.int64)
    interval_lst = []
    lower = -np.inf
    for k in range(nclasses - 1):
        upper = np.quantile(vals, (k + 1) / nclasses)
        interval_lst.append((lower, upper))
        inds = (vals >= lower) * (vals < upper)
        label[inds] = k
        lower = upper
    label[vals >= lower] = nclasses - 1
    interval_lst.append((lower, np.inf))
    if verbose:
        print("Class Label Intervals:")
        for class_idx, interval in enumerate(interval_lst):
            print(f"Class {class_idx}: [{interval[0]}, {interval[1]})]")
    return label


def load_snap_patents_mat(
    nclass=5,
    mat_path=None,
    data_dir=None,
    download=False,
    url=None,
    drive_id=None,
    quiet=False,
):
    """
    Load the SNAP-Patents dataset from a MATLAB .mat file and produce an NCDataset.

    Args:
        nclass (int): Number of quantile bins for year labels.
        mat_path (str | None): Full path to snap_patents.mat. If None, constructed from data_dir/cwd.
        data_dir (str | None): Directory to look for/create the dataset file if mat_path is None.
        download (bool): If True, attempt to download when file is missing using url or drive_id.
        url (str | None): Direct download URL (used if provided).
        drive_id (str | None): Google Drive file id (used if provided and url is None).
        quiet (bool): Passed to the downloader to reduce output verbosity.

    Returns:
        NCDataset: Dataset with graph and quantile-binned labels.
    """
    if mat_path is None:
        base_dir = data_dir if data_dir is not None else os.getcwd()
        os.makedirs(base_dir, exist_ok=True)
        mat_path = os.path.join(base_dir, "snap_patents.mat")

    if not os.path.exists(mat_path):
        if download and (url or drive_id):
            try:
                import gdown
            except ImportError as e:
                raise RuntimeError(
                    "gdown is required for downloading. Install it or provide an existing mat_path."
                ) from e
            if drive_id and not url:
                gdown.download(id=drive_id, output=mat_path, quiet=quiet)
            else:
                gdown.download(url=url, output=mat_path, quiet=quiet)
        elif not download:
            raise FileNotFoundError(
                f"Dataset file not found at {mat_path}. Set download=True and provide url or drive_id, or pass an existing mat_path."
            )
        else:
            raise ValueError("download=True requires either url or drive_id.")

    fulldata = scipy.io.loadmat(mat_path)

    edge_index = torch.tensor(fulldata["edge_index"], dtype=torch.long)
    node_feat_arr = fulldata["node_feat"]
    if hasattr(node_feat_arr, "todense"):
        node_feat = torch.tensor(node_feat_arr.todense(), dtype=torch.float)
    else:
        node_feat = torch.tensor(node_feat_arr, dtype=torch.float)
    num_nodes = int(fulldata["num_nodes"])

    years = np.array(fulldata["years"]).flatten()
    label = torch.tensor(even_quantile_labels(years, nclass), dtype=torch.long)

    dataset = NCDataset("snap_patents")
    dataset.graph = {
        "edge_index": edge_index,
        "edge_feat": None,
        "node_feat": node_feat,
        "num_nodes": num_nodes,
    }
    dataset.label = label
    return dataset


def load_arxiv_year_dataset(nclass=5):
    filename = "arxiv-year"
    dataset = NCDataset(filename)
    ogb_dataset = NodePropPredDataset(name="ogbn-arxiv")
    dataset.graph = ogb_dataset.graph
    dataset.graph["edge_index"] = torch.as_tensor(dataset.graph["edge_index"])
    dataset.graph["node_feat"] = torch.as_tensor(dataset.graph["node_feat"])

    label = even_quantile_labels(dataset.graph["node_year"].flatten(), nclass)
    dataset.label = torch.tensor(label, dtype=torch.long).reshape(-1, 1)
    return dataset


def load_pokec_snap(data_dir=None):
    from torch_geometric.datasets import SNAPDataset
    import gzip
    import urllib.request

    base_dir = data_dir if data_dir is not None else "./data"
    os.makedirs(base_dir, exist_ok=True)

    print("Loading graph structure via PyG SNAPDataset...")
    pyg_dataset = SNAPDataset(root=base_dir, name="soc-pokec")
    pyg_data = pyg_dataset[0]

    profiles_url = "https://snap.stanford.edu/data/soc-pokec-profiles.txt.gz"
    profiles_gz = os.path.join(base_dir, "soc-pokec-profiles.txt.gz")
    profiles_txt = os.path.join(base_dir, "soc-pokec-profiles.txt")

    if not os.path.exists(profiles_txt):
        if not os.path.exists(profiles_gz):
            print(f"Downloading {profiles_url}...")
            urllib.request.urlretrieve(profiles_url, profiles_gz)
        print("Extracting profiles file...")
        with gzip.open(profiles_gz, "rb") as f_in:
            with open(profiles_txt, "wb") as f_out:
                f_out.write(f_in.read())

    print("Loading profiles and extracting features...")
    gender_labels = {}
    age_features = {}
    with open(profiles_txt, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) < 4:
                continue
            try:
                node_id = int(parts[0])
                gender = int(parts[3]) if parts[3] and parts[3] != "null" else -1
                age = int(parts[2]) if parts[2] and parts[2] != "null" else 0
            except (ValueError, IndexError):
                continue
            if gender in (0, 1):
                gender_labels[node_id] = gender
                age_features[node_id] = age

    valid_nodes = sorted(gender_labels.keys())
    node_id_map = {old_id: new_id for new_id, old_id in enumerate(valid_nodes)}
    num_nodes = len(valid_nodes)
    print(f"Found {num_nodes} nodes with valid gender labels")

    pyg_edge_index = pyg_data.edge_index.numpy()
    mapped_edges = []
    for i in range(pyg_edge_index.shape[1]):
        src, dst = pyg_edge_index[0, i], pyg_edge_index[1, i]
        if src in node_id_map and dst in node_id_map:
            mapped_edges.append((node_id_map[src], node_id_map[dst]))
    edge_index = torch.tensor(np.array(mapped_edges, dtype=np.int64).T, dtype=torch.long)
    print(f"Mapped {edge_index.shape[1]} edges (from {pyg_edge_index.shape[1]} original)")

    labels = torch.tensor([gender_labels[nid] for nid in valid_nodes], dtype=torch.long)
    ages = np.array([age_features[nid] for nid in valid_nodes], dtype=np.float32)
    ages = (ages - ages.mean()) / (ages.std() + 1e-8)
    node_feat = torch.tensor(ages, dtype=torch.float).unsqueeze(1)

    dataset = NCDataset("pokec")
    dataset.graph = {
        "edge_index": edge_index,
        "edge_feat": None,
        "node_feat": node_feat,
        "num_nodes": num_nodes,
    }
    dataset.label = labels
    return dataset


def load_nc_dataset(dataname, **kwargs):
    if dataname in ("snap-patents", "snap_patents"):
        return load_snap_patents_mat(**kwargs)
    if dataname == "pokec":
        return load_pokec_snap(data_dir=kwargs.get("data_dir", "./data"))
    raise ValueError(f"Unsupported dataset: {dataname}")


dataset_drive_url = {
    "twitch-gamer_feat": "1fA9VIIEI8N0L27MSQfcBzJgRQLvSbrvR",
    "twitch-gamer_edges": "1XLETC6dG3lVl7kDmytEJ52hvDMVdxnZ0",
    "snap-patents": "1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia",
    "pokec": "1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y",
    "yelp-chi": "1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ",
    "wiki_views": "1p5DlVHrnFgYm3VsNIzahSsvCD424AyvP",  # Wiki 1.9M
    "wiki_edges": "14X7FlkjrlUgmnsYtPwdh-gGuFla4yb5u",  # Wiki 1.9M
    "wiki_features": "1ySNspxbK-snNoAZM7oxiWGvOnTRdSyEK",  # Wiki 1.9M
}

splits_drive_url = {
    "snap-patents": "12xbBRqd8mtG_XkNLH8dRRNZJvVM4Pw-N",
    "pokec": "1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_",
}
