
import os
import warnings
from collections import defaultdict

import numpy as np
import pandas as pd

import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected

from utils.data.preprocessing import preprocess_node_features, preprocess_edge_features


def get_labels_from_graph(graph) -> np.ndarray:
    return graph.y.cpu().numpy()


def get_adjacency_dict_from_graph(graph, return_list: bool = False) -> dict:
    edge_index = graph.edge_index.cpu().numpy()
    adj = defaultdict(set)
    row, col = edge_index
    for u, v in zip(row, col):
        adj[u].add(v)
        adj[v].add(u)
    if return_list:
        adj = {k: sorted(v) for k, v in adj.items()}
    return adj



def get_clusters_from_graph(graph, min_adjacency: int = None) -> tuple[set, dict[int, list[int]]]:

    if min_adjacency is not None:
        adj_graph = get_adjacency_dict_from_graph(graph=graph)
    else:
        adj_graph = None

    nodes = set()
    clusters = defaultdict(list)
    for i, label in enumerate(graph.y.cpu().numpy()):
        if min_adjacency is not None:
            if len(adj_graph[i]) >= min_adjacency:
                clusters[label].append(i)
                nodes.add(i)
        else:
            clusters[label].append(i)
            nodes.add(i)

    return nodes, clusters


def create_graph(folder: str, tp: str, index: int = 0):

    path = os.path.join(folder, tp, str(index))

    # node features
    node_features = pd.read_csv(os.path.join(path, "node_features.csv"))
    node_ids = node_features.pop("node_id")
    np_node_features = preprocess_node_features(node_features)

    # clusters
    clusters = pd.read_csv(os.path.join(path, "clustering.csv"), index_col=0).iloc[:, 0]
    clusters.index = clusters.index.astype("int64")
    clusters = clusters.astype("int64")
    y = node_ids.map(clusters).fillna(node_ids).astype("int64").to_numpy(copy=False)

    # edge index
    edge_features = pd.read_csv(os.path.join(path, "edge_features.csv"))
    node_id_to_index = pd.Series(np.arange(len(node_ids), dtype=np.int64), index=node_ids.values, name="node_idx")
    edge_features = edge_features.join(node_id_to_index.rename("src_idx"), on="a")
    edge_features = edge_features.join(node_id_to_index.rename("dst_idx"), on="b")
    edge_features = edge_features.dropna(subset=["src_idx", "dst_idx"])
    src = edge_features["src_idx"].to_numpy(dtype=np.int64, copy=False)
    dst = edge_features["dst_idx"].to_numpy(dtype=np.int64, copy=False)
    edge_index = torch.from_numpy(np.vstack((src, dst)))

    # edge features
    cols = [c for c in edge_features.columns if c not in {"a", "b", "src_idx", "dst_idx"}]
    np_edge_features = preprocess_edge_features(edge_features[cols])
    edge_attr = torch.from_numpy(np_edge_features.astype(np.float32, copy=False))

    # create the graph
    x = torch.from_numpy(np_node_features.astype(np.float32, copy=False))
    y_t = torch.from_numpy(y.astype(np.int64, copy=False))
    graph = Data(x=x, edge_index=edge_index.long(), edge_attr=edge_attr, y=y_t, num_nodes=int(x.shape[0]))

    return graph


def load_graph(folder: str, tp: str, index: int = 0,
               use_edge_features: bool = False,
               undirected: bool = True
               ):

    path = os.path.join(folder, tp, str(index), "graph.pt")
    if not os.path.exists(path):
        graph = create_graph(folder, tp, index)
        torch.save(graph, path)

    warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*",
                            category=FutureWarning)
    graph = torch.load(path)

    if not use_edge_features:
        graph.edge_attr = None

    if undirected:
        if use_edge_features:
            edge_index, edge_attr = to_undirected(graph.edge_index, edge_attr=graph.edge_attr,
                                                  num_nodes=graph.num_nodes, reduce="mean")
            graph.edge_index = edge_index
            graph.edge_attr = edge_attr
        else:
            graph.edge_index = to_undirected(graph.edge_index)

    return graph
