from os.path import join

import numpy as np
import pandas as pd
import torch

from magnetic_edge_gnn.datasets.dataset_utils import (
    combine_edges,
    create_pyg_graph_transductive,
    normalize_features,
    normalize_flows,
    relabel_nodes,
)

from .transductive_dataset import TransductiveDataset


class TNTPFlowDenoisingInterpolationDataset(TransductiveDataset):
    supported_tasks = ["denoising", "interpolation", "simulation"]
    supported_datasets = {
        "traffic-anaheim": "Anaheim",
        "traffic-barcelona": "Barcelona",
        "traffic-chicago": "ChicagoSketch",
        "traffic-sioux-falls": "SiouxFalls",
        "traffic-winnipeg": "Winnipeg",
    }

    def __init__(
        self,
        split: str,
        dataset_name: str,
        dataset_path: str,
        val_ratio: float = 0.1,
        test_ratio: float = 0.5,
        seed: float | None = None,
        arbitrary_orientation: bool = True,
        orientation_equivariant_labels: bool = False,
        cache_file: str | None = None,
        preprocess: bool | None = None,
        interpolation_label_size: float = 0.75,
    ):
        """
        Dataset class for the edge flow denoising and edge flow interpolation tasks for the traffic datasets from the
        TransportationNetworks GitHub repository (https://github.com/bstabler/TransportationNetworks).

        Args:
            split (str): Data split to load. Should be one of: ["train", "val", "test"].
            dataset_name (str): Name of the dataset.
            dataset_path (str): Path to the dataset.
            val_ratio (float, optional): Ratio of validation data. Defaults to 0.1.
            test_ratio (float, optional): Ratio of test data. Defaults to 0.5.
            seed (float, optional): Random seed. Defaults to 0.
            arbitrary_orientation (bool, optional): Whether to arbitrarily orient the edges.
                Defaults to False.
            orientation_equivariant_labels (bool, optional): Whether the labels are orientation-equivariant or not.
                Defaults to False.
        """

        dataset, task = dataset_name.rsplit("-", 1)
        if dataset not in self.supported_datasets:
            raise ValueError(
                f"The dataset should be in {self.supported_datasets.keys()}. The dataset {dataset} is not supported!"
            )
        if task not in self.supported_tasks:
            raise ValueError(
                f"The task should be in {self.supported_tasks.keys()}. The task {task} is not supported!"
            )

        self.dataset = self.supported_datasets[dataset]
        self.task = task
        self.interpolation_label_size = interpolation_label_size
        super().__init__(
            split=split,
            dataset_name=dataset_name,
            dataset_path=dataset_path,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
            seed=seed,
            arbitrary_orientation=arbitrary_orientation,
            orientation_equivariant_labels=orientation_equivariant_labels,
            cache_file=cache_file,
            preprocess=preprocess,
        )

    def preprocess(self):
        # Read the features and flows.
        features_file = join(self.dataset_path, f"{self.dataset}_net.tntp")
        features_df = pd.read_csv(features_file, skiprows=8, sep="\t")
        features_df.columns = [s.strip().lower() for s in features_df.columns]
        # Drop useless columns.
        features_df.drop(["~", ";"], axis=1, inplace=True)

        flows_file = join(self.dataset_path, f"{self.dataset}_flow.tntp")
        flows_df = pd.read_csv(flows_file, sep="\t")
        flows_df.columns = [s.strip().lower() for s in flows_df.columns]

        # All links types.
        all_link_types = set(features_df["link_type"])
        link_type2idx = {link_type: idx for idx, link_type in enumerate(all_link_types)}

        # Convert the dataframes to dictionaries.
        features = {}
        for _, row in features_df.iterrows():
            u, v = int(row["init_node"]), int(row["term_node"])
            numerical_features = np.array(
                [
                    row["capacity"],
                    row["length"],
                    row["free_flow_time"],
                    row["b"],
                    row["power"],
                    # row["speed"],
                    row["toll"],
                ]
            )

            # One-hot encode the link types.
            link_type = int(row["link_type"])
            link_type_features = np.zeros(len(link_type2idx))
            link_type_features[link_type2idx[link_type]] = 1

            features[(u, v)] = np.concatenate([numerical_features, link_type_features])

        flows = {
            (int(row["from"]), int(row["to"])): row["volume"]
            for _, row in flows_df.iterrows()
        }

        # Pre-process the graph.
        features, flows, undirected_edges = combine_edges(
            features=features, flows=flows
        )
        features, flows, undirected_edges, node_mapping = relabel_nodes(
            features=features, flows=flows, undirected_edges=undirected_edges
        )
        features, flows, undirected_edges = normalize_flows(
            features=features, flows=flows, undirected_edges=undirected_edges
        )
        features = normalize_features(features)

        inv_features = {k: v for k, v in features.items()}
        equi_features = {k: np.zeros(0) for k in features}

        # Create PyG graph from the dictonaries.
        data = create_pyg_graph_transductive(
            equi_features=equi_features,
            inv_features=inv_features,
            undirected_edges=undirected_edges,
            labels=flows,
            val_ratio=self.val_ratio,
            test_ratio=self.test_ratio,
            add_noisy_flow_to_input=(self.task == "denoising"),
            add_interpolation_flow_to_input=(self.task == "interpolation"),
            add_zeros_to_flow_input=(
                self.task == "simulation"
            ),  # There a no equivariant features available
            interpolation_label_size=self.interpolation_label_size,
        )

        torch.save(data, self.filename)
