import os
import numpy as np
import torch

# list of folders used in the study
# main folder, containing code and figures
home_folder = '/media/anonymous/llm_brain_asym/'
# path to Le Petit Prince fMRI corpus, downloaded from https://doi.org/10.18112/openneuro.ds003643.v2.0.5
lpp_path = '/media/anonymous/fmri/openneuro/ds003643-download/'
# annotations, used for aligning text and speech
annotation_folder = os.path.join(lpp_path, 'annotation')

# location of activations from the various llms
llms_activations = os.path.join(home_folder, 'llms_activations')
# location of brain correlations for each model, for each layer
llms_brain_correlations = os.path.join(home_folder, 'llms_brain_correlations')
# all figures in the paper
fig_folder = os.path.join(home_folder, 'fig')

n_runs = 9
t_r = 2 #s

# helpers
def make_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

def standardize(v, axis=0):
    return (v - np.mean(v, axis=axis, keepdims=True)) / np.std(v, axis=axis, keepdims=True)


def compute_logprobs(model, tokenizer, sentences, batch_size=512):
    tokenizer.padding_side = 'right'
    if tokenizer.pad_token_id == None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    logprobs_sentences = []
    for k in range(0, len(sentences), batch_size):    
        inputs = tokenizer(sentences[k:(k+batch_size)], return_tensors="pt", padding=True).to(model.device)
        with torch.no_grad():
            outputs = model(**inputs)
            logprobs = outputs.logits[:,:-1].log_softmax(dim=-1)
            attention_mask = inputs['attention_mask'][:,1:]
            logprobs_tokens = torch.gather(logprobs, dim=-1, index=inputs['input_ids'][:,1:,None]).squeeze(-1)
            logprobs_sentences.extend(((logprobs_tokens*attention_mask).sum(dim=-1)
                                       / attention_mask.sum(dim=-1)).cpu().numpy())       
    return np.array(logprobs_sentences)