import torch
from utils import generate_activations
from circuit_tracer import ReplacementModel
from datasets import load_dataset
from pathlib import Path
import click
import wandb
if torch.cuda.is_available():
    device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


@click.command()
@click.option("--model_name", default="google/gemma-2-2b")
@click.option("--transcoder_name", default="gemma")
@click.option("--checkpoint", default=20)
@click.option("--batch_size", default=4)
@click.option("--save_dir", default="activations/gemma-2-2b/")
@click.option("--array_idx", type=int, default=0)
@click.option("--array_size", type=int, default=1)
@click.option("--layers-to-save", multiple=True, type=int)
def main(model_name, checkpoint, batch_size, save_dir, transcoder_name, array_idx, array_size, layers_to_save):
    layers_to_save = list(layers_to_save) if layers_to_save else None
    wandb.init(
        project="activations_caching",
        resume="allow",
        allow_val_change=True, 
        config={
            "model_name": model_name, 
            "checkpoint": checkpoint, 
            "batch_size": batch_size, 
            "array_idx": array_idx, 
            "array_size": array_size,
            "layers_to_save": layers_to_save
        }
    )
    print(
        f"\n[Job {array_idx}/{array_size}] Arguments:\n"
        f"  model_name      = {model_name}\n"
        f"  transcoder_name = {transcoder_name}\n"
        f"  save_dir        = {save_dir}\n"
        f"  batch_size      = {batch_size}\n"
        f"  checkpoint      = {checkpoint}\n"
        f"  array_idx       = {array_idx}\n"
        f"  array_size      = {array_size}\n"
        f"  layers_to_save  = {layers_to_save}\n"
    )

    dataset = load_dataset("parquet", data_files="data/transcoders_batch_1.parquet")
    dataset = {i: row["text"] for i, row in enumerate(dataset["train"])}

    # Total dataset size
    n = len(dataset)

    # Compute chunk boundaries for this job
    chunk_size = (n + array_size - 1) // array_size   # ceiling division
    start = array_idx * chunk_size
    end = min(start + chunk_size, n)

    # Slice the dataset for this array job
    dataset = {i: row for i, row in dataset.items() if start <= i < end}
    print(
    f"[Job {array_idx}/{array_size-1}] "
    f"Processing samples {start}–{end-1} "
    f"({len(dataset)} samples out of {n} total)"
    )

    # This can be for now hardcoded, we don't really analyze other models
    model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16, device=device)
    print("LOADED MODEL: ", model_name)
    # TODO: add batch size and checkpoint as hyperparameters
    generate_activations(
        model=model,
        dataset=dataset,
        device=device,
        batch_size=batch_size,
        checkpoint=checkpoint,
        save_dir=Path(save_dir), 
        layers_to_save=layers_to_save
    ) 


if __name__ == "__main__":
    main()