from functools import partial
from einops import einsum, rearrange
import logging
import torch
import torch.nn as nn
import re
import os
from discovery_utils import ablate_head_llava, ablate_mlp_llava, ablate_head_llava_plus, dynamic_counterbalance_head_llava, dynamic_counterbalance_head_llava_dosample

def get_mean_activations_llava(caches, key, batch_indices=None):
    # caches: list of cache dict (one dict per sample)
    # key: activation name, such as 'llama_layer12_attn'
    # batch_indices: optional, specify which batch samples to collect mean from
    # usage: used for providing per-layer/per-head mean for mean ablation
    values = []
    if batch_indices is None:
        batch_indices = range(len(caches))
    for i in batch_indices:
        if key in caches[i]:
            values.append(caches[i][key])
    if not values:
        raise ValueError(f"No values found for key {key}")
    # Stack and take the mean over batch dimension
    mean_value = torch.stack(values, dim=0).mean(dim=0)
    return mean_value

def register_pruning_hooks_llava(model, circuit_heads, circuit_mlps, patching_cache, ablation_scheme="mean"):
    # usage: only register nodes that need to be ablated
    hooks = []
    n_layers = model.config.text_config.num_hidden_layers
    n_heads = model.config.text_config.num_attention_heads
    # Head ablation
    for layer in range(n_layers):
        for head in range(n_heads):
            if [layer, head] not in circuit_heads:
                h = model.model.language_model.layers[layer].self_attn.register_forward_hook(
                    ablate_head_llava(layer, head, ablation_scheme, patching_cache)
                )
                hooks.append(h)
    # MLP ablation
    for layer in range(n_layers):
        if layer not in circuit_mlps:
            h = model.model.language_model.layers[layer].mlp.register_forward_hook(
                ablate_mlp_llava(layer, ablation_scheme, patching_cache)
            )
            hooks.append(h)
    return hooks

def _single_generate_llava(model, tokenizer, processor, image, prompt, device="cuda", max_new_tokens=32):
    # Generate a single sample, no retry logic
    inputs = processor(
        text=[prompt],
        images=[image],
        return_tensors="pt",
        padding=True
    )
    for k, v in inputs.items():
        if k == "pixel_values":
            inputs[k] = v.to(device, dtype=torch.float16)
        elif k == "input_ids":
            inputs[k] = v.to(device, dtype=torch.long)
        else:
            inputs[k] = v.to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            attention_mask=inputs.get("attention_mask", None),
            max_new_tokens=max_new_tokens,
            do_sample=False
        )
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return extract_assistant_answer(text)

def _batch_generate_llava_once(model, tokenizer, processor, images, prompts, device="cuda", max_new_tokens=32):
    # Single batch inference
    inputs = processor(
        text=prompts,
        images=images,
        return_tensors="pt",
        padding=True
    )
    for k, v in inputs.items():
        if k == "pixel_values":
            inputs[k] = v.to(device, dtype=torch.float16)
        elif k == "input_ids":
            inputs[k] = v.to(device, dtype=torch.long)
        else:
            inputs[k] = v.to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            attention_mask=inputs.get("attention_mask", None),
            max_new_tokens=max_new_tokens,
            do_sample=False,
            # top_p=0.9,
            # repetition_penalty=1.1,
        )
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    clean_texts = [extract_assistant_answer(t) for t in texts]
    return clean_texts

def _batch_generate_llava_once_dosample(model, tokenizer, processor, images, prompts, device="cuda", max_new_tokens=32):
    # Single batch inference (with do_sample=True)
    inputs = processor(
        text=prompts,
        images=images,
        return_tensors="pt",
        padding=True
    )
    for k, v in inputs.items():
        if k == "pixel_values":
            inputs[k] = v.to(device, dtype=torch.float16)
        elif k == "input_ids":
            inputs[k] = v.to(device, dtype=torch.long)
        else:
            inputs[k] = v.to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            attention_mask=inputs.get("attention_mask", None),
            max_new_tokens=max_new_tokens,
            do_sample=True,
            # top_p=0.9,
            # repetition_penalty=1.1,
        )
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    clean_texts = [extract_assistant_answer(t) for t in texts]
    return clean_texts



