from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import tqdm
import gc
import argparse, json, os

nerons = []
def hook(module, input, output):
    global nerons
    nerons.append(output.detach())

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--lang", type=str)
    

    args = parser.parse_args()

    outputpath = f'result_{args.model_path}/'

    os.makedirs(outputpath, exist_ok=True)


    outputfile = outputpath +f'sum_{args.model_path}_{args.lang}.result'
    outputfile2 = outputpath +f'ave_{args.model_path}_{args.lang}.result'

    if os.path.exists(outputfile2):
        continue

    model = AutoModelForCausalLM.from_pretrained(args.model_path)
    model = model.cuda()
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    handles = []
    hidden_size = 8192
    actstate = []
    for i in range(len(model.transformer.h)):
        handles.append(model.transformer.h[i].mlp.act_fn.register_forward_hook(hook))
        actstate.append(torch.zeros(hidden_size))

    inputs = []
    with open(f'{args.lang}.json') as f:
        for line in f:
            inputs.append(json.loads(line)['prompt'])

    batchsize = 8
    tokens = 0


    for i in tqdm.tqdm(range(len(inputs)//batchsize)):
        
        nerons = []
        torch.cuda.empty_cache()
        batched_inputs = inputs[i*batchsize:(i+1)*batchsize]
        inputs_tokenized = tokenizer(batched_inputs, return_tensors="pt", max_length=2048,return_token_type_ids=False,padding=True,truncation=True)
        inputs_tokenized = {k:v.cuda() for k,v in inputs_tokenized.items()}
        seqlen = [len(torch.where(inputs_tokenized['attention_mask'][j] > 0)[0]) for j in range(len(batched_inputs))]
        tokens += sum(seqlen)
        with torch.no_grad():
            model(**inputs_tokenized)
        assert len(nerons) == len(actstate)
        for j in range(len(nerons)):
            nerons[j] = nerons[j].masked_fill(nerons[j]>0,1)
            nerons[j] = nerons[j].masked_fill(nerons[j]<=0,0)
            nerons[j] = nerons[j].masked_fill(inputs_tokenized['attention_mask'].unsqueeze(-1)==0,0)
            y = nerons[j].sum(dim=0).sum(dim=0)
            y = y.cpu()
            actstate[j] = actstate[j] + y
    ave = [i/tokens for i in actstate]
    
    torch.save(ave,outputfile2)
    torch.save(actstate,outputfile)
    
    del model
    gc.collect()
    torch.cuda.empty_cache()

