
import torch.nn.functional as F
import json

def load_json(path):
    with open(path, 'r') as fp:
        return json.load(fp)

def get_probs_and_mrrs(model, logits, answer):

    layerwise_probs = model.prob_of_answer(logits, answer)
    layerwise_rrs = model.rr_per_layer(logits, answer)
    return layerwise_probs, layerwise_rrs


def from_layer_logits_to_prob_distros(logits):
    """
    Take in a tensor of shape (n_layers, vocab_size) and return a same sized tensor but with a softmax over the vocab size in each layer
    """
    probs = logits.detach().cpu().clone().float()
    for i in range(probs.shape[0]): #iterate over each layer
        probs[i] = F.softmax(probs[i], dim=-1)
    return probs.half().numpy() #is now a bunch of probability distributions