"""
Code adapted from
Fuchsgruber, Dominik, et al. "Graph Neural Networks for Edge Signals: Orientation Equivariance and Invariance.", ICLR 2025
link: https://openreview.net/forum?id=XWBE90OYlH
"""

import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
from utils.stratified_split import scsplit


def combine_edges(features: dict, flows: dict) -> tuple[dict, dict, dict]:
    """
    Save indexes of directed and undirected edges for the loss computation
    during model training.

    Args:
        features (dict): Features of the edges.
        flows (dict): Flows of the edges.

    Returns:
        tuple[dict, dict, dict]:
            Combined features, combined flows, undirected edges and unique edges.
    """
    combined_features = {}
    combined_flows = {}
    undirected_edges = {}
    unique_edges = {}  # for train/val/test splits, keep undirected edges only once

    for e in features.keys():
        u, v = e
        if flows[(u, v)] != 0 and ((v, u) not in features.keys() or flows[(v, u)] != 0):
            combined_features[e] = features[e]
            if (v, u) in features.keys():

                # Undirected edge. Keep track of them for the train/val/test masks
                if (v, u) not in undirected_edges:
                    undirected_edges[(u, v)] = 1
                    undirected_edges[(v, u)] = 1
                    unique_edges[(u, v)] = 1

                    # Both edges will be treated as separate directed edges
                    combined_flows[(u, v)] = flows[(u, v)]
                    combined_flows[(v, u)] = flows[(v, u)]

            else:
                # Directed edge.
                undirected_edges[e] = 0
                unique_edges[e] = 1

                if e in flows:
                    combined_flows[e] = flows[e]

    return combined_features, combined_flows, undirected_edges, unique_edges


def relabel_nodes(
    features: dict, flows: dict, undirected_edges: dict, unique_edges: dict
) -> tuple[dict, dict, dict, dict]:
    """
    Relabels nodes to 0, ..., N - 1 range and updates the edges.

    Args:
        features (dict): Features of the edges.
        flows (dict): Flows of the edges.
        undirected_edges (dict): Whether the edges are undirected (label 1) or directed (label 0).

    Returns:
        tuple[dict, dict, dict, dict]:
            Relabeled features, relabeled flows, relabeled undirected edges, and node mapping.
    """
    all_nodes = [node for edge in features.keys() for node in edge]
    node_mapping = {node: idx for idx, node in enumerate(set(all_nodes))}

    relabeled_features = {
        (node_mapping[u], node_mapping[v]): feat for ((u, v), feat) in features.items()
    }
    relabeled_flows = {
        (node_mapping[u], node_mapping[v]): flow for ((u, v), flow) in flows.items()
    }
    relabeled_undirected_edges = {
        (node_mapping[u], node_mapping[v]): undirected_edge
        for ((u, v), undirected_edge) in undirected_edges.items()
    }
    relabeled_unique_edges = {
        (node_mapping[u], node_mapping[v]): unique_edge
        for ((u, v), unique_edge) in unique_edges.items()
    }

    return (
        relabeled_features,
        relabeled_flows,
        relabeled_undirected_edges,
        relabeled_unique_edges,
        node_mapping,
    )


def normalize_flows(
    features: dict,
    flows: dict,
    undirected_edges: dict,
    unique_edges: dict,
    normalize_by_max_flow: bool = True,
) -> tuple[dict, dict, dict]:
    """
    Converts flow estimation instance to a non-negative one for directed edges and normalizes all values.

    Args:
        features (dict): Features of the edges.
        flows (dict): Flows of the edges.
        undirected_edges (dict): Whether the edges are undirected (label 1) or directed (label 0).

    Returns:
        tuple[dict, dict, dict]: New features, normalized flows, and new undirected edges.
    """

    normalized_flows = {}
    new_features = {}
    new_undirected_edges = {}
    new_unique_edges = {}

    if normalize_by_max_flow:
        max_flow = max([abs(f) for f in flows.values()])
    else:
        max_flow = 1.0
    for edge in features.keys():
        if edge in flows:
            # Flip only directed edges with flows against their direction.
            if undirected_edges[edge] == 0:
                if flows[edge] < 0:
                    flipped_edge = (edge[1], edge[0])
                    normalized_flows[flipped_edge] = -flows[edge] / max_flow
                    new_features[flipped_edge] = features[edge]
                    new_undirected_edges[flipped_edge] = undirected_edges[edge]
                    if edge in unique_edges:
                        new_unique_edges[flipped_edge] = unique_edges[edge]
                else:
                    normalized_flows[edge] = flows[edge] / max_flow
                    new_features[edge] = features[edge]
                    new_undirected_edges[edge] = undirected_edges[edge]
                    if edge in unique_edges:
                        new_unique_edges[edge] = unique_edges[edge]
            else:
                # We just need to normalize the flow here, no need for flipping edges
                # as in EIGN
                normalized_flows[edge] = flows[edge] / max_flow
                new_features[edge] = features[edge]
                new_undirected_edges[edge] = undirected_edges[edge]
                if edge in unique_edges:
                    new_unique_edges[edge] = unique_edges[edge]

        else:
            new_features[edge] = features[edge]
            new_undirected_edges[edge] = undirected_edges[edge]

    return new_features, normalized_flows, new_undirected_edges, new_unique_edges


