"""Converts a PyTorch Geometric Batch object to an AtomGraphs object for the Orb models."""

import torch
import torch_geometric.data
from orb_models.forcefield.base import AtomGraphs


def to_atom_graphs(
    topology: torch_geometric.data.Batch,
    batch: torch.Tensor,
    num_graphs: int,
) -> AtomGraphs:
    """Converts a PyTorch Geometric Batch object to an AtomGraphs object for the Orb models."""
    senders = topology.edge_index[0]
    receivers = topology.edge_index[1]

    n_node = torch.bincount(batch, minlength=num_graphs)
    n_edge = torch.bincount(batch[senders], minlength=num_graphs)

    node_features: dict[str, torch.Tensor] = {}
    if topology.get("atom_type_index", None) is not None:
        node_features["atom_type_index"] = topology.atom_type_index
    if topology.get("atom_code_index", None) is not None:
        node_features["atom_code_index"] = topology.atom_code_index
    if topology.get("residue_code_index", None) is not None:
        node_features["residue_code_index"] = topology.residue_code_index
    if topology.get("residue_sequence_index", None) is not None:
        node_features["residue_sequence_index"] = topology.residue_sequence_index
    if topology.get("residue_index", None) is not None:
        node_features["residue_index"] = topology.residue_index

    edge_features: dict[str, torch.Tensor] = {}
    if topology.get("bond_mask", None) is not None:
        edge_features["bond_mask"] = topology.bond_mask

    system_features: dict[str, torch.Tensor] = {}
    if topology.get("num_residues", None) is not None:
        system_features["num_residues"] = topology.num_residues

    if topology.get("loss_weight", None) is not None:
        system_features["loss_weight"] = topology.loss_weight

    return AtomGraphs(
        senders=senders,
        receivers=receivers,
        n_node=n_node,
        n_edge=n_edge,
        node_features=node_features,
        edge_features=edge_features,
        system_features=system_features,
        node_targets=None,
        edge_targets=None,
        system_targets=None,
        system_id=None,
        fix_atoms=None,
        tags=None,
        radius=None,
        max_num_neighbors=None,
        half_supercell=False,
    )
