import torch
import torch.nn.functional as F


def get_clean_mapping(mapping, vocab_dict):
    new_mapping = {}
    for k, v in mapping.items():
        if k == v:
            continue
        if k in vocab_dict and v in vocab_dict:
            new_mapping[k] = v
    print(f"Mapping size: {len(new_mapping)}")
    return new_mapping