"""
Edge dataset class. 
Code adapted from 
Fuchsgruber, Dominik, et al. "Graph Neural Networks for Edge Signals: Orientation Equivariance and Invariance.", ICLR 2025
link: https://openreview.net/forum?id=XWBE90OYlH
""" 

import numpy as np
import pandas as pd
from abc import abstractmethod
from os.path import exists, join
from tempfile import NamedTemporaryFile
from torch.utils.data import Dataset
from torch_geometric.data import Data
from src.t3_edge_regression.dataset_utils import (
    combine_edges,
    create_pyg_graph_transductive,
    normalize_features,
    normalize_flows,
    relabel_nodes,
)
import torch


class TransductiveDataset(Dataset):
    def __init__(
        self,
        split: str,
        dataset_name: str,
        dataset_path: str,
        val_ratio: float = 0.1,
        test_ratio: float = 0.8,
        seed: float | None = None,
        arbitrary_orientation: bool = True,
        orientation_equivariant_labels: bool = False,
        cache_file: str | None = None,
        preprocess: bool | None = None,
    ):
        """
        Abstract dataset class for transductive tasks.

        Args:
            split (str): Data split to load. Should be one of: ["train", "val", "test"].
            dataset_name (str): Name of the dataset.
            dataset_path (str): Path to the dataset.
            val_ratio (float, optional): Ratio of validation data. Defaults to 0.1.
            test_ratio (float, optional): Ratio of test data. Defaults to 0.8.
            seed (float, optional): Random seed. Defaults to 0.
            arbitrary_orientation (bool, optional): Whether to arbitrarily orient the edges.
                Defaults to False.
            orientation_equivariant_labels (bool, optional): Whether the labels are orientation-equivariant or not.
                Defaults to False.
        """
        super().__init__()

        if split not in ["train", "val", "test"]:
            raise ValueError(
                f"The split should be in ['train', 'val', 'test']. Split {split} is not supported!"
            )

        self.split = split
        self.dataset_path = dataset_path
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.seed = seed
        self.arbitrary_orientation = arbitrary_orientation
        self.orientation_equivariant_labels = orientation_equivariant_labels

        if cache_file is not None:
            self.filename = cache_file
        else:
            self.filename = join(self.dataset_path, f"graph-{self.seed}.pt")

        if preprocess or not exists(self.filename):
            self.preprocess()
        assert exists(self.filename)
        data = torch.load(self.filename)

        self.graphs = [data] if isinstance(data, Data) else data


    @abstractmethod
    def preprocess(self, filename: str):
        pass

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx: int):
        return self.graphs[idx]


