import os
import torch
from transformers import AutoTokenizer
from transformers.models.deepseek_v2.modeling_deepseek_v2 import DeepseekV2ForCausalLM
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import numpy as np
from typing import Dict, List, Tuple
import time
from collections import defaultdict

MODEL_PATH = "/Path/DeepSeek-V2-Lite"
OUTPUT_FILE = "deepseek_v2_mutual_information_values.json"

NUM_SAMPLES = 128
BATCH_SIZE = 8
MAX_SEQUENCE_LENGTH = 10240
MIN_SEQUENCE_LENGTH = 5120
MAX_SEQUENCES = 32




DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0


class ExpertOutputCollector:
    def __init__(self, model: DeepseekV2ForCausalLM):
        self.model = model
        self.num_layers = model.config.num_hidden_layers
        
        self.expert_outputs = defaultdict(lambda: defaultdict(list))
        self.hooks: List[torch.utils.hooks.RemovableHandle] = []
        self.num_experts = getattr(model.config, "n_routed_experts", 64)
        self._register_hooks()

    def _register_hooks(self):
        for layer_idx, layer in enumerate(self.model.model.layers):
            if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'):
                experts = layer.mlp.experts
                try:
                    num_layer_experts = len(experts)
                except Exception:
                    num_layer_experts = self.num_experts
                print(f"Registering hooks for layer {layer_idx}: {num_layer_experts} experts")
                for expert_idx, expert_module in enumerate(experts):
                    handle = expert_module.register_forward_hook(self._make_hook(layer_idx, expert_idx))
                    self.hooks.append(handle)

    def _make_hook(self, layer_idx: int, expert_idx: int):
        def hook_fn(module, inputs, output):
            if output is None:
                return
            try:
                if output.dim() == 2:
                    avg = output.mean(dim=0).detach().to('cpu')
                elif output.dim() > 2:
                    flat = output.view(output.size(0), -1)
                    avg = flat.mean(dim=0).detach().to('cpu')
                else:
                    avg = output.detach().to('cpu')
                self.expert_outputs[layer_idx][expert_idx].append(avg)
            except Exception:
                pass
        return hook_fn

    def get_layer_outputs(self, layer_idx: int):
        return self.expert_outputs.get(layer_idx, {})

    def cleanup(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()


def load_resources():
    print("Loading model and tokenizer...")
    model = DeepseekV2ForCausalLM.from_pretrained(
        MODEL_PATH,
        dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading and preparing dataset...")
    
    dataset = load_dataset('json', data_files={'train': './data/c4-train.00000-of-01024.json.gz'}, trust_remote_code=True)['train']
    
    def num_tokens(text: str) -> int:
        return len(tokenizer(text, add_special_tokens=False).input_ids)

    def keep_example(example) -> bool:
        return num_tokens(example["text"]) > MIN_SEQUENCE_LENGTH
    def tokenize_function(examples):
        return tokenizer(
            examples['text'], 
            truncation=True, 
            max_length=MAX_SEQUENCE_LENGTH, 
            padding="max_length",
            return_tensors="pt"
        )

    tokenized_dataset = dataset.filter(keep_example).map(
        tokenize_function, 
        batched=True, 
        remove_columns=["text"]
    )
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    calibration_dataset = tokenized_dataset.select(range(min(NUM_SAMPLES, len(tokenized_dataset))))
    dataloader = DataLoader(calibration_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return model, tokenizer, dataloader

def collect_outputs(model, dataloader) -> ExpertOutputCollector:
    print("Initializing expert output collector...")
    collector = ExpertOutputCollector(model)
    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Collecting expert outputs")):
            if batch_idx >= MAX_SEQUENCES:
                break
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            _ = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False,
                output_attentions=False,
                use_cache=False,
                return_dict=True,
            )
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    return collector


def _stack_expert_samples(
    layer_outputs: Dict[int, List[torch.Tensor]]
) -> Tuple[List[int], List[torch.Tensor], int]:
    if not layer_outputs:
        return [], [], 0
    expert_indices = sorted(layer_outputs.keys())
    samples_per_expert = [len([v for v in layer_outputs[e] if v is not None]) for e in expert_indices]
    if not samples_per_expert or min(samples_per_expert) < 2:
        return expert_indices, [], 0
    min_samples = min(samples_per_expert)

    expert_mats: List[torch.Tensor] = []
    for e in expert_indices:
        vectors = [v for v in layer_outputs[e] if v is not None][:min_samples]
        if not vectors:
            mat = torch.zeros(min_samples, 1)
            mat = mat.pin_memory()
            expert_mats.append(mat)
        else:
            mat = torch.stack(vectors, dim=0)  
            mat = mat.pin_memory()
            expert_mats.append(mat)

    return expert_indices, expert_mats, min_samples


def build_correlation_matrix(layer_outputs: Dict[int, List[torch.Tensor]]) -> torch.Tensor:
    if not layer_outputs:
        return torch.zeros(1, 1)
    expert_indices, expert_mats, min_samples = _stack_expert_samples(layer_outputs)
    num_experts = len(expert_indices)
    if min_samples < 2 or not expert_mats:
        return torch.zeros(num_experts if num_experts > 0 else 1, num_experts if num_experts > 0 else 1)

    R = torch.zeros(num_experts, num_experts, dtype=torch.float32)
    for i in range(num_experts):
        R[i, i] = 1.0
        for j in range(i + 1, num_experts):
            mi = expert_mats[i]
            mj = expert_mats[j]
            mean_i = mi.mean(dim=0)
            mean_j = mj.mean(dim=0)
            std_i = mi.std(dim=0)
            std_j = mj.std(dim=0)   
            valid = (std_i > 0) & (std_j > 0)
            if valid.sum() == 0:
                corr = 0.0
            else:
                ci = (mi[:, valid] - mean_i[valid]).mean(dim=0)
                cj = (mj[:, valid] - mean_j[valid]).mean(dim=0)
                denom = (std_i[valid].mean() * std_j[valid].mean()).item()
                if denom == 0:
                    corr = 0.0
                else:
                    corr = float((ci * cj).mean().item() / denom)
            corr = float(max(-1.0, min(1.0, corr)))
            R[i, j] = corr
            R[j, i] = corr
    R = 0.5 * (R + R.T)
    R = R + 1e-6 * torch.eye(num_experts, dtype=R.dtype)
    return R


def gaussian_cca_pairwise_mi(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-6, ridge: float = 1e-4) -> float:
    if X.dim() != 2 or Y.dim() != 2 or X.size(0) != Y.size(0) or X.size(0) < 2:
        return 0.0
    X = X.to(torch.float32)
    Y = Y.to(torch.float32)
    n = X.size(0)
    Xc = X - X.mean(dim=0, keepdim=True)
    Yc = Y - Y.mean(dim=0, keepdim=True)
    Sxx = (Xc.T @ Xc) / (n - 1 + eps)
    Syy = (Yc.T @ Yc) / (n - 1 + eps)
    Sxy = (Xc.T @ Yc) / (n - 1 + eps)
    dx = Sxx.size(0)
    dy = Syy.size(0)
    Sxx = Sxx + ridge * torch.eye(dx, dtype=Sxx.dtype, device=X.device)
    Syy = Syy + ridge * torch.eye(dy, dtype=Syy.dtype, device=Y.device)

    def inv_sqrt(mat: torch.Tensor, min_eig: float = 1e-6) -> torch.Tensor:
        tmp_device = mat.device
        eigvals, eigvecs = torch.linalg.eigh(mat.to(torch.float32))
        eigvals = torch.clamp(eigvals, min=min_eig)
        inv_sqrt_vals = torch.diag(torch.pow(eigvals, -0.5))
        return (eigvecs @ inv_sqrt_vals @ eigvecs.T).to(tmp_device)

    Sxx_mhalf = inv_sqrt(Sxx)
    Syy_mhalf = inv_sqrt(Syy)

    T = Sxx_mhalf @ Sxy @ Syy_mhalf
    try:
        sv = torch.linalg.svdvals(T)
    except RuntimeError:
        _, sv, _ = torch.linalg.svd(T, full_matrices=False)
    sv = torch.clamp(sv, 0.0, 1.0 - eps)
    arg = torch.clamp(1.0 - sv * sv, min=eps)
    mi = -0.5 * torch.log(arg).sum().item()
    if not np.isfinite(mi) or mi < 0:
        mi = float(max(0.0, mi))
    return mi


def gaussian_cca_pairwise_mi_tensor(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-6, ridge: float = 1e-4) -> torch.Tensor:
    device = X.device if X.is_cuda else (Y.device if Y.is_cuda else torch.device("cpu"))
    if X.dim() != 2 or Y.dim() != 2 or X.size(0) != Y.size(0) or X.size(0) < 2:
        return torch.zeros((), device=device, dtype=torch.float32)
    X = X.to(torch.float32)
    Y = Y.to(torch.float32)
    n = X.size(0)
    Xc = X - X.mean(dim=0, keepdim=True)
    Yc = Y - Y.mean(dim=0, keepdim=True)
    Sxx = (Xc.T @ Xc) / (n - 1 + eps)
    Syy = (Yc.T @ Yc) / (n - 1 + eps)
    Sxy = (Xc.T @ Yc) / (n - 1 + eps)
    dx = Sxx.size(0)
    dy = Syy.size(0)
    Sxx = Sxx + ridge * torch.eye(dx, dtype=Sxx.dtype, device=X.device)
    Syy = Syy + ridge * torch.eye(dy, dtype=Syy.dtype, device=Y.device)

    def inv_sqrt_tensor(mat: torch.Tensor, min_eig: float = 1e-6) -> torch.Tensor:
        eigvals, eigvecs = torch.linalg.eigh(mat.to(torch.float32))
        eigvals = torch.clamp(eigvals, min=min_eig)
        inv_sqrt_vals = torch.diag(torch.pow(eigvals, -0.5))
        return eigvecs @ inv_sqrt_vals @ eigvecs.T

    Sxx_mhalf = inv_sqrt_tensor(Sxx)
    Syy_mhalf = inv_sqrt_tensor(Syy)
    T = Sxx_mhalf @ Sxy @ Syy_mhalf
    try:
        sv = torch.linalg.svdvals(T)
    except RuntimeError:
        _, sv, _ = torch.linalg.svd(T, full_matrices=False)
    sv = torch.clamp(sv, 0.0, 1.0 - eps)
    arg = torch.clamp(1.0 - sv * sv, min=eps)
    mi = -0.5 * torch.log(arg).sum()
    mi = torch.where(torch.isfinite(mi), mi, torch.zeros_like(mi))
    mi = torch.clamp(mi, min=0.0)
    return mi


def mutual_information_sum_from_R(R: torch.Tensor, eps: float = 1e-8) -> float:
    if R.numel() <= 1 or R.size(0) < 2:
        return 0.0
    N = R.size(0)
    info = 0.0
    for i in range(N):
        for j in range(i + 1, N):
            rij = float(R[i, j].item())
            rij = max(-1.0 + eps, min(1.0 - eps, rij))
            arg = max(eps, 1.0 - rij * rij)
            mi = -0.5 * np.log(arg)
            info += float(mi)
    return float(info)


def compute_layer_mi_scores(model: DeepseekV2ForCausalLM, collector: ExpertOutputCollector) -> Dict[int, float]:
    scores: Dict[int, float] = {}
    num_layers = model.config.num_hidden_layers
    for layer_idx in tqdm(range(num_layers), desc="Computing MI scores"):
        if layer_idx == 0:
            continue
        layer_outputs = collector.get_layer_outputs(layer_idx)
        if not layer_outputs:
            scores[layer_idx] = 0.0
            continue
        expert_indices, expert_mats, min_samples = _stack_expert_samples(layer_outputs)
        if min_samples < 2 or not expert_mats:
            scores[layer_idx] = 0.0
            continue
        num_experts = len(expert_indices)
        pairs = [(i, j) for i in range(num_experts) for j in range(i + 1, num_experts)]

        def _compute_pairs_on_device(pairs_subset, device_str: str) -> float:
            if not pairs_subset:
                return 0.0
            device = torch.device(device_str)
            stream = torch.cuda.Stream(device=device)
            total = torch.zeros((), device=device, dtype=torch.float32)
            with torch.cuda.stream(stream):
                for (ii, jj) in pairs_subset:
                    Xi = expert_mats[ii].to(device, non_blocking=True)
                    Xj = expert_mats[jj].to(device, non_blocking=True)
                    total = total + gaussian_cca_pairwise_mi_tensor(Xi, Xj)
            stream.synchronize()
            return float(total.item())

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            pairs_dev0 = pairs[0::2]
            pairs_dev1 = pairs[1::2]
            partial0 = _compute_pairs_on_device(pairs_dev0, "cuda:0")
            partial1 = _compute_pairs_on_device(pairs_dev1, "cuda:1")
            score = partial0 + partial1
        else:
            device_single = "cuda:0" if torch.cuda.is_available() else "cpu"
            score = _compute_pairs_on_device(pairs, device_single)
        scores[layer_idx] = score
    return scores


def normalize_and_format_results(scores: Dict[int, float], model: DeepseekV2ForCausalLM) -> Dict[str, Dict[str, float]]:
    num_layers = model.config.num_hidden_layers
    raw = np.array([scores.get(i, 0.0) for i in range(num_layers)], dtype=np.float64)
    if raw.max() > raw.min():
        norm = (raw - raw.min()) / (raw.max() - raw.min())
    else:
        norm = np.ones_like(raw) * 0.5
    result: Dict[str, Dict[str, float]] = {}
    for i in range(num_layers):
        result[f"layer_{i}"] = {
            "mutual_information_sum": float(raw[i]),
            "normalized_importance": float(norm[i].item()),
            "final_score": float(norm[i].item()),
            "layer_importance_rank": int(np.argsort(-norm)[i] + 1),
            "importance_category": "high" if norm[i] > 0.7 else ("medium" if norm[i] > 0.3 else "low"),
        }
    return result


def main():
    start = time.time()
    print("=" * 60)
    print("DeepSeek-V2-Lite Sum of Mutual Information from Correlation Matrix-based Layer Importance")
    print("=" * 60)
    print(f"Using device: {DEVICE}")
    print(f"Number of GPUs: {NUM_GPUS}")
    print(f"Model path: {MODEL_PATH}")
    print(f"Num samples: {NUM_SAMPLES} | Batch size: {BATCH_SIZE} | Seq len: {MAX_SEQUENCE_LENGTH} | Max sequences: {MAX_SEQUENCES}")

    try:
        print("\n" + "=" * 40)
        print("Step 1: Loading Resources")
        print("=" * 40)
        model, tokenizer, dataloader = load_resources()

        print("\n" + "=" * 40)
        print("Step 2: Collecting Expert Outputs")
        print("=" * 40)
        collector = collect_outputs(model, dataloader)

        print("\n" + "=" * 40)
        print("Step 3: Computing Layer MI Scores")
        print("=" * 40)
        try:
            mi_scores = compute_layer_mi_scores(model, collector)
        finally:
            collector.cleanup()

        print("\n" + "=" * 40)
        print("Step 4: Processing Results")
        print("=" * 40)
        formatted = normalize_and_format_results(mi_scores, model)

        print(f"Saving importance scores to {OUTPUT_FILE}...")
        with open(OUTPUT_FILE, 'w') as f:
            json.dump(formatted, f, indent=4)

        elapsed = time.time() - start
        print(f"\nCompleted successfully in {elapsed:.2f} seconds!")

    except Exception as e:
        print(f"\nError occurred: {e}")
        import traceback
        traceback.print_exc()
        return

    print("\nSample Layer Importance Scores:")
    for i in range(min(8, len(formatted))):
        lk = f"layer_{i}"
        s = formatted[lk]
        print(f"  {lk}:")
        print(f"    Mutual Information Sum: {s['mutual_information_sum']:.6f}")
        print(f"    Normalized Importance: {s['normalized_importance']:.4f}")
        print(f"    Importance Rank: {s['layer_importance_rank']}")
        print(f"    Category: {s['importance_category']}")

    all_mi = [d['mutual_information_sum'] for d in formatted.values()]
    all_norm = [d['normalized_importance'] for d in formatted.values()]
    print("\nOverall Statistics:")
    print(f"  Total layers: {len(formatted)}")
    print(f"  MI sum - Mean: {np.mean(all_mi):.6f}, Std: {np.std(all_mi):.6f}")
    print(f"  MI sum - Min: {np.min(all_mi):.6f}, Max: {np.max(all_mi):.6f}")
    print(f"  Normalized scores - Mean: {np.mean(all_norm):.4f}, Std: {np.std(all_norm):.4f}")

    high = [k for k, v in formatted.items() if v['importance_category'] == 'high']
    med = [k for k, v in formatted.items() if v['importance_category'] == 'medium']
    low = [k for k, v in formatted.items() if v['importance_category'] == 'low']
    print("\nLayer Importance Distribution:")
    print(f"  High importance ({len(high)} layers): {high}")
    print(f"  Medium importance ({len(med)} layers): {med}")
    print(f"  Low importance ({len(low)} layers): {low}")

    sorted_layers = sorted(formatted.items(), key=lambda x: x[1]['normalized_importance'], reverse=True)
    print("\nTop 5 Most Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[:5]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")

    print("\nTop 5 Least Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[-5:]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")


if __name__ == "__main__":
    main()