def normalize_features(features: dict) -> tuple[dict, dict]:
    """
    Normalizes features using standard scaler.

    Args:
        features (dict): Features of the edges.

    Returns:
        dict: Normalized features.
    """
    scaler = StandardScaler()

    num_features = features[next(iter(features))].shape[0]
    feature_matrix = np.zeros((len(features), num_features))

    for i, e in enumerate(features):
        feature_matrix[i] = features[e]

    feature_matrix = scaler.fit_transform(feature_matrix)

    normalized_features = {}
    for i, e in enumerate(features):
        normalized_features[e] = feature_matrix[i]

    return normalized_features


def continuous_idx_split(
    values,
    train_size: float,
    val_size: float,
    test_size: float | None = None,
    random_state: int | None = None,
    stratify: bool = True,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Performs a continuous stratified split of the indices based on the values."""
    data = pd.DataFrame({"y": values})
    if test_size is None:
        test_size = 1 - train_size - val_size

    if stratify:
        train_val, test = scsplit(
            data,
            stratify=data["y"],
            train_size=train_size + val_size,
            test_size=test_size,
            random_state=random_state,
        )
        train, val = scsplit(
            train_val,
            stratify=train_val["y"],
            train_size=train_size / (train_size + val_size),
            test_size=val_size / (train_size + val_size),
            random_state=random_state,
        )
        idx_train = np.array(train.index)
        idx_val = np.array(val.index)
        idx_test = np.array(test.index)

    else:
        idxs = np.arange(len(data))
        rng = np.random.default_rng(random_state)
        rng.shuffle(idxs)
        idx_train = idxs[: int(train_size * len(data))]
        idx_val = idxs[
            int(train_size * len(data)) : int((train_size + val_size) * len(data))
        ]
        idx_test = idxs[int((train_size + val_size) * len(data)) :]

    assert (
        len(np.intersect1d(idx_train, idx_val)) == 0
        and len(np.intersect1d(idx_train, idx_test)) == 0
        and len(np.intersect1d(idx_val, idx_test)) == 0
    )

    return idx_train, idx_val, idx_test


def add_opposite_edges(
    labeled_idx,
    idx_train,
    unique_edges,
    undirected_mask,
    index_edge_mapping,
    labels,
    edge_index_mapping,
):
    """
    If one undirected edge is selected in the training set, add the opposite edge
    Parameters:
        labeled_idx (torch.Tensor): Indices of the labeled edges.
        idx_train (torch.Tensor): Indices of the training edges.
        unique_edges (dict): Unique edges.
        undirected_mask (torch.Tensor): Mask for undirected edges.
        index_edge_mapping (dict): Mapping from edge index to edge.
        labels (dict): Labels of the edges.
        edge_index_mapping (dict): Mapping from edge to edge index.
    Returns:
        new_idx_train (torch.Tensor): Updated indices of the training edges.
    """
    new_idx_train = labeled_idx[idx_train]
    for x in idx_train:
        if undirected_mask[labeled_idx[x]]:
            # add also the edge in the opposite direction
            # in the training set
            cur_edge = index_edge_mapping[labeled_idx[x].item()]
            opposite_edge = (cur_edge[1], cur_edge[0])
            assert opposite_edge in labels and opposite_edge not in unique_edges
            new_idx_train = np.concatenate(
                (
                    new_idx_train,
                    np.array(edge_index_mapping[opposite_edge]).reshape(
                        1,
                    ),
                )
            )
    return new_idx_train


def create_pyg_graph_transductive(
    equi_features: dict,
    inv_features: dict,
    undirected_edges: dict,
    unique_edges: dict,
    labels: dict,
    val_ratio: float,
    test_ratio: float,
    add_noisy_flow_to_input: bool = False,
    add_interpolation_flow_to_input: bool = False,
    add_zeros_to_flow_input: bool = False,
    stratified_split: bool = False,
    seed: int | None = None,
    interpolation_label_size: float = 0.75,
    num_training_splits: int = 20,
) -> Data:
    """
    Creates a PyTorch Geometric data object based on features and labels.
    It also creates a random transductive split of the labeled edges into train/validation/test.

    Args:
        equi_features (dict): Orientation-equivariant features of the edges.
        inv_features (dict): Orientation-invariant features of the edges.
        undirected_edges (dict): Whether the edges are undirected (label 1) or directed (label 0).
        labels (dict): Labels of the edges.
        val_ratio (float): Proportion of validation edges.
        test_ratio (float): Proportion of testing edges.
        add_noisy_flow_to_input (bool, optional): Whether to add noisy flow as an orientation-equivariant feature.
            Defaults to False.
        add_interpolation_flow_to_input (bool, optional): Whether to add interpolation flow as an orientation-equivariant feature
            to 75% of the training samples. Defaults to False.
        add_zeros_to_flow_input (bool, optional): Whether to add zeros to the flow input. Defaults to False.

    Returns:
        Data: PyTorch Geometric data object.
    """
    num_edges = len(equi_features)

    num_equi_features = equi_features[list(equi_features.keys())[0]].shape[0]
    equi_edge_attr = torch.zeros(num_edges, num_equi_features)
    num_inv_features = inv_features[list(inv_features.keys())[0]].shape[0]
    inv_edge_attr = torch.zeros(num_edges, num_inv_features)

    edge_index = torch.zeros(2, num_edges, dtype=torch.long)
    undirected_mask = torch.zeros(num_edges, dtype=bool)
    y = torch.zeros(num_edges, dtype=torch.float)
    labeled_mask = torch.zeros(num_edges, dtype=bool)

    index_edge_mapping = {}
    edge_index_mapping = {}

    undirected_mask_1 = []
    undirected_mask_2 = []
    directed_mask = torch.zeros(num_edges, dtype=bool)

    for i, e in enumerate(equi_features):
        index_edge_mapping[i] = e
        edge_index_mapping[e] = i
        edge_index[:, i] = torch.tensor([e[0], e[1]])
        equi_edge_attr[i] = torch.tensor(equi_features[e])
        inv_edge_attr[i] = torch.tensor(inv_features[e])
        undirected_mask[i] = undirected_edges[e]
        directed_mask[i] = 1 - undirected_edges[e]
        y[i] = labels[e]

        if e in labels and e in unique_edges:
            labeled_mask[i] = True

    for i, e in enumerate(undirected_edges):
        if undirected_edges[e]:
            if (
                edge_index_mapping[e] not in undirected_mask_1
                and edge_index_mapping[e] not in undirected_mask_2
            ):
                undirected_mask_1.append(edge_index_mapping[e])
                undirected_mask_2.append(edge_index_mapping[(e[1], e[0])])

    undirected_mask_1 = torch.tensor(undirected_mask_1)
    undirected_mask_2 = torch.tensor(undirected_mask_2)

    all_train_mask, all_val_mask, all_test_mask = [], [], []

    for _ in range(num_training_splits):
        train_mask, val_mask, test_mask = (
            torch.zeros(num_edges, dtype=bool),
            torch.zeros(num_edges, dtype=bool),
            torch.zeros(num_edges, dtype=bool),
        )

        labeled_idx = torch.nonzero(labeled_mask, as_tuple=True)[0]

        idx_train, idx_val, idx_test = continuous_idx_split(
            y[labeled_idx],
            1 - val_ratio - test_ratio,
            val_ratio,
            test_ratio,
            seed,
            stratify=stratified_split,
        )

        new_idx_train = add_opposite_edges(
            labeled_idx,
            idx_train,
            unique_edges,
            undirected_mask,
            index_edge_mapping,
            labels,
            edge_index_mapping,
        )
        new_idx_val = add_opposite_edges(
            labeled_idx,
            idx_val,
            unique_edges,
            undirected_mask,
            index_edge_mapping,
            labels,
            edge_index_mapping,
        )
        new_idx_test = add_opposite_edges(
            labeled_idx,
            idx_test,
            unique_edges,
            undirected_mask,
            index_edge_mapping,
            labels,
            edge_index_mapping,
        )

        train_mask[new_idx_train] = True
        val_mask[new_idx_val] = True
        test_mask[new_idx_test] = True

        assert (
            not torch.any(train_mask & val_mask)
            and not torch.any(val_mask & test_mask)
            and not torch.any(train_mask & test_mask)
        )

        all_train_mask.append(train_mask)
        all_val_mask.append(val_mask)
        all_test_mask.append(test_mask)

    if add_zeros_to_flow_input:
        equi_edge_attr = torch.cat([torch.zeros(num_edges, 1), equi_edge_attr], dim=-1)

    num_nodes = int(edge_index.max().item()) + 1
    data = Data(
        num_nodes=num_nodes,
        edge_index=edge_index,
        equi_edge_attr=equi_edge_attr,
        inv_edge_attr=inv_edge_attr,
        undirected_mask=undirected_mask,
        y=y,
        train_mask=all_train_mask,
        val_mask=all_val_mask,
        test_mask=all_test_mask,
    )

    return data, undirected_mask_1, undirected_mask_2, directed_mask
