import torch

def topk_index(tensor, k=-1):
    if k < 0:
        k = tensor.numel()
    index = tensor.flatten().topk(k=k)[1]
    shape = tensor.shape
    out = []
    for dim in reversed(shape):
        out.append((index % dim).tolist())
        index = index // dim
    return list(zip(*tuple(reversed(out))))

model_names = ['Llama-2-7b-hf', 'gemma-7b', 'Mistral-7B-v0.1']
filenames = ['hh_harmless_prompt_last_token', 'hh_helpful_prompt_last_token']
for model_name in model_names:
    activations = []
    for filename in filenames:
        activation = torch.load(f'Alignment/output/activations/{model_name}/{filename}.pt')
        activations.append(activation)

    activation_std = torch.concat(activations, 0).std(0)
    index_by_std = topk_index(activation_std)
    torch.save((0, torch.tensor(index_by_std), 0, 0, 0, 0), f'Alignment/hooked_llama/neuron_activation/{model_name}_std_on_hh_prompt_last_token.pt')
    dist = (activations[0] - activations[1][:activations[0].shape[0]]).square().mean(0).sqrt()
    prompt_difference_rank = topk_index(dist)
    torch.save((0, torch.tensor(prompt_difference_rank), 0, 0, 0, 0), f'Alignment/hooked_llama/neuron_activation/{model_name}_difference_on_hh_prompt_last_token.pt')