import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["NCCL_P2P_DISABLE"] = "0"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT"] = "2.0"
import sys
import json
import time
import argparse
from typing import Dict, List

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import MixtralForCausalLM, AutoTokenizer
from tqdm import tqdm
import numpy as np


MODEL_PATH = "/Path/Mixtral-8x7B-v0.1"
OUTPUT_DIR = os.path.join("importance_score", "layerwise")
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "mixtral_output_cosine_values.json")

NUM_SAMPLES = 128              
BATCH_SIZE = 8
SEQUENCE_LENGTH = 512          
MAX_SEQUENCES = 32

DATASET_NAME = "wikitext"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0


def load_resources(dataset_name: str = DATASET_NAME, num_samples: int = NUM_SAMPLES, batch_size: int = BATCH_SIZE, seq_len: int = SEQUENCE_LENGTH):
    print("Loading model and tokenizer...")
    print(f"Using {NUM_GPUS} GPUs with device_map='auto'")
    model = MixtralForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading and preparing dataset...")
    try:
        repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
        if repo_root not in sys.path:
            sys.path.append(repo_root)
        from data.build import build_calib_loader  # type: ignore
        dataloader = build_calib_loader(
            dataset=dataset_name,
            tokenizer=tokenizer,
            n_blocks_for_stat=num_samples,
            batch_size=batch_size,
            max_block_size=seq_len,
            num_workers=4,
        )
    except Exception:
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split='train')

        def tokenize_function(examples):
            return tokenizer(
                examples['text'],
                truncation=True,
                max_length=seq_len,
                padding="max_length",
                return_tensors="pt",
            )

        tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
        tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        calib = tokenized.select(range(min(num_samples, len(tokenized))))
        dataloader = DataLoader(calib, batch_size=batch_size, shuffle=False)

    return model, tokenizer, dataloader


def _compute_cosine_similarity_matrix(expert_outputs: List[torch.Tensor]) -> float: 
    merged = [t for t in expert_outputs if t is not None and t.numel() > 0]
    num_experts = len(merged)
    if num_experts < 2:
        return 0.0

    cosine_vals: List[float] = []
    for i in range(num_experts):
        for j in range(i + 1, num_experts):
            X = merged[i]
            Y = merged[j]
            if X.dim() > 2:
                X = X.reshape(X.shape[0], -1)
            if Y.dim() > 2:
                Y = Y.reshape(Y.shape[0], -1)
            min_rows = min(X.shape[0], Y.shape[0])
            if min_rows < 1:
                continue
            X = X[:min_rows].to(torch.float32)
            Y = Y[:min_rows].to(torch.float32)
            cosine = torch.nn.functional.cosine_similarity(X, Y, dim=1).mean().item()
            cosine_vals.append(float(cosine))

    if not cosine_vals:
        return 0.0
    return float(np.mean(cosine_vals))


def compute_output_cosine_scores(model: MixtralForCausalLM, dataloader: DataLoader) -> Dict[int, float]:
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)
    expert_outputs: Dict[str, List[List[torch.Tensor]]] = {}
    layer_name_to_idx: Dict[str, int] = {}

    def _custom_moe_forward_cosine(self, hidden_states):
        original_output = self._original_forward(hidden_states)

        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states_view = hidden_states.view(-1, hidden_dim)
        current_batch_experts: List[torch.Tensor] = []
        with torch.no_grad():
            for expert in self.experts:
                out = expert(hidden_states_view)
                current_batch_experts.append(out.detach().cpu())

        if self._module_name not in expert_outputs:
            expert_outputs[self._module_name] = [[] for _ in range(len(self.experts))]

        for i, out in enumerate(current_batch_experts):
            if out is not None and out.numel() > 0:
                expert_outputs[self._module_name][i].append(out)

        return original_output

    original_forwards = {}
    try:
        for name, module in model.named_modules():
            if hasattr(module, 'experts') and hasattr(module, 'gate') and hasattr(module, 'top_k'):
                original_forwards[name] = module.forward
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward_cosine.__get__(module, type(module))
                layer_name_to_idx[name] = int(name.split('.')[2]) if name.startswith('model.layers.') else len(layer_name_to_idx)

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(dataloader, desc="Collecting expert outputs for cosine")):
                if batch_idx >= MAX_SEQUENCES:
                    break
                batch = {k: (v.cuda(non_blocking=True) if torch.cuda.is_available() else v) for k, v in batch.items()}
                _ = model(**batch)
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
    finally:
        for name, module in model.named_modules():
            if name in original_forwards:
                module.forward = original_forwards[name]

    layer_scores: Dict[int, float] = {}
    for layer_name, per_expert_batches in tqdm(expert_outputs.items(), desc="Calculating output-cosine metric"):
        merged_per_expert: List[torch.Tensor] = []
        for batch_list in per_expert_batches:
            if batch_list:
                merged_per_expert.append(torch.cat(batch_list, dim=0))
            else:
                merged_per_expert.append(None)
        value = _compute_cosine_similarity_matrix(merged_per_expert)
        layer_idx = layer_name_to_idx.get(layer_name, len(layer_scores))
        layer_scores[layer_idx] = float(value)

    return layer_scores


