import transformer_lens
import torch
# Load a model (eg GPT-2 Small)
model = transformer_lens.HookedTransformer.from_pretrained("/data/ssliangruc3/models/opt-1.3b")

# Run the model and get logits and activations
logits, cache = model.run_with_cache("Hello World")
#print('logits',logits,'activations',activations)
residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
print('labels',labels)
answer = "Thanks"
logit_attrs = cache.logit_attrs(residual_stream, answer)

most_important_component_idx = torch.argmax(logit_attrs)
print(labels[most_important_component_idx])
