# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from transformers import  AutoTokenizer
import os
from src.data_utils import get_dataset_by_length
from  src.model_utils import  load_safetensors_model,load_safetensors_model,print_model_summary,save_model
import torch
from transformers import AutoTokenizer
import torch
import torch.nn as nn
from typing import Any, Dict,  Sequence, Tuple
import torch
from torch import nn
from typing import Sequence, Tuple, Dict
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from typing import Optional
import os

from  src.eval_utils import eval
import torch
import torch.nn as nn
import os

import torch
from torch import nn
from typing import Sequence, Tuple, Dict


def get_Linears(model):
    linear_layers = []
    for name, layer in model.named_modules():
        if isinstance(layer, torch.nn.Linear):
            layer_index = None
            if 'layers.' in name:
                layer_index_str = name.split('layers.')[1].split('.')[0]
                layer_index = int(layer_index_str)

            print(
                f"Layer Index: {layer_index}, Layer Name: {name}, Input Features: {layer.in_features}, Output Features: {layer.out_features}")
            linear_layers.append((layer_index, name, layer))
    return linear_layers


def get_lm_logits(inps_, model):
    """
    Compute language model logits for LLaMA model.

    This function takes hidden states input from a certain layer of the model
    (usually the last layer's activations), performs necessary normalization
    (if the model contains a LayerNorm layer), then maps it through the language
    model head (lm_head) to the vocabulary dimension, producing logits for each token.

    Args:
        inps_ (torch.Tensor): Input hidden states tensor, usually shaped [seq_len, hidden_size]
                             or [batch_size, seq_len, hidden_size].
        model (PreTrainedModel): Pretrained LLaMA model object containing norm and lm_head.

    Returns:
        torch.Tensor: Language model logits, shaped [1, seq_len, vocab_size] (batch dimension added).

    Notes:
        - The function adds batch dimension internally (unsqueeze(0)).
        - If model has normalization layer (model.model.norm), it is applied first.
        - Suitable for LLaMA and similar architectures.
    """
    hidden_states = inps_.unsqueeze(0)
    if model.model.norm is not None:
        hidden_states = model.model.norm(hidden_states)
    lm_logits = model.lm_head(hidden_states)
    print("LLaMA logits:", lm_logits)
    return lm_logits


def get_layers(model):
    layers = model.model.layers
    # print(f"Class name of layer 0: {layers[0].__class__.__name__}")
    # print("Child module names of layer 0:")
    # for name, module in layers[0].named_children():
    #     print(f" - {name}")
    return layers


@torch.no_grad()
def get_inps_llama_by_transformer(
        model,
        data,
        model_seqlen,
        device,
        offload_activations,
        layer_idx=0,
):
    layers = get_layers(model)
    target_layer = layers[layer_idx]

    dtype = next(model.parameters()).dtype
    print('Data dtype:', dtype)
    nsamples = len(data)
    target_device = torch.device("cpu") if offload_activations else device

    inps = torch.zeros(
        (nsamples, model_seqlen, model.config.hidden_size),
        dtype=dtype,
        device=target_device,
        pin_memory=offload_activations,
    )

    cache = {"i": 0}

    def hook_fn(module, input):
        # The first parameter of LlamaDecoderLayer is hidden_states, typically [batch, seq_len, hidden_size]
        hidden_states = input[0]
        idx = cache["i"]
        inps[idx] = hidden_states.detach().to(target_device)
        cache["i"] += 1

    hook = target_layer.register_forward_pre_hook(hook_fn)

    for i, batch_inp in enumerate(data):
        batch_inp = batch_inp.unsqueeze(0).to(device)  # Add batch dimension
        print(f"Processing batch {i}, batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")
        # Make sure attention_mask and other parameters meet model requirements
        model(batch_inp, attention_mask=torch.ones_like(batch_inp))

    hook.remove()

    assert cache["i"] == nsamples, f"Captured count mismatch: {cache['i']} vs {nsamples}"

    return inps, {}


