import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import sort_edge_index


class DataToFloat(BaseTransform):
    def __call__(self, data):
        data.x = data.x.to(torch.float32)
        return data


class LabelsToInt(BaseTransform):
    def __call__(self, data):
        data.y = data.y.to(torch.long)
        return data


class NodeOneHot(BaseTransform):
    def __init__(self, num_node_types: int):
        self.num_node_types = num_node_types

    def __call__(self, data):
        # data.x: [N, 1] long → [N, num_node_types] float
        idx = data.x.view(-1).long()
        data.x = F.one_hot(idx, num_classes=self.num_node_types).to(torch.float)
        return data


class EdgeOneHot(BaseTransform):
    def __init__(self, num_edge_types: int):
        self.num_edge_types = num_edge_types

    def __call__(self, data):
        # data.edge_attr: [E, 1] long → [E, num_edge_types] float
        idx = data.edge_attr.view(-1).long()
        data.edge_attr = F.one_hot(idx, num_classes=self.num_edge_types).to(torch.float)
        return data


class UnsqueezeY(BaseTransform):
    def __call__(self, data):
        # Ensure graph-level target has shape [1]
        if data.y.dim() == 1:
            data.y = data.y.unsqueeze(-1)
        return data


class SqueezeY(BaseTransform):
    def __call__(self, data):
        data.y = data.y.squeeze()
        return data


class SortNodes(BaseTransform):
    """
    Sort the nodes of the graph according to the node label.
    """

    def __init__(self) -> None:
        super().__init__()

    @torch.no_grad()
    def forward(self, data: Data) -> Data:
        assert data.edge_index is not None
        assert data.y is not None
        y_sorted, sort_idx = torch.sort(data.y)
        edge_index_renamed = torch.empty_like(data.edge_index)
        for new_i in range(data.num_nodes):
            i = sort_idx[new_i]
            mask_i = data.edge_index == i
            edge_index_renamed[mask_i] = new_i

        data.x = data.x[sort_idx]
        data.y = y_sorted
        # sort edge_index_renamed in order to have edges ordered by source
        if data.edge_attr is not None:
            data.edge_index, (data.edge_weight, data.edge_attr) = sort_edge_index(
                edge_index_renamed, edge_attr=[data.edge_weight, data.edge_attr]
            )
        else:
            data.edge_index, data.edge_weight = sort_edge_index(
                edge_index_renamed, data.edge_weight
            )

        return data