from typing import Sequence, Optional
from LorentzMACE.tools.torch_tools import to_one_hot

import torch.utils.data
from LorentzMACE.tools import torch_geometric

import importlib
from LorentzMACE.data.neighborhood import get_neighborhood
from LorentzMACE.data.utils import Configuration


class AtomicData(torch_geometric.data.Data):
    edge_index: torch.Tensor
    node_attrs: torch.Tensor
    edge_vectors: torch.Tensor
    edge_lengths: torch.Tensor
    forces: torch.Tensor
    energy: torch.Tensor

    def __init__(
        self,
        edge_index: torch.Tensor,  # [2, n_edges]
        node_attrs: torch.Tensor,  # [n_nodes, n_node_feats]
        positions: torch.Tensor,  # [n_nodes, 4]
        signal: Optional[torch.Tensor]  #[n_particles, n_class]
    ):
        # Check shapes
        num_nodes = positions.shape[0]

        assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
        assert positions.shape == (num_nodes, 4)
        assert node_attrs is None or len(node_attrs.shape) == 2
        # Aggregate data
        data = {
            'num_nodes': num_nodes,
            'edge_index': edge_index,
            'positions': positions,
            'node_attrs': node_attrs,
            'signal': signal,
        }
        super().__init__(**data)

    @classmethod
    def from_config(cls, config: Configuration, cutoff_in: float,
                    cutoff_out: float) -> 'AtomicData':
        positions = torch.tensor(config.positions,
                                 dtype=torch.get_default_dtype())
        edge_index = get_neighborhood(positions=positions,
                                      cutoff_in=cutoff_in,
                                      cutoff_out=cutoff_out)
        node_attrs = to_one_hot(
            torch.tensor(config.attributes, dtype=torch.long).unsqueeze(-1),
            num_classes=2) if config.attributes is not None else None
        return cls(edge_index=edge_index,
                   positions=positions,
                   node_attrs=node_attrs,
                   signal=torch.tensor(config.signal.astype(int),
                                       dtype=torch.long))


def get_data_loader(
    dataset: Sequence[AtomicData],
    batch_size: int,
    shuffle=True,
    drop_last=False,
) -> torch.utils.data.DataLoader:
    return torch_geometric.dataloader.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
    )