import torch

def get_activation_recorder(cache_container, name, n_heads=None, d_head=None):
    # value[b] stores the activation for each batch sample, can be directly used for different tasks
    def hook_fn(module, inp, outp):
        value = outp[0] if isinstance(outp, tuple) else outp  # [batch, seq, d_model]
        value = value.detach().cpu()
        # Only reshape for LLaMA attention layers
        if n_heads is not None and d_head is not None:
            batch_size, seq_len, d_model = value.shape
            assert d_model == n_heads * d_head, f"d_model({d_model}) != n_heads({n_heads}) * d_head({d_head})"
            value = value.view(batch_size, seq_len, n_heads, d_head)  # [batch, seq, n_heads, d_head]
            # print(f"[hook] {name} after reshape: {value.shape}")
        if isinstance(cache_container, list):
            batch_size = value.shape[0]
            for b in range(batch_size):
                cache_container[b][name] = value[b]
        elif isinstance(cache_container, dict):
            cache_container[name] = value[0] if value.shape[0] == 1 else value
        else:
            raise TypeError(f"cache_container must be list or dict, got {type(cache_container)}")
    return hook_fn


def register_llava_hooks(model, cache_list):
    hooks = []
    # Vision Tower
    for i, layer in enumerate(model.vision_tower.vision_model.encoder.layers):
        hooks.append(layer.self_attn.register_forward_hook(
            get_activation_recorder(cache_list, f"clip_layer{i}_attn")))
        hooks.append(layer.mlp.register_forward_hook(
            get_activation_recorder(cache_list, f"clip_layer{i}_mlp")))
    # Projector
    hooks.append(model.multi_modal_projector.register_forward_hook(
        get_activation_recorder(cache_list, "mm_projector")))
    # LLaMA Main
    n_heads = model.config.text_config.num_attention_heads
    d_head = model.config.text_config.hidden_size // n_heads
    for i, layer in enumerate(model.language_model.layers):
        # For LLaMA attention, decompose by head
        hooks.append(layer.self_attn.register_forward_hook(
            get_activation_recorder(cache_list, f"llama_layer{i}_attn", n_heads, d_head)))
        # Not used for MLP layer
        hooks.append(layer.mlp.register_forward_hook(
            get_activation_recorder(cache_list, f"llama_layer{i}_mlp")))
    return hooks

def register_llama_attn_hooks(model, cache_list):
    hooks = []
    # LLaMA Main
    n_heads = model.config.text_config.num_attention_heads
    d_head = model.config.text_config.hidden_size // n_heads
    for i, layer in enumerate(model.language_model.layers):
        # For LLaMA attention, decompose by head
        hooks.append(layer.self_attn.register_forward_hook(
            get_activation_recorder(cache_list, f"llama_layer{i}_attn", n_heads, d_head)))
    return hooks


def run_with_cache_llava(model, processor, image, prompt, device="cuda"):
    """
    Adapted for LLaVA-1.5-7B-HF, single image and single prompt, returns logits and cache
    """
    # Only use batch API, input 1 image + 1 prompt
    logits, cache = batch_run_with_cache_llava(
        model, processor, [{"image": image, "prompt": prompt}], device=device
    )
    # logits: (1, seq, vocab)
    for k in cache:
        cache[k] = cache[k][0]  # Extract the single sample
    return logits[0], cache



def batch_run_with_cache_llava(model, processor, samples, device="cuda", batch_size=4):
    """
    Equivalent to large batch inference. Internally splits into small batches, lower VRAM pressure per batch, results are concatenated at the end.
    Returns:
      logits: [total samples, ...]
      cache: dict, all activations, equivalent to full-batch processing
    """
    # For accumulating results
    logits_list = []
    cache_list = []

    # Batch processing
    for i in range(0, len(samples), batch_size):
        batch_samples = samples[i:i+batch_size]
        images = [s["image"] for s in batch_samples]
        prompts = [s["prompt"] for s in batch_samples]
        cache = {}  # Cache for each small batch
        hooks = register_llama_attn_hooks(model, cache)

        inputs = processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding=True
        )
        for k, v in inputs.items():
            if k == "pixel_values":
                inputs[k] = v.to(device, dtype=torch.float16)
            else:
                inputs[k] = v.to(device)

        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits

        for h in hooks:
            h.remove()
        logits_list.append(logits.cpu())  # Concatenate logits
        # All layer activations in cache, move to cpu, concatenate in order
        cache_list.append({k: v.cpu() for k, v in cache.items()})

        del logits, cache, outputs, hooks, inputs
        torch.cuda.empty_cache()

    # Concatenate logits
    all_logits = torch.cat(logits_list, dim=0)

    # Concatenate cache (use layer name as key, concatenate by batch dim)
    all_cache = {}
    for k in cache_list[0].keys():
        all_cache[k] = torch.cat([cl[k] for cl in cache_list], dim=0)

    return all_logits, all_cache
