import math
import torch
import torch.nn.functional as F
import functools
from collections.abc import Callable
from typing import Any, Optional, Tuple

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, \
    ALL_ATTENTION_FUNCTIONS

from utils.logging_utils import logger


class Llama3Model:
    """
    Llama3 wrapper (HF Llama architecture), adapted from your Llama3Model style.

    Typical HF IDs:
        "meta-llama/Llama-3.1-8B"
        "meta-llama/Llama-3.2-1B"
        "meta-llama/Llama-3.2-3B"
        (or your local path)

    Key notes vs Qwen3:
      1) No per-head q_norm/k_norm in Llama.
      2) GQA: key/value are projected to num_key_value_heads, then repeat_kv to num_heads.
      3) LlamaAttention.forward returns (attn_output, attn_weights, past_key_value).
      4) LlamaDecoderLayer.forward returns tuple outputs (hidden_states, [attn_weights], [present_key_value]).
    """

    def __init__(self, model_name: str, device: str = "cuda"):
        self.model_name = model_name
        self.device = device

        hf_kwargs = {
            "attn_implementation": "eager",  # strongly recommended for deterministic patching logic
        }

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, **hf_kwargs).to(device)

        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        # For LlamaForCausalLM: layers live at self.model.model.layers
        self.model_config = {
            "n_heads": self.model.config.num_attention_heads,
            "n_kv_heads": getattr(self.model.config, "num_key_value_heads", self.model.config.num_attention_heads),
            "n_layers": self.model.config.num_hidden_layers,
            "hidden_size": self.model.config.hidden_size,
            "head_dim": getattr(
                self.model.config,
                "head_dim",
                self.model.config.hidden_size // self.model.config.num_attention_heads,
            ),
            "name_or_path": getattr(self.model.config, "_name_or_path", model_name),
            "attn_hook_names": [f"model.layers.{i}.self_attn.o_proj" for i in
                                range(self.model.config.num_hidden_layers)],
            "layer_hook_names": [f"model.layers.{i}" for i in range(self.model.config.num_hidden_layers)],
            "prepend_bos": True,
        }

        self._orig_forwards = {"attn": [], "decoder": []}

        self.is_corrupted_run = False
        self.is_original_run = False

        self.TOKEN_TYPES = None
        self.tp_inds = None  # (B, L_total) expected

        self.token_type_map = []
        self.ablate_edges_map = []  # List[(layer_idx, head_idx, ablation_type)]

        self.corrupted_activations = None
        self.original_activations = None
        self.attention_head_activations = None

        self.q_vectors = []
        self.k_vectors = []
        self.v_vectors = []

        self.attention_scores = []

        self.kv_ablation_tp = None

        self.corruption_with_ablation = False
        self.corruption_ablate_edges_map = None

        self._log_memory_usage()

    def _log_memory_usage(self):
        if not torch.cuda.is_available():
            logger.info("CUDA not available; skip GPU memory usage logging.")
            return
        for d in range(torch.cuda.device_count()):
            t = torch.cuda.get_device_properties(d).total_memory
            r = torch.cuda.memory_reserved(d)
            a = torch.cuda.memory_allocated(d)
            logger.info(
                f"Device {d}, total_memory: {t / 1024 / 1024 / 1024:.4}Gb, "
                f"reserved: {r / 1024 / 1024 / 1024:.4}Gb, "
                f"allocated: {a / 1024 / 1024 / 1024:.4}Gb, "
                f"free: {(t - r) / 1024 / 1024 / 1024:.4}Gb"
            )

    def set_token_indices(self, tp_inds: torch.Tensor):
        self.tp_inds = tp_inds

    def set_ablation_task_config(
            self,
            token_types: int,
            ablate_edges_map: list[tuple[int, int, str]],
            token_type_map: list[int],
            kv_ablation_tp: list[int] = None,
            corruption_with_ablation: bool = False,
            corruption_ablate_edges_map=None,
    ):
        if corruption_ablate_edges_map is None:
            corruption_ablate_edges_map = []

        self.TOKEN_TYPES = token_types
        n_layers = len(self.model.model.layers)

        self.corrupted_activations = {"input_to_attn_per_type": {layer: None for layer in range(n_layers)}}
        self.attention_head_activations = {layer: None for layer in range(n_layers)}

        self.ablate_edges_map = ablate_edges_map
        self.token_type_map = token_type_map

        if kv_ablation_tp:
            self.kv_ablation_tp = torch.tensor(kv_ablation_tp, dtype=torch.int, device=self.device)
            self.original_activations = {"input_to_attn_per_type": {layer: None for layer in range(n_layers)}}

        self.corruption_with_ablation = corruption_with_ablation
        if self.corruption_with_ablation:
            self.corruption_ablate_edges_map = corruption_ablate_edges_map

    def break_into(self):
        self._orig_forwards = {"decoder": [], "attn": []}
        layers = self.model.model.layers

        for layer_idx in range(len(layers)):
            # attention
            self._orig_forwards["attn"].append(layers[layer_idx].self_attn.forward)
            layers[layer_idx].self_attn.forward = functools.partial(
                self.attn_forward,
                layer=layer_idx,
                self=layers[layer_idx].self_attn,
                model_wrapper=self,
            )

            # decoder layer
            self._orig_forwards["decoder"].append(layers[layer_idx].forward)
            layers[layer_idx].forward = functools.partial(
                self.decoder_forward,
                layer=layer_idx,
                self=layers[layer_idx],
                model_wrapper=self,
            )

    def break_out(self):
        layers = self.model.model.layers
        for layer_idx in range(len(layers)):
            layers[layer_idx].self_attn.forward = self._orig_forwards["attn"][layer_idx]
            layers[layer_idx].forward = self._orig_forwards["decoder"][layer_idx]

    @staticmethod
    def attn_forward(
            self,  # corresponds to model.layers[layer].self_attn (Llama3Attention)
            hidden_states: torch.Tensor,  # head -> tp -> (B,L,HIDDEN)
            position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,  # accepted but not relied on
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Any] = None,  # Llama3 uses this name (Cache)
            cache_position: Optional[torch.LongTensor] = None,
            layer: Optional[int] = None,
            model_wrapper: Optional[Any] = None,
            **kwargs,
    ):
        input_shape = hidden_states.shape[:-1]  # Shape: (B,L)
        hidden_shape = (*input_shape, -1, self.head_dim)  # Shape: (B,L,H,D)

        num_heads = self.config.num_attention_heads

        # Calculate q, k, v projections corresponding to each head and token type
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)  # Shape: (B,H,L,D)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # Q/K/V Patching Flag
        q_patch_flag = k_patch_flag = v_patch_flag = False
        q_extract_list = []
        k_extract_list = []
        v_extract_list = []

        if model_wrapper is not None and (not model_wrapper.is_corrupted_run) and (not model_wrapper.is_original_run):
            for (layer_idx, head_idx, ablation_type) in model_wrapper.ablate_edges_map:
                if layer_idx is None or layer_idx == layer:
                    if ablation_type == "q_replace":
                        q_patch_flag = True
                    elif ablation_type == "k_replace":
                        k_patch_flag = True
                    elif ablation_type == "v_replace":
                        v_patch_flag = True
                    elif ablation_type == "q_extract":
                        q_extract_list.append(head_idx)
                    elif ablation_type == "k_extract":
                        k_extract_list.append(head_idx)
                    elif ablation_type == "v_extract":
                        v_extract_list.append(head_idx)

        # Apply Q patching
        if q_patch_flag:
            # Note: corrupted seq_len may be different from clean seq_len due to different tokenization
            corrupted_tensor = model_wrapper.corrupted_activations["input_to_attn_per_type"][layer].to(
                device=hidden_states.device,
                dtype=self.q_proj.weight.dtype,
            )  # Shape: (B, corrupted_seq_len, HIDDEN)

            corrupted_seq_len = corrupted_tensor.shape[1]

            q_patched = self.q_proj(corrupted_tensor).view(hidden_shape[0], corrupted_seq_len, num_heads,
                                                           self.head_dim).transpose(1, 2)
            query_states[:, :, -1, :] = q_patched[:, :, -1, :]  # Shape: (B,H,L,D)

        # Apply Rotary PositionalEmbeddings
        cos, sin = position_embeddings  # Shape: (B, L, D)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )  # attn_output Shape: (B,L,H,D), attn_weights Shape: (B,H,L,L)

        if k_patch_flag or v_patch_flag:
            # construct patched k/v
            q_original = query_states
            k_original = key_states
            k_used = key_states
            v_used = value_states

            original_tensor = model_wrapper.original_activations["input_to_attn_per_type"][layer].to(
                device=hidden_states.device,
                dtype=self.k_proj.weight.dtype,
            )

            corrupted_tensor = model_wrapper.corrupted_activations["input_to_attn_per_type"][layer].to(
                device=hidden_states.device,
                dtype=self.k_proj.weight.dtype,
            )

            if k_patch_flag:
                # Keep original Q unchanged for K ablation
                q_original = self.q_proj(original_tensor).view(hidden_shape).transpose(1, 2)
                k_patched = self.k_proj(corrupted_tensor).view(hidden_shape).transpose(1, 2)
                # rotate original Q & patching K vector
                q_original, k_patched = apply_rotary_pos_emb(q_original, k_patched, cos, sin)

            if v_patch_flag:
                # Keep original Q&K unchanged for V ablation
                q_original = self.q_proj(original_tensor).view(hidden_shape).transpose(1, 2)
                k_original = self.k_proj(original_tensor).view(hidden_shape).transpose(1, 2)
                # rotate original QK vector
                q_original, k_original = apply_rotary_pos_emb(q_original, k_original, cos, sin)
                v_patched = self.v_proj(corrupted_tensor).view(hidden_shape).transpose(1, 2)

            # mask: choose which key/value positions to replace
            tp_pos = None
            if getattr(model_wrapper, "kv_ablation_tp", None) is not None and getattr(model_wrapper, "tp_inds",
                                                                                      None) is not None:
                tp_pos = torch.isin(
                    model_wrapper.tp_inds,
                    model_wrapper.kv_ablation_tp.to(device=model_wrapper.tp_inds.device)
                )  # (B, L), bool

            if tp_pos is not None:
                tp_pos4 = tp_pos.to(device=hidden_states.device).unsqueeze(1).unsqueeze(-1)  # (B,1,L,1)
                if k_patch_flag:
                    k_used = torch.where(tp_pos4, k_patched, key_states)  # Shape: (B,H,L,D)
                    query_states = q_original
                    key_states = k_used
                if v_patch_flag:
                    v_used = torch.where(tp_pos4, v_patched, value_states)  # Shape: (B,H,L,D)
                    query_states = q_original
                    key_states = k_original
                    value_states = v_used

            # Recompute attention for the last query only
            attn_last, attn_w_last = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

            # Update only the last position in attn_output and attn_weights
            # attn_output Shape: (B,L,H,D), attn_weights Shape: (B,H,L,L)
            attn_output[:, -1, :, :] = attn_last[:, -1, :, :]
            attn_weights[:, :, -1, :] = attn_w_last[:, :, -1, :]

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()  # Shape: (B,L,H*D)
        layer_activations = attn_output.clone().detach()
        attn_output = self.o_proj(attn_output)  # Shape: (B,L,HIDDEN)

        # QKV Extraction
        if not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run:
            for head in range(num_heads):
                if head in q_extract_list:
                    model_wrapper.q_vectors.append(query_states[:, head, -1, :].detach().cpu().clone().squeeze(
                        0))  # Shape: (head_dim, )
                kv_head = head // self.num_key_value_groups
                if head in k_extract_list:
                    model_wrapper.k_vectors.append(key_states[:,
                    kv_head, :,
                    :].detach().cpu().clone().squeeze(0))  # Shape: (seq_len, head_dim)
                if head in v_extract_list:
                    model_wrapper.v_vectors.append(value_states[:,
                    kv_head, :,
                    :].detach().cpu().clone().squeeze(0))  # Shape: (seq_len, head_dim)

            # In corrupted run, we do not store any attention scores
            # Aggregate attention scores on token types
            attn_scores = attn_weights.clone().detach()[-1, :, -1, :]  # Shape: (num_heads, seq_len)
            attn_scores_col = []
            for tp in range(model_wrapper.TOKEN_TYPES):
                attn_scores_col.append(
                    attn_scores[:, model_wrapper.tp_inds[-1] == tp].sum(dim=-1, keepdim=True).cpu())
            attn_score_dict = model_wrapper.attention_scores[-1]
            attn_scores_tp = torch.cat(attn_scores_col, dim=1)
            attn_score_dict[layer] = attn_scores_tp

            # Assign the attention activations to the attention head activations before projection
            model_wrapper.attention_head_activations[
                layer] = layer_activations  # Shape: (batch_size=1, seq_len, n_heads * head_dim)

        return attn_output, attn_weights

    @staticmethod
    def decoder_forward(
            hidden_states: torch.Tensor,  # (B,L,hidden)
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[Any] = None,  # Llama3 uses this name
            use_cache: Optional[bool] = False,
            cache_position: Optional[torch.LongTensor] = None,
            position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
            self=None,  # corresponds to model.layers[layer]
            layer=None,
            model_wrapper=None,
            **kwargs,
    ):
        # Llama3 LAYER ORDER differs from Gemma2:
        #   residual = x
        #   x = input_layernorm(x)
        #   x = self_attn(x)
        #   x = residual + x
        #   residual = x
        #   x = post_attention_layernorm(x)
        #   x = mlp(x)
        #   x = residual + x
        #   return x

        residual = hidden_states  # Shape: (B,L,hidden)
        hidden_states = self.input_layernorm(hidden_states)

        # Cache input-to-attn activations and apply ablation logic
        if model_wrapper.is_corrupted_run:
            model_wrapper.corrupted_activations["input_to_attn_per_type"][layer] = (
                hidden_states.detach().clone().cpu()
            )  # Shape: (B,L,hidden)

            if model_wrapper.corruption_with_ablation:
                for (from_tp, to_tp, ablation_type) in model_wrapper.corruption_ablate_edges_map:
                    if ablation_type == "zero":
                        B, _, L, _ = attention_mask.shape
                        mask_q = (model_wrapper.tp_inds == to_tp).unsqueeze(2).expand(B, L, L)
                        mask_k = (model_wrapper.tp_inds == from_tp).unsqueeze(1).expand(B, L, L)
                        zero_positions = mask_q & mask_k
                        attention_mask = attention_mask.masked_fill(zero_positions.unsqueeze(1), -torch.inf)
        else:
            if model_wrapper.is_original_run:
                model_wrapper.original_activations["input_to_attn_per_type"][layer] = (
                    hidden_states.detach().clone().cpu()
                )  # Shape: (B,L,hidden)

            for (from_tp, to_tp, ablation_type) in model_wrapper.ablate_edges_map:
                if ablation_type == "zero":
                    B, _, L, _ = attention_mask.shape
                    mask_q = (model_wrapper.tp_inds == to_tp).unsqueeze(2).expand(B, L, L)
                    mask_k = (model_wrapper.tp_inds == from_tp).unsqueeze(1).expand(B, L, L)
                    zero_positions = mask_q & mask_k
                    attention_mask = attention_mask.masked_fill(zero_positions.unsqueeze(1), -torch.inf)

        # Self Attention (calls our overridden attn_forward)
        attn_out, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )

        # Llama3: residual add immediately, THEN post_attention_layernorm -> mlp -> residual add
        hidden_states = residual + attn_out

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        # IMPORTANT Llama3 DIFFERENCE:
        #   - Must return Tensor, not tuple (HF Llama3Model.forward expects tensor per layer).
        return (hidden_states,)