@torch.no_grad()
def get_outs_llama_by_transformer(
        model,
        data,
        model_seqlen,
        device,
        offload_activations,
        layer_idx=0,
):
    layers = get_layers(model)
    target_layer = layers[layer_idx]

    dtype = next(model.parameters()).dtype
    print('Data dtype:', dtype)
    nsamples = len(data)
    target_device = torch.device("cpu") if offload_activations else device

    outs = torch.zeros(
        (nsamples, model_seqlen, model.config.hidden_size),
        dtype=dtype,
        device=target_device,
        pin_memory=offload_activations,
    )

    cache = {"i": 0}

    def hook_fn(module, input, output):
        # Unpack output, usually first element is the hidden state we want
        hidden_states = output[0]  # Extract required part
        idx = cache["i"]
        outs[idx] = hidden_states.detach().to(target_device)
        cache["i"] += 1

    hook = target_layer.register_forward_hook(hook_fn)

    for i, batch_inp in enumerate(data):
        batch_inp = batch_inp.unsqueeze(0).to(device)  # Add batch dimension
        print(f"Processing batch {i}, batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")
        # Ensure attention_mask and other parameters meet model requirements
        model(batch_inp, attention_mask=torch.ones_like(batch_inp))

    hook.remove()

    assert cache["i"] == nsamples, f"Captured count mismatch: {cache['i']} vs {nsamples}"

    return outs, {}


@torch.no_grad()
def get_inps_llama_by_linear(
        model: nn.Module,
        data: Sequence[torch.Tensor],
        model_seqlen: int,
        device: torch.device,
        offload_activations: bool,
        linear_layer_idx: int = 0,  # Linear layer index 0~223
) -> Tuple[torch.Tensor, Dict]:
    linear_names = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    n_layers = 32
    n_linears_per_layer = 7
    total_linears = n_layers * n_linears_per_layer

    assert 0 <= linear_layer_idx < total_linears, f"linear_layer_idx should be in range 0~{total_linears - 1}"

    layer_idx = linear_layer_idx // n_linears_per_layer
    linear_name = linear_names[linear_layer_idx % n_linears_per_layer]

    # print(f"Capturing linear layer index: {linear_layer_idx}, belonging to Transformer layer {layer_idx} '{linear_name}'")

    layers = get_layers(model)
    assert len(layers) == n_layers, f"Model should have {n_layers} layers, but got {len(layers)}"

    target_layer = layers[layer_idx].to(device)

    if linear_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        target_module = getattr(target_layer.self_attn, linear_name)
    else:
        target_module = getattr(target_layer.mlp, linear_name)

    target_module = target_module.to(device)

    dtype = next(model.parameters()).dtype
    nsamples = len(data)

    inps_list = []
    forward_args = {}

    def hook_fn(module, input):
        print(f"Hook captured activation, shape={input[0].shape}, dtype={input[0].dtype}")
        inps_list.append(input[0].detach().cpu())

    hook_handle = target_module.register_forward_pre_hook(hook_fn)

    for idx, batch_inp in enumerate(data):
        print(f"Processing batch {idx}")

        if batch_inp.dim() == 1:
            batch_inp = batch_inp.unsqueeze(0)  # Add batch dimension, become [1, seq_len]

        batch_inp = batch_inp.to(device)
        print(f"batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")

        if batch_inp.dtype != torch.long:
            print(f"Converting batch_inp to long")
            batch_inp = batch_inp.long()

        # Print sequence length info
        seq_len = batch_inp.size(1)
        print(f"Sequence length: {seq_len}")

        # Truncate if sequence length exceeds model max length, with warning
        max_seq_len = model_seqlen
        if seq_len > max_seq_len:
            print(f"Warning: sequence length {seq_len} exceeds model max length {max_seq_len}, truncating")
            batch_inp = batch_inp[:, :max_seq_len]

        # Generate attention_mask (assumed no padding here; modify if padding present)
        attention_mask = torch.ones(batch_inp.shape, dtype=torch.long, device=device)
        print(f"attention_mask.shape={attention_mask.shape}, dtype={attention_mask.dtype}")

        try:
            outputs = model(input_ids=batch_inp, attention_mask=attention_mask)
            # Optionally print output shape
            if hasattr(outputs, "last_hidden_state"):
                print(f"Model output last_hidden_state.shape={outputs.last_hidden_state.shape}")
            elif isinstance(outputs, tuple) and len(outputs) > 0:
                print(f"Model output first element shape={outputs[0].shape}")
        except Exception as e:
            print(f"Error during model call: {e}")
            raise

    hook_handle.remove()

    if len(inps_list) == 0:
        raise RuntimeError(
            "No activations captured, confirm hook registered properly and model forward runs successfully")

    # Print example activation shape info
    # print(f"Captured activations count: {len(inps_list)}")
    # print(f"Sample activation shape: {inps_list[0].shape}")

    inps = torch.stack(inps_list, dim=0)
    # print(f"Stacked inps.shape={inps.shape}")

    # Remove batch dimension if size 1
    if inps.size(1) == 1:
        inps = inps.squeeze(1)  # Become [nsamples, seq_len, hidden_size]

    # print(f"inps shape after squeezing batch dim: {inps.shape}")

    target_device = torch.device("cpu") if offload_activations else device
    inps = inps.to(target_device)
    # print("Activation capture finished")

    return inps, forward_args