class TNTPFlowDenoisingInterpolationDataset(TransductiveDataset):
    supported_tasks = ["simulation"]
    supported_datasets = {
        "traffic-anaheim": "Anaheim",
        "traffic-barcelona": "Barcelona",
        "traffic-chicago": "ChicagoSketch",
        "traffic-winnipeg": "Winnipeg",
    }

    def __init__(
        self,
        split: str,
        dataset_name: str,
        dataset_path: str,
        val_ratio: float = 0.1,
        test_ratio: float = 0.5,
        seed: float | None = None,
        arbitrary_orientation: bool = True,
        orientation_equivariant_labels: bool = False,
        cache_file: str | None = None,
        preprocess: bool | None = None,
        interpolation_label_size: float = 0.75,
    ):
        """
        Dataset class for the edge flow denoising and edge flow interpolation tasks for the traffic datasets from the
        TransportationNetworks GitHub repository (https://github.com/bstabler/TransportationNetworks).

        Args:
            split (str): Data split to load. Should be one of: ["train", "val", "test"].
            dataset_name (str): Name of the dataset.
            dataset_path (str): Path to the dataset.
            val_ratio (float, optional): Ratio of validation data. Defaults to 0.1.
            test_ratio (float, optional): Ratio of test data. Defaults to 0.5.
            seed (float, optional): Random seed. Defaults to 0.
            arbitrary_orientation (bool, optional): Whether to arbitrarily orient the edges.
                Defaults to False.
            orientation_equivariant_labels (bool, optional): Whether the labels are orientation-equivariant or not.
                Defaults to False.
        """

        dataset, task = dataset_name.rsplit("-", 1)
        if dataset not in self.supported_datasets:
            raise ValueError(
                f"The dataset should be in {self.supported_datasets.keys()}. The dataset {dataset} is not supported!"
            )
        if task not in self.supported_tasks:
            raise ValueError(
                f"The task should be in {self.supported_tasks.keys()}. The task {task} is not supported!"
            )

        self.dataset = self.supported_datasets[dataset]
        self.task = task
        self.interpolation_label_size = interpolation_label_size
        super().__init__(
            split=split,
            dataset_name=dataset_name,
            dataset_path=dataset_path,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
            seed=seed,
            arbitrary_orientation=arbitrary_orientation,
            orientation_equivariant_labels=orientation_equivariant_labels,
            cache_file=cache_file,
            preprocess=preprocess,
        )

    def preprocess(self):
        # Read the features and flows.
        features_file = join(self.dataset_path, f"{self.dataset}_net.tntp")
        features_df = pd.read_csv(features_file, skiprows=8, sep="\t")
        features_df.columns = [s.strip().lower() for s in features_df.columns]
        # Drop useless columns.
        features_df.drop(["~", ";"], axis=1, inplace=True)

        flows_file = join(self.dataset_path, f"{self.dataset}_flow.tntp")
        flows_df = pd.read_csv(flows_file, sep="\t")
        flows_df.columns = [s.strip().lower() for s in flows_df.columns]

        # All links types.
        all_link_types = set(features_df["link_type"])
        link_type2idx = {link_type: idx for idx, link_type in enumerate(all_link_types)}

        # Convert the dataframes to dictionaries.
        features = {}
        for _, row in features_df.iterrows():
            u, v = int(row["init_node"]), int(row["term_node"])
            numerical_features = np.array(
                [
                    row["capacity"],
                    row["length"],
                    row["free_flow_time"],
                    row["b"],
                    row["power"],
                    row["toll"],
                ]
            )

            # One-hot encode the link types.
            link_type = int(row["link_type"])
            link_type_features = np.zeros(len(link_type2idx))
            link_type_features[link_type2idx[link_type]] = 1

            features[(u, v)] = np.concatenate([numerical_features, link_type_features])

        flows = {
            (int(row["from"]), int(row["to"])): row["volume"]
            for _, row in flows_df.iterrows()
        }

        # Pre-process the graph.
        features, flows, undirected_edges, unique_edges = combine_edges(
            features=features, flows=flows
        )
        features, flows, undirected_edges, unique_edges, node_mapping = relabel_nodes(
            features=features,
            flows=flows,
            undirected_edges=undirected_edges,
            unique_edges=unique_edges,
        )
        features, flows, undirected_edges, unique_edges = normalize_flows(
            features=features,
            flows=flows,
            undirected_edges=undirected_edges,
            unique_edges=unique_edges,
        )
        features = normalize_features(features)

        inv_features = {k: v for k, v in features.items()}
        equi_features = {k: np.zeros(0) for k in features}

        # Create PyG graph from the dictonaries.
        data, undirected_mask_1, undirected_mask_2, directed_mask = (
            create_pyg_graph_transductive(
                equi_features=equi_features,
                inv_features=inv_features,
                undirected_edges=undirected_edges,
                unique_edges=unique_edges,
                labels=flows,
                val_ratio=self.val_ratio,
                test_ratio=self.test_ratio,
                add_noisy_flow_to_input=False,
                add_interpolation_flow_to_input=False,
                add_zeros_to_flow_input=True,  # There a no equivariant features available
                interpolation_label_size=self.interpolation_label_size,
            )
        )

        torch.save(data, self.filename)
        torch.save(
            (undirected_mask_1, undirected_mask_2, directed_mask),
            join(self.dataset_path, f"undirected_masks.pt"),
        )