def safe_batch_generate(model, tokenizer, processor, images, prompts, device, max_new_tokens, min_len=10, max_retries=3):
    for _ in range(max_retries):
        outs = _batch_generate_llava_once(
            model, tokenizer, processor,
            images, prompts,
            device=device, max_new_tokens=max_new_tokens
        )
        # outs: List[str]
        if all(len(o.strip()) >= min_len for o in outs):
            return outs
    return None  # Return None if all retries fail

def batch_generate_llava(model, tokenizer, processor, images, prompts, device="cuda", max_new_tokens=32):
    min_len = 3
    retry_times = 3
    assert len(images) == len(prompts), "Each image requires a prompt"
    # 1. Main batch inference
    clean_texts = _batch_generate_llava_once(
        model, tokenizer, processor, images, prompts, device, max_new_tokens
    )

    # 2. Batch retry
    idx_need_retry = [i for i, t in enumerate(clean_texts) if len(t.split()) < min_len]
    for retry_round in range(retry_times):
        if not idx_need_retry:
            break
        retry_imgs = [images[i] for i in idx_need_retry]
        retry_prompts = [prompts[i] for i in idx_need_retry]
        print(f"[INFO] Batch retry round {retry_round+1}: {len(idx_need_retry)} samples.")
        retry_texts = _batch_generate_llava_once(
            model, tokenizer, processor, retry_imgs, retry_prompts, device, max_new_tokens
        )
        # Write back results to corresponding positions, count remaining failures
        next_idx_need_retry = []
        for idx, retry_text in zip(idx_need_retry, retry_texts):
            if len(retry_text.split()) >= min_len:
                clean_texts[idx] = retry_text
                print(f"[INFO] Sample {idx} succeeded after batch retry.")
            else:
                next_idx_need_retry.append(idx)
        idx_need_retry = next_idx_need_retry

    return clean_texts

def batch_generate_llava_dosample(model, tokenizer, processor, images, prompts, device="cuda", max_new_tokens=32):
    min_len = 3
    retry_times = 3
    assert len(images) == len(prompts), "Each image requires a prompt"
    # 1. Main batch inference
    clean_texts = _batch_generate_llava_once_dosample(
        model, tokenizer, processor, images, prompts, device, max_new_tokens
    )

    # 2. Batch retry
    idx_need_retry = [i for i, t in enumerate(clean_texts) if len(t.split()) < min_len]
    for retry_round in range(retry_times):
        if not idx_need_retry:
            break
        retry_imgs = [images[i] for i in idx_need_retry]
        retry_prompts = [prompts[i] for i in idx_need_retry]
        print(f"[INFO] Batch retry round {retry_round+1}: {len(idx_need_retry)} samples.")
        retry_texts = _batch_generate_llava_once_dosample(
            model, tokenizer, processor, retry_imgs, retry_prompts, device, max_new_tokens
        )
        # Write back results to corresponding positions, count remaining failures
        next_idx_need_retry = []
        for idx, retry_text in zip(idx_need_retry, retry_texts):
            if len(retry_text.split()) >= min_len:
                clean_texts[idx] = retry_text
                print(f"[INFO] Sample {idx} succeeded after batch retry.")
            else:
                next_idx_need_retry.append(idx)
        idx_need_retry = next_idx_need_retry

    return clean_texts

def extract_assistant_answer(generated_text: str) -> str:
    """
    Remove the prompt from the model output and extract the pure answer content.
    """
    text = generated_text.strip()
    # Common prompt prefix patterns
    prompt_patterns = [
        r"^<image>\s*Please describe the image in detail\.\s*",
        r"^Please describe the image in detail\.\s*"
    ]
    for pat in prompt_patterns:
        text = re.sub(pat, "", text, flags=re.IGNORECASE)
    # If contains ASSISTANT: tag, split again
    if "ASSISTANT:" in text:
        text = text.rsplit("ASSISTANT:", 1)[-1]
    return text.strip()