@torch.no_grad()
def get_outs_llama_by_linear(
        model: nn.Module,
        data: Sequence[torch.Tensor],
        model_seqlen: int,
        device: torch.device,
        offload_activations: bool,
        linear_layer_idx: int = 0,  # Linear layer index 0~223
) -> Tuple[torch.Tensor, Dict]:
    linear_names = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    n_layers = 32
    n_linears_per_layer = 7
    total_linears = n_layers * n_linears_per_layer

    assert 0 <= linear_layer_idx < total_linears, f"linear_layer_idx should be in range 0~{total_linears - 1}"

    layer_idx = linear_layer_idx // n_linears_per_layer
    linear_name = linear_names[linear_layer_idx % n_linears_per_layer]

    print(f"Capturing linear layer index: {linear_layer_idx}, belongs to Transformer layer {layer_idx} '{linear_name}'",
          flush=True)

    layers = get_layers(model)
    assert len(layers) == n_layers, f"Model should have {n_layers} layers, but got {len(layers)}"

    target_layer = layers[layer_idx].to(device)

    if linear_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        target_module = getattr(target_layer.self_attn, linear_name)
    else:
        target_module = getattr(target_layer.mlp, linear_name)

    target_module = target_module.to(device)

    dtype = next(model.parameters()).dtype
    nsamples = len(data)

    outs_list = []
    forward_args = {}

    def hook_fn(module, input, output):
        print(f"Hook captured output, shape={output.shape}, dtype={output.dtype}")
        outs_list.append(output.detach().cpu())  # Capture output

    hook_handle = target_module.register_forward_hook(hook_fn)

    for idx, batch_inp in enumerate(data):
        print(f"Processing batch {idx}")

        if batch_inp.dim() == 1:
            batch_inp = batch_inp.unsqueeze(0)  # Add batch dimension, become [1, seq_len]

        batch_inp = batch_inp.to(device)
        print(f"batch_inp.shape={batch_inp.shape}, dtype={batch_inp.dtype}")

        if batch_inp.dtype != torch.long:
            print(f"Converting batch_inp to long")
            batch_inp = batch_inp.long()

        # Print sequence length info
        seq_len = batch_inp.size(1)
        print(f"Sequence length: {seq_len}")

        # Truncate if sequence length exceeds model max length, with warning
        max_seq_len = model_seqlen
        if seq_len > max_seq_len:
            print(f"Warning: sequence length {seq_len} exceeds model max length {max_seq_len}, truncating")
            batch_inp = batch_inp[:, :max_seq_len]

        # Generate attention_mask (assumed no padding here; modify if padding exists)
        attention_mask = torch.ones(batch_inp.shape, dtype=torch.long, device=device)
        print(f"attention_mask.shape={attention_mask.shape}, dtype={attention_mask.dtype}")

        try:
            outputs = model(input_ids=batch_inp, attention_mask=attention_mask)
            # Optionally print output shape
            if hasattr(outputs, "last_hidden_state"):
                print(f"Model output last_hidden_state.shape={outputs.last_hidden_state.shape}")
            elif isinstance(outputs, tuple) and len(outputs) > 0:
                print(f"Model output first element shape={outputs[0].shape}")
        except Exception as e:
            print(f"Error during model call: {e}")
            raise

    hook_handle.remove()

    if len(outs_list) == 0:
        raise RuntimeError("No outputs captured, ensure hook registered properly and model forward runs successfully")

    # Print example captured output shape info
    print(f"Captured outputs count: {len(outs_list)}")
    print(f"Sample captured output shape: {outs_list[0].shape}")

    outs = torch.stack(outs_list, dim=0)
    print(f"Stacked outs shape={outs.shape}")

    # Remove batch dimension if size 1
    if outs.size(1) == 1:
        outs = outs.squeeze(1)  # Become [nsamples, seq_len, hidden_size]

    # print(f"outs shape after squeezing batch dim: {outs.shape}")

    target_device = torch.device("cpu") if offload_activations else device
    outs = outs.to(target_device)
    print("Capture completed")

    return outs, forward_args


