import os
import pickle
from typing import Any

import torch as t
from acdc.TLACDCExperiment import TLACDCExperiment
from jaxtyping import Float, Int
from torch import Tensor


def run_model_batched(
    model, toks: Int[Tensor, "batch seq"], batch_size: int = 100
) -> Float[Tensor, "batch seq d_vocab"]:
    """
    Run the model on the toks in batches of size batch_size
    """
    n = len(toks)
    logits = []
    for i in range(0, n, batch_size):
        logits.append(model(toks[i : i + batch_size]))
    return t.cat(logits, dim=0)


def save_with_pickle(data: Any, filename: str) -> None:
    try:
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        with open(filename, "wb") as file:
            pickle.dump(data, file)
    except IOError as e:
        print(f"Failed to save file {filename}: {e}")
        raise


def load_with_pickle(filename: str) -> Any:
    try:
        with open(filename, "rb") as file:
            return pickle.load(file)
    except IOError as e:
        print(f"Failed to load file {filename}: {e}")
        raise


def add_all_hooks(experiment: TLACDCExperiment) -> TLACDCExperiment:
    """
    Adds all edges to the experiment object.
    """
    experiment.model.reset_hooks()
    experiment.setup_model_hooks(
        add_sender_hooks=True, add_receiver_hooks=True, doing_acdc_runs=False
    )
    return experiment
