import os
from pathlib import Path
import torch
from fade.data import DictionaryDataset
from tqdm import tqdm


def save_sparse_activations(
    activations, 
    save_dir=Path("actmax"), 
    layers_to_save=None
):
    """
    Save sparse activations for selected layers.

    Args:
        activations: {layer: {sample_id: {"values": tensor, "indices": tensor}}}
        save_dir (Path): Where to save
        layers_to_save (list[int] or None): Which layers to save. None = all
    """
    os.makedirs(save_dir, exist_ok=True)

    if layers_to_save is None:
        layers_to_save = list(activations.keys())

    for layer in layers_to_save:
        if not activations[layer]:  # skip if empty
            continue

        os.makedirs(save_dir / f"layer_{layer}", exist_ok=True)

        sparse_data = {}
        for i, act in activations[layer].items():
            # Convert values to sparse COO
            sparse_vals = act["values"].detach().cpu().to_sparse()

            # Keep only indices corresponding to nonzeros
            nz_features = sparse_vals.indices()[0]  # feature IDs with nonzero activation
            nz_tokens = act["indices"].detach().cpu()[nz_features]

            sparse_data[i] = {
                "values": sparse_vals,  # sparse tensor [d_mlp]
                "indices": {
                    "features": nz_features,   # feature IDs
                    "tokens": nz_tokens        # token positions
                },
            }

        torch.save(
            sparse_data,
            save_dir / f"layer_{layer}/{min(sparse_data.keys())}_{max(sparse_data.keys())}.pt"
        )


@torch.no_grad()
def generate_activations(
    model, 
    dataset, 
    device, 
    batch_size, 
    checkpoint=20, 
    layers_to_save=None, 
    save_dir="activations/"
):
    """
    Extract max activations + indices from specific model layers for all samples.

    Args:
        model: The model to extract activations from
        dataset: Dictionary of {id: text} pairs
        device: Device to run inference on
        batch_size: Number of samples per batch
        checkpoint: Save interval in terms of processed sample IDs
        layers_to_save (list[int] or None): Which layers to process. None = all
    """
    def simple_collate(batch):
        keys, values = zip(*batch)
        return list(keys), list(values)

    dataloader = torch.utils.data.DataLoader(
        DictionaryDataset(dataset), 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=simple_collate
    )

    n_layers = len(model.transcoders)
    if layers_to_save is None:
        layers_to_save = list(range(n_layers))

    stored_activations = {layer: {} for layer in layers_to_save}

    for batch_ids, batch_sequences in tqdm(dataloader, total=len(dataloader), mininterval=0.5):
        tokens = model.tokenizer(
            batch_sequences, 
            padding=True, 
            return_tensors="pt", 
            padding_side="left"
        ).to(device)

        _, activations = model.get_activations(tokens.input_ids)

        batch_len = activations.shape[2]
        for i, sample_id in enumerate(batch_ids):
            sequence_len = tokens["attention_mask"][i].sum().item()
            for layer in layers_to_save:
                # [seq_len, d_mlp]
                sample_activation = activations[layer, i, batch_len - sequence_len :, :].squeeze()

                # max values + indices along sequence dimension
                max_vals, max_indices = torch.max(sample_activation, dim=0)

                stored_activations[layer][sample_id] = {
                    "values": max_vals.detach().cpu(),       # [d_mlp]
                    "indices": max_indices.detach().cpu()    # [d_mlp]
                }

            # periodic save
            if (sample_id + 1) % checkpoint == 0:
                print(f"Saving activations at index {sample_id}")
                save_sparse_activations(stored_activations, layers_to_save=layers_to_save, save_dir=save_dir)
                stored_activations = {layer: {} for layer in layers_to_save}

    # flush remaining activations
    print("Final save of remaining activations...")
    save_sparse_activations(stored_activations, layers_to_save=layers_to_save, save_dir=save_dir)

    return stored_activations