def detect_problematic_channels(activations: torch.Tensor, threshold: float = 6.0):
    """
    Input:
      activations: torch.Tensor, shape [nsamples, seq_len, hidden_size]
      threshold: float, threshold multiple, default 6

    Returns:
      problematic_channels: List[int], indices of channels exceeding threshold
      channel_means: torch.Tensor, shape [hidden_size], mean value per channel
      channel_vars: torch.Tensor, shape [hidden_size], variance per channel
    """
    # Compute mean and variance per channel across all samples and sequence dimensions
    # Reshape to [nsamples * seq_len, hidden_size]
    nsamples, seq_len, hidden_size = activations.shape
    flattened = activations.reshape(-1, hidden_size)  # [nsamples*seq_len, hidden_size]

    channel_means = flattened.mean(dim=0)  # [hidden_size]
    channel_vars = flattened.var(dim=0, unbiased=False)  # [hidden_size]

    # Compute mean absolute value per channel
    channel_abs_means = flattened.abs().mean(dim=0)  # [hidden_size]

    # Compute global mean absolute value of all elements
    global_abs_mean = flattened.abs().mean()

    # Identify channels whose mean absolute value exceeds threshold * global mean absolute value
    problematic_mask = channel_abs_means > (threshold * global_abs_mean)

    problematic_channels = problematic_mask.nonzero(as_tuple=False).squeeze(1).tolist()

    return problematic_channels, channel_means, channel_vars

