from transformers import  AutoTokenizer
import os
from data_utils import get_dataset_by_length
from  model_utils import  load_safetensors_model
import torch
from transformers import AutoTokenizer
import torch
import torch.nn as nn
from typing import Any, Dict,  Sequence, Tuple
import torch
from torch import nn
from typing import Sequence, Tuple, Dict
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from typing import Optional
import os
from  model_utils import  load_safetensors_model,print_model_summary,save_model
from  eval_utils import eval
import torch
import torch.nn as nn
import os
from examples.quant_llama import get_Linears



def get_lm_logits(inps_, model):
    """
    Compute language model logits for a LLaMA model.

    This function accepts hidden states input from a certain model layer (usually
    the last layer's activations), performs necessary normalization if the model
    contains a LayerNorm layer, and then projects through the language model head
    (lm_head) to the vocabulary dimension, obtaining logits that predict each token's score.

    Args:
        inps_ (torch.Tensor): Input hidden states tensor, usually shaped [seq_len, hidden_size]
                              or [batch_size, seq_len, hidden_size].
        model (PreTrainedModel): Pretrained LLaMA model object, which contains normalization and lm_head.

    Returns:
        torch.Tensor: Language model logits, shaped [1, seq_len, vocab_size] (batch dimension added).

    Notes:
        - The function internally adds batch dimension (unsqueeze(0)).
        - If the model contains normalization layer (model.model.norm), it is applied first.
        - Suitable for LLaMA and similar architectures.
    """
    hidden_states = inps_.unsqueeze(0)
    if model.model.norm is not None:
        hidden_states = model.model.norm(hidden_states)
    lm_logits = model.lm_head(hidden_states)
    print("LLaMA logits:", lm_logits)
    return lm_logits


def get_layers(model):
    layers = model.model.layers
    # print(f"Class name of layer 0: {layers[0].__class__.__name__}")
    # print("Child module names of layer 0:")
    # for name, module in layers[0].named_children():
    #     print(f" - {name}")
    return layers


@torch.no_grad()
def get_inps_llama_by_transformer(
    model,
    data,
    model_seqlen,
    device,
    offload_activations,
    layer_idx=0,
):
    layers = get_layers(model)
    target_layer = layers[layer_idx]

    dtype = next(model.parameters()).dtype
    print('Data dtype:', dtype)
    nsamples = len(data)
    target_device = torch.device("cpu") if offload_activations else device

    inps = torch.zeros(
        (nsamples, model_seqlen, model.config.hidden_size),
        dtype=dtype,
        device=target_device,
        pin_memory=offload_activations,
    )

    cache = {"i": 0}

    def hook_fn(module, input):
        # The first parameter of LlamaDecoderLayer is hidden_states, expected shape [batch, seq_len, hidden_size]
        hidden_states = input[0]
        idx = cache["i"]
        inps[idx] = hidden_states.detach().to(target_device)
        cache["i"] += 1

    hook = target_layer.register_forward_pre_hook(hook_fn)

    for i, batch_inp in enumerate(data):
        batch_inp = batch_inp.unsqueeze(0).to(device)  # Add batch dimension
        print(f"Processing batch {i}, batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")
        # Note: ensure attention_mask and other parameters satisfy model requirements
        model(batch_inp, attention_mask=torch.ones_like(batch_inp))

    hook.remove()

    assert cache["i"] == nsamples, f"Captured count mismatch: {cache['i']} vs {nsamples}"

    return inps, {}


@torch.no_grad()
def get_outs_llama_by_transformer(
    model,
    data,
    model_seqlen,
    device,
    offload_activations,
    layer_idx=0,
):
    layers = get_layers(model)
    target_layer = layers[layer_idx]

    dtype = next(model.parameters()).dtype
    print('Data dtype:', dtype)
    nsamples = len(data)
    target_device = torch.device("cpu") if offload_activations else device

    outs = torch.zeros(
        (nsamples, model_seqlen, model.config.hidden_size),
        dtype=dtype,
        device=target_device,
        pin_memory=offload_activations,
    )

    cache = {"i": 0}

    def hook_fn(module, input, output):
        # Unpack output, usually first element is the required hidden states
        hidden_states = output[0]
        idx = cache["i"]
        outs[idx] = hidden_states.detach().to(target_device)
        cache["i"] += 1

    hook = target_layer.register_forward_hook(hook_fn)

    for i, batch_inp in enumerate(data):
        batch_inp = batch_inp.unsqueeze(0).to(device)  # Add batch dimension
        print(f"Processing batch {i}, batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")
        # Ensure attention_mask and other parameters satisfy model requirements
        model(batch_inp, attention_mask=torch.ones_like(batch_inp))

    hook.remove()

    assert cache["i"] == nsamples, f"Captured count mismatch: {cache['i']} vs {nsamples}"

    return outs, {}