def normalize_and_format_results(values: Dict[int, float], num_layers: int) -> Dict[str, Dict[str, float]]:
    raw = np.array([values.get(i, 0.0) for i in range(num_layers)], dtype=np.float64)
    if raw.max() > raw.min():
        norm = (raw - raw.min()) / (raw.max() - raw.min())
    else:
        norm = np.ones_like(raw) * 0.5
    result: Dict[str, Dict[str, float]] = {}
    ranks = np.argsort(-norm)
    for i in range(num_layers):
        result[f"layer_{i}"] = {
            "output_cosine_value": float(raw[i]),
            "normalized_importance": float(norm[i].item()),
            "final_score": float(norm[i].item()),
            "layer_importance_rank": int(np.where(ranks == i)[0][0] + 1),
            "importance_category": "high" if norm[i] > 0.7 else ("medium" if norm[i] > 0.3 else "low"),
        }
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default=DATASET_NAME, choices=["wikitext", "c4"], help="calibration dataset")
    parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="mapped to n_blocks_for_stat")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--seq_len", type=int, default=SEQUENCE_LENGTH, help="mapped to max_block_size")
    args = parser.parse_args()
    start = time.time()
    print("=" * 60)
    print("Mixtral Layer-wise Output-Cosine Metric")
    print("=" * 60)
    print(f"Using device: {DEVICE}")
    print(f"Number of GPUs: {NUM_GPUS}")
    print(f"Model path: {MODEL_PATH}")
    print(f"Num samples: {NUM_SAMPLES} | Batch size: {BATCH_SIZE} | Seq len: {SEQUENCE_LENGTH} | Max sequences: {MAX_SEQUENCES}")

    try:
        print("\n" + "=" * 40)
        print("Step 1: Loading Resources")
        print("=" * 40)
        model, tokenizer, dataloader = load_resources(args.dataset, args.num_samples, args.batch_size, args.seq_len)

        print("\n" + "=" * 40)
        print("Step 2: Collecting Expert Outputs")
        print("=" * 40)
        layer_values = compute_output_cosine_scores(model, dataloader)

        print("\n" + "=" * 40)
        print("Step 3: Processing Results")
        print("=" * 40)
        num_layers = model.config.num_hidden_layers
        formatted = normalize_and_format_results(layer_values, num_layers)

        os.makedirs(OUTPUT_DIR, exist_ok=True)
        print(f"Saving output-cosine scores to {OUTPUT_FILE}...")
        with open(OUTPUT_FILE, 'w') as f:
            json.dump(formatted, f, indent=4)

        elapsed = time.time() - start
        print(f"\nCompleted successfully in {elapsed:.2f} seconds!")

    except Exception as e:
        print(f"\nError occurred: {e}")
        import traceback
        traceback.print_exc()
        return

    # 示例打印
    print("\nSample Layer Output-Cosine Scores:")
    for i in range(min(8, len(formatted))):
        lk = f"layer_{i}"
        s = formatted[lk]
        print(f"  {lk}:")
        print(f"    Output-Cosine Value: {s['output_cosine_value']:.6f}")
        print(f"    Normalized Importance: {s['normalized_importance']:.4f}")
        print(f"    Importance Rank: {s['layer_importance_rank']}")
        print(f"    Category: {s['importance_category']}")

    all_vals = [d['output_cosine_value'] for d in formatted.values()]
    all_norm = [d['normalized_importance'] for d in formatted.values()]
    print("\nOverall Statistics:")
    print(f"  Total layers: {len(formatted)}")
    print(f"  Value - Mean: {np.mean(all_vals):.6f}, Std: {np.std(all_vals):.6f}")
    print(f"  Value - Min: {np.min(all_vals):.6f}, Max: {np.max(all_vals):.6f}")
    print(f"  Normalized scores - Mean: {np.mean(all_norm):.4f}, Std: {np.std(all_norm):.4f}")

    high = [k for k, v in formatted.items() if v['importance_category'] == 'high']
    med = [k for k, v in formatted.items() if v['importance_category'] == 'medium']
    low = [k for k, v in formatted.items() if v['importance_category'] == 'low']
    print("\nLayer Importance Distribution:")
    print(f"  High importance ({len(high)} layers): {high}")
    print(f"  Medium importance ({len(med)} layers): {med}")
    print(f"  Low importance ({len(low)} layers): {low}")

    sorted_layers = sorted(formatted.items(), key=lambda x: x[1]['normalized_importance'], reverse=True)
    print("\nTop 5 Most Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[:5]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")

    print("\nTop 5 Least Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[-5:]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")


if __name__ == "__main__":
    main()


