import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["NCCL_P2P_DISABLE"] = "0"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT"] = "2.0"
import sys
import json
import time
import argparse
from typing import Dict, List

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import MixtralForCausalLM, AutoTokenizer
from tqdm import tqdm
import numpy as np


MODEL_PATH = "/Path/Mixtral-8x7B-v0.1"
OUTPUT_DIR = os.path.join("importance_score", "layer_wise")
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "mixtral_output_outlier_values.json")

NUM_SAMPLES = 128              
BATCH_SIZE = 8
SEQUENCE_LENGTH = 512          
MAX_SEQUENCES = 32

DATASET_NAME = "wikitext"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0


def load_resources(dataset_name: str = DATASET_NAME, num_samples: int = NUM_SAMPLES, batch_size: int = BATCH_SIZE, seq_len: int = SEQUENCE_LENGTH):
    print("Loading model and tokenizer...")
    print(f"Using {NUM_GPUS} GPUs with device_map='auto'")
    model = MixtralForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading and preparing dataset...")
    try:
        repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
        if repo_root not in sys.path:
            sys.path.append(repo_root)
        from data.build import build_calib_loader  
        dataloader = build_calib_loader(
            dataset=dataset_name,
            tokenizer=tokenizer,
            n_blocks_for_stat=num_samples,
            batch_size=batch_size,
            max_block_size=seq_len,
            num_workers=4,
        )
    except Exception:
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split='train')

        def tokenize_function(examples):
            return tokenizer(
                examples['text'],
                truncation=True,
                max_length=seq_len,
                padding="max_length",
                return_tensors="pt",
            )

        tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
        tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        calib = tokenized.select(range(min(num_samples, len(tokenized))))
        dataloader = DataLoader(calib, batch_size=batch_size, shuffle=False)

    return model, tokenizer, dataloader


def compute_output_outlier_scores(model: MixtralForCausalLM, dataloader: DataLoader) -> Dict[int, float]:
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)

    layer_expert_sums: Dict[str, torch.Tensor] = {}
    layer_name_to_idx: Dict[str, int] = {}

    def _custom_moe_forward(self, hidden_states):
        original_output = self._original_forward(hidden_states)

        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states_view = hidden_states.view(-1, hidden_dim)

        router_logits = self.gate(hidden_states_view)
        routing_weights, routing_indices = torch.topk(router_logits, self.top_k, dim=-1)

        if self._module_name not in layer_expert_sums:
            num_experts = len(self.experts)
            layer_expert_sums[self._module_name] = torch.zeros(num_experts, dtype=torch.float32, device='cpu')

        with torch.no_grad():
            for expert_idx, expert in enumerate(self.experts):
                mask = (routing_indices == expert_idx).any(dim=-1)
                if mask.any():
                    expert_inputs = hidden_states_view[mask]
                    expert_output = expert(expert_inputs)
                    abs_output = expert_output.abs()
                    denom = abs_output.mean(dim=0).clamp_min(1e-12)
                    ratio = abs_output.max(dim=0).values / denom
                    score = float(ratio.max().detach().to('cpu').item())
                    layer_expert_sums[self._module_name][expert_idx] += score

        return original_output

    original_forwards = {}
    try:
        for name, module in model.named_modules():
            if hasattr(module, 'experts') and hasattr(module, 'gate') and hasattr(module, 'top_k'):
                original_forwards[name] = module.forward
                module._original_forward = module.forward
                module._module_name = name
                module.forward = _custom_moe_forward.__get__(module, type(module))
                layer_name_to_idx[name] = int(name.split('.')[2]) if name.startswith('model.layers.') else len(layer_name_to_idx)

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(dataloader, desc="Collecting output-outlier metric")):
                if batch_idx >= MAX_SEQUENCES:
                    break
                batch = {k: (v.cuda(non_blocking=True) if torch.cuda.is_available() else v) for k, v in batch.items()}
                if "labels" in batch:
                    batch.pop("labels")
                _ = model(**batch)
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
    finally:
        for name, module in model.named_modules():
            if name in original_forwards:
                module.forward = original_forwards[name]

    layer_scores: Dict[int, float] = {}
    for layer_name, vec in layer_expert_sums.items():
        if vec.numel() == 0:
            value = 0.0
        else:
            value = float(vec.mean().item())
        layer_idx = layer_name_to_idx.get(layer_name, len(layer_scores))
        layer_scores[layer_idx] = value

    return layer_scores


