import pickle
from os.path import join

import numpy as np
import torch

from magnetic_edge_gnn.datasets.dataset_utils import (
    combine_edges,
    create_pyg_graph_transductive,
    normalize_features,
    normalize_flows,
    relabel_nodes,
)

from .transductive_dataset import TransductiveDataset


class TrafficLADataset(TransductiveDataset):
    """
    Dataset class for the edge flow interpolation task for Traffic dataset from the paper
    "Combining Physics and Machine Learning for Network Flow Estimation"
    (https://openreview.net/pdf?id=l0V53bErniB).
    """

    def preprocess(self):
        # Read the features and flows.
        with open(join(self.dataset_path, "features_traffic.pkl"), "rb") as f:
            features = pickle.load(f)

        with open(join(self.dataset_path, "flows_traffic.pkl"), "rb") as f:
            flows = pickle.load(f)

        # Keep only highway type features (e.g., motorway, motorway link, trunk, ...).
        features = {k: v[4:-3] for k, v in features.items()}

        # Remove outliers in the flow value (above the 95th percentile).
        cut_off_value = np.percentile(np.array(list(flows.values())), q=95)
        flows = {k: v for k, v in flows.items() if v < cut_off_value}

        # Pre-process the graph.
        features, flows, undirected_edges = combine_edges(
            features=features, flows=flows
        )
        features, flows, undirected_edges, node_mapping = relabel_nodes(
            features=features, flows=flows, undirected_edges=undirected_edges
        )
        features, flows, undirected_edges = normalize_flows(
            features=features, flows=flows, undirected_edges=undirected_edges
        )
        features = normalize_features(features)

        inv_features = {k: v for k, v in features.items()}
        equi_features = {k: np.zeros(0) for k in features}

        # Create PyG graph from the dictonaries.
        data = create_pyg_graph_transductive(
            equi_features=equi_features,
            inv_features=inv_features,
            undirected_edges=undirected_edges,
            labels=flows,
            val_ratio=self.val_ratio,
            test_ratio=self.test_ratio,
            add_noisy_flow_to_input=False,
        )

        torch.save(data, join(self.dataset_path, f"graph-{self.seed}.pt"))