@torch.no_grad()
def get_outs_llama_by_linear(
    model: nn.Module,
    data: Sequence[torch.Tensor],
    model_seqlen: int,
    device: torch.device,
    offload_activations: bool,
    linear_layer_idx: int = 0,  # Linear layer index 0~223
) -> Tuple[torch.Tensor, Dict]:
    linear_names = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    n_layers = 32
    n_linears_per_layer = 7
    total_linears = n_layers * n_linears_per_layer

    assert 0 <= linear_layer_idx < total_linears, f"linear_layer_idx should be in range 0~{total_linears-1}"

    layer_idx = linear_layer_idx // n_linears_per_layer
    linear_name = linear_names[linear_layer_idx % n_linears_per_layer]

    print(f"Capturing linear layer index: {linear_layer_idx}, belonging to Transformer layer {layer_idx} '{linear_name}'", flush=True)

    layers = get_layers(model)
    assert len(layers) == n_layers, f"Model should have {n_layers} layers, actual: {len(layers)}"

    target_layer = layers[layer_idx].to(device)

    if linear_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        target_module = getattr(target_layer.self_attn, linear_name)
    else:
        target_module = getattr(target_layer.mlp, linear_name)

    target_module = target_module.to(device)

    dtype = next(model.parameters()).dtype
    nsamples = len(data)

    outs_list = []
    forward_args = {}

    def hook_fn(module, input, output):
        print(f"Hook captured output, shape={output.shape}, dtype={output.dtype}")
        outs_list.append(output.detach().cpu())  # Capture output

    hook_handle = target_module.register_forward_hook(hook_fn)

    for idx, batch_inp in enumerate(data):
        print(f"Processing batch {idx}")

        if batch_inp.dim() == 1:
            batch_inp = batch_inp.unsqueeze(0)  # Add batch dimension, shape becomes [1, seq_len]

        batch_inp = batch_inp.to(device)
        print(f"batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")

        if batch_inp.dtype != torch.long:
            print(f"Converting batch_inp to long")
            batch_inp = batch_inp.long()

        # Print sequence length info
        seq_len = batch_inp.size(1)
        print(f"Sequence length: {seq_len}")

        # Truncate if sequence length exceeds model max length, print warning
        max_seq_len = model_seqlen
        if seq_len > max_seq_len:
            print(f"Warning: sequence length {seq_len} exceeds model max length {max_seq_len}, truncating")
            batch_inp = batch_inp[:, :max_seq_len]

        # Generate attention_mask (assuming no padding; modify if padding exists)
        attention_mask = torch.ones(batch_inp.shape, dtype=torch.long, device=device)
        print(f"attention_mask.shape={attention_mask.shape}, dtype={attention_mask.dtype}")

        try:
            outputs = model(input_ids=batch_inp, attention_mask=attention_mask)
            # Optional: print output shape
            if hasattr(outputs, "last_hidden_state"):
                print(f"Model output last_hidden_state.shape={outputs.last_hidden_state.shape}")
            elif isinstance(outputs, tuple) and len(outputs) > 0:
                print(f"Model output first element shape={outputs[0].shape}")
        except Exception as e:
            print(f"Error during model call: {e}")
            raise

    hook_handle.remove()

    if len(outs_list) == 0:
        raise RuntimeError("No outputs captured, please ensure hook is correctly registered and model forward runs successfully")

    # Print example captured output shapes
    print(f"Number of outputs captured: {len(outs_list)}")
    print(f"Example captured output shape: {outs_list[0].shape}")

    outs = torch.stack(outs_list, dim=0)
    print(f"Stacked outs shape={outs.shape}")

    # Remove batch dimension if size 1
    if outs.size(1) == 1:
        outs = outs.squeeze(1)  # Resulting shape [nsamples, seq_len, hidden_size]

    target_device = torch.device("cpu") if offload_activations else device
    outs = outs.to(target_device)
    print("Capture completed")

    return outs, forward_args


