import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import numpy as np

# --- 1. Configuration ---

MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

NUM_SAMPLES_PER_SUBSET = 5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

def get_forward_pass_data_for_model(model, tokenizer, dataset, device, model_name_short):

    all_forward_pass_data = []

    print(f"\n--- Collecting hidden states for {model_name_short} ---")
    for sample in tqdm(dataset, desc=f"Processing samples"):
        text = sample.get('text')
        if not text or not isinstance(text, str):
            continue

        inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True).to(device)

        if inputs.input_ids.shape[1] <= 1:
            continue

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        sample_data = {
            'hidden_states': [h.cpu().detach() for h in outputs.hidden_states]
        }
        all_forward_pass_data.append(sample_data)

    return all_forward_pass_data




def analyze_and_plot_all(all_forward_pass_data, model_name_short, model_name_full):
 
    if not all_forward_pass_data:
        print("Error: 'all_forward_pass_data' is empty. Cannot calculate importance.")
        return

    try:
        first_sample_hidden_states = all_forward_pass_data[0]['hidden_states']
        num_layers = len(first_sample_hidden_states)
        hidden_size = first_sample_hidden_states[0].shape[-1]
        print(f"Detected {num_layers} layers and a hidden size of {hidden_size}.")
    except (IndexError, KeyError) as e:
        print(f"Error: Could not infer model structure from 'all_forward_pass_data': {e}")
        return

    layer_importances = torch.zeros(num_layers, hidden_size)

    print("\n--- Calculating Feature Importance Layer by Layer ---")
    for layer_idx in tqdm(range(num_layers), desc="Processing layers"):
        all_hidden_states_for_layer = [
            sample_data['hidden_states'][layer_idx].squeeze(0)
            for sample_data in all_forward_pass_data
            if layer_idx < len(sample_data['hidden_states'])
        ]

        if not all_hidden_states_for_layer:
            print(f"Warning: No hidden states found for layer {layer_idx}. Skipping.")
            continue

        mean_abs_activations_tensor = torch.stack([torch.mean(torch.abs(hs), dim=0) for hs in all_hidden_states_for_layer])
        layer_importances[layer_idx] = torch.linalg.norm(mean_abs_activations_tensor, ord=2, dim=0)

  
    print("\n--- Calculating Global Importance ---")
    global_importance = torch.mean(layer_importances, dim=0)
    sorted_importance, sorted_indices = torch.sort(global_importance, descending=True)

    print("\n--- Generating Global Importance Chart ---")


   
   

   

    num_to_keep_50 = hidden_size // 2

    top_50_percent_indices_tensor = sorted_indices[:num_to_keep_50]
    top_50_percent_indices_set = set(top_50_percent_indices_tensor.cpu().numpy().tolist())
    
   
    print(f"\n--- Top 50% Global Important Feature Dimensions ---")
    print(top_50_percent_indices_set)
    

    num_to_keep_75 = hidden_size * 3 // 4
    top_75_percent_indices_tensor = sorted_indices[:num_to_keep_75]
    top_75_percent_indices_set = set(top_75_percent_indices_tensor.cpu().numpy().tolist())

    print(f"\n--- Top 75% Global Important Feature Dimensions ---")
    print(top_75_percent_indices_set)

    print(f"\n--- Global Top Feature Dimension Info ---")
    print(f"Total feature dimensions: {hidden_size}")
    print(f"Number of Top 50% dimensions to keep: {len(top_50_percent_indices_set)}")
    print(f"Number of Top 75% dimensions to keep: {len(top_75_percent_indices_set)}")



def main():
    """主执行函数"""
    print("--- Preparing Test Dataset ---")
    try:
        subset_names = ['Nemotron-CC-Diverse-QA', 'Nemotron-SFT-Code', 'Nemotron-SFT-General']
        test_subsets = []
        for name in subset_names:
            subset_dataset = load_dataset("nvidia/Nemotron-Pretraining-Dataset-sample", name, split="train", trust_remote_code=True)
            shuffled_subset = subset_dataset.shuffle(seed=42)
            test_count = min(NUM_SAMPLES_PER_SUBSET, len(shuffled_subset))
            test_subsets.append(shuffled_subset.select(range(test_count)))
        test_dataset = concatenate_datasets(test_subsets).shuffle(seed=42)
        print(f"Total size of the test dataset: {len(test_dataset)}")
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        return

    model = None
    forward_pass_data = None

    print(f"\n--- Processing Model: {MODEL_NAME} ---")
    model_name_short = MODEL_NAME.replace('/', '_')
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map=DEVICE,
            trust_remote_code=True
        ).eval()

        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        forward_pass_data = get_forward_pass_data_for_model(
            model, tokenizer, test_dataset, DEVICE, model_name_short
        )

    except Exception as e:
        print(f"An error occurred while processing the model: {e}")
    finally:
        if model: del model
        if 'tokenizer' in locals(): del tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    if forward_pass_data:
        analyze_and_plot_all(
            forward_pass_data,
            model_name_short=model_name_short,
            model_name_full=MODEL_NAME
        )
    else:
        print("No forward pass data was collected, analysis cannot proceed.")


if __name__ == "__main__":
    main()