def normalize_and_format_results(values: Dict[int, float], num_layers: int) -> Dict[str, Dict[str, float]]:
    raw = np.array([values.get(i, 0.0) for i in range(num_layers)], dtype=np.float64)
    if raw.max() > raw.min():
        norm = (raw - raw.min()) / (raw.max() - raw.min())
    else:
        norm = np.ones_like(raw) * 0.5
    result: Dict[str, Dict[str, float]] = {}
    ranks = np.argsort(-norm)
    for i in range(num_layers):
        result[f"layer_{i}"] = {
            "output_outlier_value": float(raw[i]),
            "normalized_importance": float(norm[i].item()),
            "final_score": float(norm[i].item()),
            "layer_importance_rank": int(np.where(ranks == i)[0][0] + 1),
            "importance_category": "high" if norm[i] > 0.7 else ("medium" if norm[i] > 0.3 else "low"),
        }
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default=DATASET_NAME, choices=["wikitext", "c4"], help="calibration dataset")
    parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="mapped to n_blocks_for_stat")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--seq_len", type=int, default=SEQUENCE_LENGTH, help="mapped to max_block_size")
    args = parser.parse_args()
    start = time.time()
    print("=" * 60)
    print("Mixtral Layer-wise Output-Outlier Metric")
    print("=" * 60)
    print(f"Using device: {DEVICE}")
    print(f"Number of GPUs: {NUM_GPUS}")
    print(f"Model path: {MODEL_PATH}")
    print(f"Num samples: {NUM_SAMPLES} | Batch size: {BATCH_SIZE} | Seq len: {SEQUENCE_LENGTH} | Max sequences: {MAX_SEQUENCES}")

    try:
        print("\n" + "=" * 40)
        print("Step 1: Loading Resources")
        print("=" * 40)
        model, tokenizer, dataloader = load_resources(args.dataset, args.num_samples, args.batch_size, args.seq_len)

        print("\n" + "=" * 40)
        print("Step 2: Collecting Output-Outlier Scores")
        print("=" * 40)
        layer_values = compute_output_outlier_scores(model, dataloader)

        

        print("\n" + "=" * 40)
        print("Step 3: Processing Results")
        print("=" * 40)
        num_layers = model.config.num_hidden_layers
        formatted = normalize_and_format_results(layer_values, num_layers)

        os.makedirs(OUTPUT_DIR, exist_ok=True)
        print(f"Saving output-outlier scores to {OUTPUT_FILE}...")
        with open(OUTPUT_FILE, 'w') as f:
            json.dump(formatted, f, indent=4)

        elapsed = time.time() - start
        print(f"\nCompleted successfully in {elapsed:.2f} seconds!")

    except Exception as e:
        print(f"\nError occurred: {e}")
        import traceback
        traceback.print_exc()
        return

    print("\nSample Layer Output-Outlier Scores:")
    for i in range(min(8, len(formatted))):
        lk = f"layer_{i}"
        s = formatted[lk]
        print(f"  {lk}:")
        print(f"    Output-Outlier Value: {s['output_outlier_value']:.6f}")
        print(f"    Normalized Importance: {s['normalized_importance']:.4f}")
        print(f"    Importance Rank: {s['layer_importance_rank']}")
        print(f"    Category: {s['importance_category']}")

    all_vals = [d['output_outlier_value'] for d in formatted.values()]
    all_norm = [d['normalized_importance'] for d in formatted.values()]
    print("\nOverall Statistics:")
    print(f"  Total layers: {len(formatted)}")
    print(f"  Value - Mean: {np.mean(all_vals):.6f}, Std: {np.std(all_vals):.6f}")
    print(f"  Value - Min: {np.min(all_vals):.6f}, Max: {np.max(all_vals):.6f}")
    print(f"  Normalized scores - Mean: {np.mean(all_norm):.4f}, Std: {np.std(all_norm):.4f}")

    high = [k for k, v in formatted.items() if v['importance_category'] == 'high']
    med = [k for k, v in formatted.items() if v['importance_category'] == 'medium']
    low = [k for k, v in formatted.items() if v['importance_category'] == 'low']
    print("\nLayer Importance Distribution:")
    print(f"  High importance ({len(high)} layers): {high}")
    print(f"  Medium importance ({len(med)} layers): {med}")
    print(f"  Low importance ({len(low)} layers): {low}")

    sorted_layers = sorted(formatted.items(), key=lambda x: x[1]['normalized_importance'], reverse=True)
    print("\nTop 5 Most Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[:5]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")

    print("\nTop 5 Least Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[-5:]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")


if __name__ == "__main__":
    main()