def detect_problematic_channels(activations: torch.Tensor, threshold: float = 6.0):
    """
    Input:
      activations: torch.Tensor, shape [nsamples, seq_len, hidden_size]
      threshold: float, threshold multiple, default is 6

    Returns:
      problematic_channels: List[int], indices of channels exceeding threshold
      channel_means: torch.Tensor, shape [hidden_size], mean per channel
      channel_vars: torch.Tensor, shape [hidden_size], variance per channel
    """
    # Compute mean and variance per channel across all samples and sequence length
    # First reshape to [nsamples * seq_len, hidden_size]
    nsamples, seq_len, hidden_size = activations.shape
    flattened = activations.reshape(-1, hidden_size)  # [nsamples*seq_len, hidden_size]

    channel_means = flattened.mean(dim=0)  # [hidden_size]
    channel_vars = flattened.var(dim=0, unbiased=False)  # [hidden_size]

    # Compute mean absolute value per channel
    channel_abs_means = flattened.abs().mean(dim=0)  # [hidden_size]

    # Compute global mean absolute value of all elements
    global_abs_mean = flattened.abs().mean()

    # Find channels whose mean absolute value exceeds threshold times global mean absolute value
    problematic_mask = channel_abs_means > (threshold * global_abs_mean)

    problematic_channels = problematic_mask.nonzero(as_tuple=False).squeeze(1).tolist()

    return problematic_channels, channel_means, channel_vars


def compute_trace_covariance(inps):
    # inps shape: [nsamples, seq_len, hidden_size]
    N, L, D = inps.shape
    # flatten to 2D: (N*L, D)
    flattened = inps.reshape(-1, D)  # [N*L, D]

    # variance per dimension, dim=0 means across samples
    var_per_dim = torch.var(flattened, dim=0, unbiased=False)  # [D]

    # trace is sum of variances across all dimensions
    trace_c = torch.sum(var_per_dim).item()

    return trace_c


# Example usage for loading model and computing metrics (paths replaced with ...)
# import os
# import torch
# from transformers import AutoTokenizer

# # Set GPU to use
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# print("Loading model")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model_path = '...'

# model = load_safetensors_model(model_path)
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# model.to(device)
# model_seqlen = model.config.max_position_embeddings

# # 32 Transformer layers
# layers = model.model.layers
# L = len(layers)

# layer_attribute_map = {
#     "self_attn.q_proj": 1,
#     "self_attn.k_proj": 2,
#     "self_attn.v_proj": 3,
#     "self_attn.o_proj": 4,
#     "mlp.gate_proj": 5,
#     "mlp.up_proj": 6,
#     "mlp.down_proj": 7,
# }

# layer_norm_sums = []

# for i, layer in enumerate(layers):
#     norm_sq_sum = 0.0
#     for name, param in layer.named_parameters():
#         if name.endswith('weight'):
#             for key in layer_attribute_map.keys():
#                 if key in name:
#                     W = param.data
#                     frob_norm_sq = torch.norm(W, p='fro').item() ** 2
#                     norm_sq_sum += frob_norm_sq
#                     print(f"Layer {i} {name} norm_sq: {frob_norm_sq:.6f}")
#                     break
#     layer_norm_sums.append(norm_sq_sum)
#     print(f"Layer {i} total norm_sq_sum: {norm_sq_sum:.6f}")

# import math

# M_log = []
# for i in range(L):
#     if i == L - 1:
#         M_log.append(0.0)
#     else:
#         log_sum = 0.0
#         for k in range(i + 1, L):
#             val = layer_norm_sums[k]
#             if val <= 0:
#                 val = 1e-10
#             log_sum += math.log(val)
#         M_log.append(log_sum)

# with open('beta3.txt', 'w') as f:
#     for i, val in enumerate(M_log):
#         f.write(f"Layer {i}: {val}\n")

# print("Computation completed, results saved to beta3.txt")


# # Example to compute trace covariance for input activations
# # dataset = "wikitext-2"
# # data = get_dataset_by_length(dataset, tokenizer)

# # sample_0_to_5 = data[0:40]
# # sample_16_to_23 = data[441:561]
# # data = torch.cat((sample_0_to_5, sample_16_to_23), dim=0)
# # print("Shape:", data.shape)
# # seq_len = data.size(1)

# # start_layer = 0
# # end_layer = 224

# # with open('beta2.txt', 'w') as f:
# #     for layer_id in range(start_layer, end_layer):
# #         inps, forward_args = get_inps_llama_by_linear(model, data, seq_len, device, False, layer_id)
# #         trace_val = compute_trace_covariance(inps)
# #         print(f"Layer {layer_id}: input shape={inps.shape}, trace(C)={trace_val:.6f}")
# #         f.write(f"{layer_id} {trace_val:.6f}\n")