# # Set GPU to use
# # os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# # print("Loading model")
# # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # model_path = '...'
# # model = load_safetensors_model(model_path)
# # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# # model.to(device)
# # model_seqlen = model.config.max_position_embeddings
# # dataset = "wikitext-2"
# # data = get_dataset_by_length(dataset, tokenizer)  # shape like [142, 2048]
# #
# # # TODO: Select samples based on indices in file '.../greater_than_0.03.txt'
# # filename = '.../greater_than_0.03.txt'
# # with open(filename, 'r') as f:
# #     lines = f.readlines()
# # indexes = [int(line.strip()) for line in lines]
# # selected_tensors = []
# # for i in indexes:
# #     if 0 <= i < len(data):
# #         sample = data[i]
# #         selected_tensors.append(sample)
# # data = torch.stack(selected_tensors, dim=0)  # concat along dim 0
#
#
# # # sample_0_to_5 = data[0:6]  # samples 0 to 5 (inclusive start, exclusive end)
# # # sample_16_to_23 = data[16:24]
# # # data = torch.cat((sample_0_to_5, sample_16_to_23), dim=0)
# # print("Data shape:", data.shape)  # Not expected torch.Size([88064])
# # seq_len = data.size(1)
#
# # for layer_id in range(118, 224):  # linear layers from 0 to 223, start from 118
# #     # Get inputs of specified linear layer
# #     # inps, forward_args = get_inps_llama_by_linear(model, data, seq_len, device, False, layer_id)
# #     # inps_path = f'.../activation/linear_inps/inps_linear_{layer_id}.pt'
# #     # torch.save(inps, inps_path)
# #     # print(f"Saved inputs of layer {layer_id}, shape: {inps.shape}")
#
# #     # Get outputs of specified linear layer
# #     outs, forward_args = get_outs_llama_by_linear(model, data, seq_len, device, False, layer_id)
# #     outs_path = f'.../activation/linear_out/outs_linear_{layer_id}.pt'
# #     torch.save(outs, outs_path)
# #     print(f"Saved outputs of layer {layer_id}, shape: {outs.shape}")
#
# # # inps, forward_args = get_inps_llama_by_transformer(model, data, seq_len , device, False, 223)
# # # torch.save(inps, '.../activation/transformer_inps/inps_transformer_0.pt')
# # inps, forward_args = get_inps_llama_by_linear(model, data, seq_len , device, False, 223)
# # torch.save(inps, '.../activation/linear_inps/inps_linear_223.pt')
# # # print("inps.shape:", inps.shape)  # [nsamples, seq_len, hidden_size]
#
# # outs, forward_args = get_outs_llama_by_linear(model, data, seq_len , device, False, 223)
# # torch.save(outs, '.../activation/linear_out/outs_linear_223.pt')
#
# # # outs, forward_args = get_outs_llama_by_transformer(model, data, seq_len , device, False, 0)
# # # torch.save(outs, '.../activation/transformer_out/outs_transformer_0.pt')
#
# # print("outs.shape:", outs.shape)  # [nsamples, seq_len, hidden_size]
# # TODO: save inputs and outputs activations from linear layers 0-233 into corresponding files
#
#
# # TODO: file classification
# # Save normal input activations of linear layers to inps_linear folder, e.g., inps_linear_0.pt and inps_linear_6.pt
# # Save quantized input activations of linear layers (obtained in real-time)
# # Save normal output activations of linear layers to outs_linear folder, e.g., outs_linear_0.pt and outs_linear_6.pt
# # Save normal input activations of transformer layers to inps_transformer folder, e.g., inps_transformer_6.pt
# # Save quantized input activations of transformer layers (obtained in real-time)
# # Save normal output activations of transformer layers to outs_transformer folder, e.g., outs_transformer_6.pt
#
#
# # linear_layers = get_Linears(model)
# # for layer_index, name, layer in linear_layers:
# #     # Skip layers named lm_head
# #     if 'lm_head' in name:
# #         print(f"Ignoring layer: {name}")
# #         continue
# #     print(f"Processing layer: {name}")
# #
# #     layer_attribute_map = {
# #         "self_attn.q_proj": 1,
# #         "self_attn.k_proj": 2,
# #         "self_attn.v_proj": 3,
# #         "self_attn.o_proj": 4,
# #         "mlp.gate_proj": 5,
# #         "mlp.up_proj": 6,
# #         "mlp.down_proj": 7,
# #     }
# #     attribute_name = name.split('.')[-2] + '.' + name.split('.')[-1]  # e.g. self_attn.q_proj
# #     layer_attribute = layer_attribute_map.get(attribute_name, 0)
# #     # print(f"Attribute index: {layer_attribute}")
# #     if layer_attribute == 7 and layer_index == 0:
# #         layer_shape = layer.weight.shape
# #         w = layer.weight.clone().detach()
# #         # Get bias
# #         b = layer.bias.clone().detach() if layer.bias is not None else torch.zeros(layer_shape[0], device=w.device)
#
# #         # Compute linear layer output
# #         predicted_outs = torch.matmul(inps, w.T) + b
# #         print("predicted_outs.shape:", predicted_outs.shape)
# #         # Method 1: convert outs to predicted_outs dtype
# #         outs_float = outs.to(predicted_outs.dtype)
# #         # TODO: delete outs
# #         # Delete outs to free memory
# #         del outs
# #         print("outs deleted.")
#
# #         # TODO: compare first 20 elements of predicted_outs and outs_float
# #         # Use torch.allclose with proper tolerance
# #         are_outputs_close = torch.allclose(predicted_outs[:20], outs_float[:20], atol=1e-6)
#
# #         # Or directly compare first 20 elements
# #         direct_comparison = torch.equal(predicted_outs[:20], outs_float[:20])
#
# #         # Print comparison results
# #         print("First 20 elements equal (allclose):", are_outputs_close)
# #         print("First 20 elements equal (direct):", direct_comparison)







