import click
import wandb
import torch
from circuit_tracer import ReplacementModel
from pathlib import Path


import sys
sys.path.append("src/") 
from src.model_utils import *
from src.circuit_analysis_tools import *
from src.visualizations import visualize_sampled_activations_heatmap
from src.input_invariant_feature_description import *

@click.command()
@click.option("--model_name")
def main(model_name):
    wandb.init(
        project="autointerp",
        name=model_name,
        resume="allow",
        allow_val_change=True
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if "gemma" in model_name:
        transcoder_name = "gemma"
        model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16, device=device)
        save_dir = Path("results/gemma_2_2b/")

    analyze_all_features(model, save_dir=save_dir)

if __name__ == "__main__":
    main()