import torch


def adjacency_error(adj_pred, adj_gt, num_atoms):
    correct_edges = 0.
    total_edges = 0.
    correct_molecules = 0.
    total_molecules = 0.
    for pred, gt, n_nodes in zip(adj_pred, adj_gt, num_atoms):
        pred, gt = pred[:n_nodes, :n_nodes], gt[:n_nodes, :n_nodes]
        pred = torch.argmax(pred, -1)
        
        # only consider edges above tha main diagonal
        adj_correct = (torch.triu(pred, diagonal=1) == torch.triu(gt, diagonal=1)).type(torch.float32)
        # these are this many:
        n_unique_edges = (n_nodes ** 2 - n_nodes) / 2
        # if we sum over the binary matrix adj_correct, we would count every edge twice
        # and we would count the diagonal which will be always correct by construction
        # so we need to subtract these
        new_correct_edges = adj_correct.sum() - n_unique_edges - n_nodes
        assert new_correct_edges >= 0, "something wrong with the evaluation logic"
        
        correct_edges += new_correct_edges
        total_edges += n_unique_edges
        
        if new_correct_edges == n_unique_edges:
            correct_molecules += 1
        total_molecules += 1
        
    assert total_molecules == len(num_atoms)
    return correct_edges.item(), total_edges.item(), correct_molecules, total_molecules

def reconstruction_error(adj_pred_batch, atom_types_pred_batch, formal_charges_pred_batch, adj_gt_batch, gt_batch):
    correct_edges = 0.
    total_edges = 0.

    correct_atom_types = 0.
    correct_formal_charges = 0.
    total_atoms = 0.

    correct_molecules = 0.
    total_molecules = 0.

    batch_size, _, n_atom_types = gt_batch['atomic_numbers_one_hot'].shape
    for i in range(batch_size):
        n_nodes = gt_batch['num_atoms'][i]
        adj_pred = adj_pred_batch[i][:n_nodes, :n_nodes]
        adj_gt = adj_gt_batch[i][:n_nodes, :n_nodes]

        atom_types_pred = atom_types_pred_batch[i][:n_nodes]
        formal_charges_pred = formal_charges_pred_batch[i][:n_nodes]

        atom_types_gt = gt_batch['atomic_numbers_one_hot'][i][:n_nodes, :].float().argmax(-1).to(atom_types_pred.device)
        formal_charges_gt = gt_batch['formal_charges'][i][:n_nodes].to(formal_charges_pred.device)

        # Features Accuracy
        new_correct_atom_types = (atom_types_pred == atom_types_gt).float().sum()
        correct_atom_types += new_correct_atom_types.item()
        new_correct_formal_charges = (formal_charges_pred == formal_charges_gt).float().sum()
        assert new_correct_formal_charges <= n_nodes, f"number of correct charges {new_correct_formal_charges} should not be greated than number of nodes {n_nodes}"
        correct_formal_charges += new_correct_formal_charges.item()
        total_atoms += n_nodes.item()

        # Edges Accuracy        
        # only consider edges above tha main diagonal
        adj_correct = (torch.triu(adj_pred, diagonal=1) == torch.triu(adj_gt, diagonal=1)).type(torch.float32)
        # these are this many:
        n_unique_edges = (n_nodes ** 2 - n_nodes) / 2
        # if we sum over the binary matrix adj_correct, we would count every edge twice
        # and we would count the diagonal which will be always correct by construction
        # so we need to subtract these
        new_correct_edges = adj_correct.sum() - n_unique_edges - n_nodes
        assert new_correct_edges >= 0, "something wrong with the evaluation logic"
        
        correct_edges += new_correct_edges.item()
        total_edges += n_unique_edges.item()
        
        if new_correct_edges == n_unique_edges and new_correct_atom_types == n_nodes and new_correct_formal_charges == n_nodes:
            correct_molecules += 1
        total_molecules += 1
        
    assert total_molecules == batch_size
    results = {
        'correct_edges': correct_edges,
        'total_edges': total_edges,

        'correct_atom_types': correct_atom_types,
        'correct_formal_charges': correct_formal_charges,
        'total_atoms': total_atoms,

        'correct_molecules': correct_molecules,
        'total_molecules': total_molecules,
    }
    return results