def prune_model_llava(model, circuit_heads, circuit_mlps, patching_cache, ablation_scheme="mean", target_layers=None, include_mlps=False, alpha=0.0, global_indices=4):
    """
    Batch register ablation hooks for all irrelevant heads/MLPs in the LLaVA model, to achieve a "pruned" model.
    Returns the list of registered hooks (for h.remove() cleanup later).
    """
    hooks = []
    n_layers = model.config.text_config.num_hidden_layers
    n_heads = model.config.text_config.num_attention_heads
    # Only traverse the target layers
    if target_layers is None:
        target_layers = list(range(n_layers))
    # Ablate heads in circuit
    for layer in target_layers:
        for head in range(n_heads):
            if [layer, head] in circuit_heads:
                h = model.model.language_model.layers[layer].self_attn.register_forward_hook(
                    ablate_head_llava_plus(layer, head, ablation_scheme, patching_cache, alpha, global_indices)
                )
                hooks.append(h)
    # Ablate MLPs in circuit
    if include_mlps:
        for layer in target_layers:
            if layer in circuit_mlps:
                h = model.model.language_model.layers[layer].mlp.register_forward_hook(
                    ablate_mlp_llava(layer, ablation_scheme, patching_cache)
                )
                hooks.append(h)
        logging.info(f"MLP ablation executed. Retained MLP layers: {circuit_mlps}")
    return hooks

def prune_model_llava_dynamic(model, faithful_heads, hallucination_heads,
                              target_layers=None, alpha=0.2, beta=0.9): 
    hooks = []
    n_layers = model.config.text_config.num_hidden_layers
    n_heads = model.config.text_config.num_attention_heads

    # print(f"[DEBUG prune] alpha={alpha}, beta={beta}")

    # for layer in model.model.language_model.layers:
    #     layer.self_attn._forward_hooks.clear()

    if target_layers is None:
        target_layers = list(range(n_layers))

    for layer in target_layers:
        # —— Only take head_idx for this layer ——
        f_this = [h for (lay, h) in faithful_heads       if lay == layer]
        g_this = [h for (lay, h) in hallucination_heads  if lay == layer]
        # If this layer has neither "faithful" nor "hallucination" heads, skip it
        if not (f_this or g_this):
            continue

        for head_idx in f_this + g_this:
            handle = model.model.language_model.layers[layer] \
                          .self_attn.register_forward_hook(
                dynamic_counterbalance_head_llava(
                    layer,
                    head_idx,
                    faithful_heads=f_this,
                    hallucination_heads=g_this,
                    alpha=alpha,
                    beta=beta
                )
            )
            hooks.append(handle)
    return hooks

def prune_model_llava_dynamic_dosample(model, faithful_heads, hallucination_heads,
                              target_layers=None, alpha=0.2, beta=0.9): 
    hooks = []
    n_layers = model.config.text_config.num_hidden_layers
    n_heads = model.config.text_config.num_attention_heads

    # print(f"[DEBUG prune] alpha={alpha}, beta={beta}")

    # for layer in model.model.language_model.layers:
    #     layer.self_attn._forward_hooks.clear()

    if target_layers is None:
        target_layers = list(range(n_layers))

    for layer in target_layers:
        # —— Only take head_idx for this layer ——
        f_this = [h for (lay, h) in faithful_heads       if lay == layer]
        g_this = [h for (lay, h) in hallucination_heads  if lay == layer]
        # If this layer has neither "faithful" nor "hallucination" heads, skip it
        if not (f_this or g_this):
            continue

        for head_idx in f_this + g_this:
            handle = model.model.language_model.layers[layer] \
                          .self_attn.register_forward_hook(
                dynamic_counterbalance_head_llava_dosample(
                    layer,
                    head_idx,
                    faithful_heads=f_this,
                    hallucination_heads=g_this,
                    alpha=alpha,
                    beta=beta
                )
            )
            hooks.append(handle)
    return hooks
