"""
Module containing transformers for converting between different data formats.
"""
import networkx as nx
import torch
from toponetx.classes import SimplicialComplex as ToponetxSimplicialComplex
from toponetx.transform import graph_to_clique_complex
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

from scrawl.simplicial import SimplicialComplex, SimplicialData


def smart_to_networkx(data: Data, **kwargs) -> nx.Graph:
    """
    Convert a torch geometric data object to a networkx graph.

    This method is a wrapper around `torch_geometric.utils.to_networkx` and tries to
    be smart about what node and edge attributes to copy over.

    Parameters
    ----------
    data : Data
        The torch geometric data object.
    **kwargs : keyword arguments
        Additional keyword arguments passed to `torch_geometric.utils.to_networkx`.

    Returns
    -------
    networkx.Graph
        The networkx graph.
    """
    node_attr = list(filter(lambda name: getattr(data, name) is not None, ["x"]))
    edge_attr = list(
        filter(lambda name: getattr(data, name) is not None, ["edge_attr"])
    )
    graph_attr = list(filter(lambda name: getattr(data, name) is not None, ["y"]))

    return to_networkx(
        data, node_attr, edge_attr, graph_attr, remove_self_loops=True, **kwargs
    )


def toponetx_to_sc(simplicial_complex: ToponetxSimplicialComplex) -> SimplicialComplex:
    """
    Transform a TopoNetX simplicial complex into our own simplicial complex format.

    Parameters
    ----------
    simplicial_complex : toponetx.classes.SimplicialComplex
        The simplicial complex.

    Returns
    -------
    SimplicialComplex
        The simplicial complex.
    """
    boundary_matrices = {}
    for rank in range(1, simplicial_complex.dim + 1):
        boundary_matrices[rank] = torch.from_numpy(
            simplicial_complex.incidence_matrix(rank, signed=False).toarray()
        )

    return SimplicialComplex(boundary_matrices)


def toponetx_to_data(
    simplicial_complex: ToponetxSimplicialComplex,
    attr_names: str | list[str] | None,
    dtype: torch.dtype,
    device: torch.device | None = None,
) -> SimplicialData:
    """
    Transform a TopoNetX simplicial complex into our own simplicial data format.

    Parameters
    ----------
    simplicial_complex : toponetx.classes.SimplicialComplex
        The simplicial complex.
    attr_names : str or list[str] or None
        The names of the attributes to copy over as simplicial data.
    dtype : torch.dtype
        The dtype of the data supported on the simplicial complexes.
    device : torch.device, default="cpu"
        Device on which to store the data.

    Returns
    -------
    SimplicialData
        The simplicial data.
    """
    if device is None:
        device = torch.device("cpu")

    data = SimplicialData(
        toponetx_to_sc(simplicial_complex),
        dtype=dtype,
        device=device,
    )

    if isinstance(attr_names, str):
        attr_names = [attr_names]

    if attr_names is None or len(attr_names) == 0:
        return data

    for rank in range(0, simplicial_complex.dim + 1):
        data[rank] = torch.empty(
            (simplicial_complex.shape[rank], 0),
            dtype=data.dtype,
            device=data.device,
        )

        for attr_name in attr_names:
            attributes = torch.tensor(
                list(
                    simplicial_complex.get_simplex_attributes(
                        attr_name, rank=rank
                    ).values()
                ),
                dtype=data.dtype,
                device=data.device,
            )

            if len(attributes.shape) == 1:
                attributes = attributes.unsqueeze(1)

            if attributes.size(0) == 0:
                continue

            data[rank] = torch.hstack((data[rank], attributes))
    return data


def torch_geometric_to_data(
    data: Data,
    dtype: torch.dtype,
    device: torch.device | None = None,
    *,
    max_dim: int | None = None,
) -> SimplicialData:
    """
    Transform a torch geometric data object into our own simplicial data format.

    The graph represented by the torch geometric data object is converted into a
    simplicial complex by computing the clique complex of the graph. If `max_dim` is
    specified, the clique complex is only computed up to the specified dimension.

    Parameters
    ----------
    data : Data
        The torch geometric data object.
    dtype : torch.dtype
        The dtype of the data supported on the simplicial complexes.
    device : torch.device, default="cpu"
        Device on which to store the data.
    max_dim : int, optional
        The maximum dimension of the simplicial complex to construct.

    Returns
    -------
    SimplicialData
        The simplicial data.
    """
    if device is None:
        device = torch.device("cpu")

    attr_names = []
    if data.x is not None:
        attr_names.append("x")
    if data.edge_attr is not None:
        attr_names.append("edge_attr")

    graph = smart_to_networkx(data, to_undirected=True)

    simplicial_complex = graph_to_clique_complex(graph, max_dim)
    simplicial_data = toponetx_to_data(
        simplicial_complex, attr_names, dtype=dtype, device=device
    )

    if data.y is not None and len(data.y.shape) == 1:
        simplicial_data.set_aux_tensor(-1, data.y)
    if data.y is not None and len(data.y.shape) == 2 and data.y.shape[0] == 1:
        # This is a graph classification task, but the target is one-hot encoded already.
        # Convert to integer targets.
        simplicial_data.set_aux_tensor(-1, data.y.argmax(dim=1).squeeze().to(torch.int))

    return simplicial_data


def pad_data(data: SimplicialData, sizes: list[int]) -> SimplicialData:
    """
    Pad the data to the specified sizes.

    Parameters
    ----------
    data : SimplicialData
        The simplicial data.
    sizes : list[int]
        The sizes to pad the data to.

    Returns
    -------
    SimplicialData
        The padded simplicial data.

    Raises
    ------
    RuntimeError
        If the data is already larger than the specified sizes.
    """
    for rank, size in enumerate(sizes):
        if rank > data.domain.dim:
            break

        if data[rank].shape[1] < size:
            data[rank] = torch.hstack(
                (
                    data[rank],
                    torch.zeros(
                        (data[rank].shape[0], size - data[rank].shape[1]),
                        dtype=data.dtype,
                        device=data.device,
                    ),
                )
            )
        elif data[rank].shape[1] > size:
            raise RuntimeError(
                f"Cannot pad data to a smaller size. Wanted to pad rank {rank} to size {size} but the data is already of size {data[rank].shape[1]}."
            )

    return data
