from typing import Tuple, Optional
import torch


def get_neighborhood(
    positions: torch.Tensor,  # [num_positions, 4]
    cutoff_out: float,
    cutoff_in: float,
    batch: Optional[torch.Tensor] = None,  # [num_particles,]
    num_graphs: Optional[int] = None,
    use_cutoff: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:

    num_particles = positions.shape[0]
    if batch is None:
        batch = torch.zeros(num_particles)
        num_graphs = 1

    # get fully connected graph
    edge_index = get_long_fully_connected_graph(
        batch=batch,
        num_graphs=num_graphs,
    )
    if use_cutoff:
        sender, receiver = edge_index
        # Compute distances in full graph
        vectors = positions[sender] - positions[receiver]  # [N_edges, 4]
        vectors_squared = torch.pow(vectors, 2)
        lengths = (2 * vectors_squared[..., 0] -
                   vectors_squared.sum(dim=-1))**2  # [n_edges, 1]
        #mask the cutoff
        mask_cutoff_in = cutoff_in**2 < lengths.unsqueeze(-1)
        mask_cutoff_out = lengths.unsqueeze(-1) < cutoff_out**2
        mask_cutoff = mask_cutoff_in & mask_cutoff_out
        # Build output
        edge_index = edge_index[:, mask_cutoff.flatten()]  # [2,n_edges]
    return edge_index


def get_long_fully_connected_graph(
    batch: torch.Tensor,  # [N_atoms,]
    num_graphs: int,
):
    _, counts = torch.unique(batch, return_counts=True)
    start = 0
    edge_index = []
    for i in range(num_graphs):
        ind = torch.linspace(start,
                             float(counts[i]) + start - 1,
                             steps=counts[i],
                             dtype=torch.int64)
        comb = torch.combinations(ind, r=2)
        edge_index.append(
            torch.cat((comb, torch.flip(comb, dims=[0, 1])), dim=0))
        start += counts[i]
    return (torch.cat(edge_index, dim=0).T).to(batch.device)
