import torch


def batch_stack(props):
    """
    Stack a list of torch.tensors so they are padded to the size of the
    largest tensor along each axis.

    Parameters
    ----------
    props : list of Pytorch Tensors
        Pytorch tensors to stack

    Returns
    -------
    props : Pytorch tensor
        Stacked pytorch tensor.

    Notes
    -----
    TODO : Review whether the behavior when elements are not tensors is safe.
    """
    if not torch.is_tensor(props[0]):
        return torch.tensor(props)
    elif props[0].dim() == 0:
        return torch.stack(props)
    else:
        return torch.nn.utils.rnn.pad_sequence(props, batch_first=True, padding_value=0)


def drop_zeros(props, to_keep):
    """
    Function to drop zeros from batches when the entire dataset is padded to the largest molecule size.

    Parameters
    ----------
    props : Pytorch tensor
        Full Dataset


    Returns
    -------
    props : Pytorch tensor
        The dataset with  only the retained information.

    Notes
    -----
    TODO : Review whether the behavior when elements are not tensors is safe.
    """
    if not torch.is_tensor(props[0]):
        return props
    elif props[0].dim() == 0:
        return props
    else:
        return props[:, to_keep, ...]


class PreprocessQM9:
    def __init__(self, use_ghost_nodes=False):
        self.use_ghost_nodes = use_ghost_nodes
    
    def add_trick(self, trick):
        self.tricks.append(trick)

    def collate_fn(self, batch):
        """
        Collation function that collates datapoints into the batch format for cormorant

        mainly does the following:
        stacks properties of different elements into batch
        remove extra 0s if current batch does not contain the largest data point across dataset
        creates atom_mask and edge_mask for the batch

        Parameters
        ----------
        batch : list of datapoints
            The data to be collated.

        Returns
        -------
        batch : dict of Pytorch tensors
            The collated data.
        """
        batch = {prop: batch_stack([mol[prop] for mol in batch]) for prop in batch[0].keys()}

        # When using ghost nodes, do not drop any node
        if not self.use_ghost_nodes:
            to_keep = (batch['atomic_numbers'].sum(0) > 0)

            adj_list = batch['adj_list']
            adj_matrix = batch['adj_matrix'] if 'adj_matrix' in batch else None
            if 'morgan_fingerprint' in batch:
                morgan_fingerprint = batch['morgan_fingerprint']
            batch = {key: drop_zeros(prop, to_keep) for key, prop in batch.items() if key not in ['adj_list', 'adj_matrix', 'morgan_fingerprint']}

            # different structure and thus different indices to drop for adj_list
            to_keep_adj_list = (adj_list.sum(0).sum(-1) > 0)
            adj_list = drop_zeros(adj_list, to_keep_adj_list)
            batch['adj_list'] = adj_list
            # similar problem for adj_matrix
            if adj_matrix is not None:
                batch['adj_matrix'] = adj_matrix[:, to_keep][:, :, to_keep]
            if 'morgan_fingerprint' in batch:
                batch['morgan_fingerprint'] = morgan_fingerprint

        if self.use_ghost_nodes:
            # include atoms with type 0: "ghost ndoes"
            atom_mask = batch['atomic_numbers'] >= 0
        else:
            atom_mask = batch['atomic_numbers'] > 0
        batch['atom_mask'] = atom_mask

        #Obtain edges
        batch_size, n_nodes = atom_mask.size()
        edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)

        #mask diagonal
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
        edge_mask *= diag_mask

        #edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
        batch['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)

        # mask extra formal charges
        batch['formal_charges_one_hot'] = batch['formal_charges_one_hot'] * atom_mask.unsqueeze(-1)

        return batch
