import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
os.environ["NCCL_P2P_DISABLE"] = "0"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT"] = "2.0"
import sys
import argparse
import torch
from transformers import MixtralForCausalLM, AutoTokenizer, default_data_collator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import numpy as np
from typing import Dict, List, Tuple
import time
from collections import defaultdict
import itertools

MODEL_PATH = "/Path/Mixtral-8x7B-v0.1"
OUTPUT_FILE = "mixtral_mutual_information_values_c4.json"

NUM_SAMPLES = 128              
BATCH_SIZE = 8
SEQUENCE_LENGTH = 512          
MAX_SEQUENCES = 32

DATASET_NAME = "c4"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0


class ExpertOutputCollector:
    def __init__(self, model: MixtralForCausalLM):
        self.model = model
        self.num_layers = model.config.num_hidden_layers
        # {layer_idx: {expert_idx: [avg_vector_per_batch, ...]}}
        self.expert_outputs = defaultdict(lambda: defaultdict(list))
        self.hooks: List[torch.utils.hooks.RemovableHandle] = []
        self.num_experts = getattr(model.config, "num_local_experts", 8)
        self._register_hooks()

    def _register_hooks(self):
        for layer_idx, layer in enumerate(self.model.model.layers):
            if hasattr(layer, 'block_sparse_moe') and hasattr(layer.block_sparse_moe, 'experts'):
                experts = layer.block_sparse_moe.experts
                try:
                    num_layer_experts = len(experts)
                except Exception:
                    num_layer_experts = self.num_experts
                print(f"Registering hooks for layer {layer_idx}: {num_layer_experts} experts")
                for expert_idx, expert_module in enumerate(experts):
                    handle = expert_module.register_forward_hook(self._make_hook(layer_idx, expert_idx))
                    self.hooks.append(handle)

    def _make_hook(self, layer_idx: int, expert_idx: int):
        def hook_fn(module, inputs, output):
            if output is None:
                return
            try:
                if output.dim() == 2:
                    avg = output.mean(dim=0).detach().to('cpu')
                elif output.dim() > 2:
                    flat = output.view(output.size(0), -1)
                    avg = flat.mean(dim=0).detach().to('cpu')
                else:
                    avg = output.detach().to('cpu')
                self.expert_outputs[layer_idx][expert_idx].append(avg)
            except Exception:
                pass
        return hook_fn

    def get_layer_outputs(self, layer_idx: int):
        return self.expert_outputs.get(layer_idx, {})

    def cleanup(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()


def load_resources(dataset_name: str = DATASET_NAME, num_samples: int = NUM_SAMPLES, batch_size: int = BATCH_SIZE, seq_len: int = SEQUENCE_LENGTH):
    print("Loading model and tokenizer...")
    print(f"Using {NUM_GPUS} GPUs with device_map='auto'")
    model = MixtralForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading and preparing dataset...")
    if dataset_name == "wikitext":
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split='train')

        def tokenize_function(examples):
            return tokenizer(
                examples['text'],
                truncation=True,
                max_length=seq_len,
                padding="max_length",
                return_tensors="pt",
            )

        tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
        tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        calib = tokenized.select(range(min(num_samples, len(tokenized))))
        dataloader = DataLoader(calib, batch_size=batch_size, shuffle=False)
    elif dataset_name == "c4":
        data_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "c4-train.00000-of-01024.json.gz")
        data_path = os.path.abspath(data_path)
        try:
            if os.path.exists(data_path):
                dataset = load_dataset('json', data_files={'train': data_path}, trust_remote_code=True)
                train_split = dataset['train']
            else:
                train_split = load_dataset('c4', 'en', split='train', trust_remote_code=True)
        except Exception:
            print("Loading C4 failed, falling back to wikitext")
            train_split = load_dataset("wikitext", "wikitext-2-raw-v1", split='train')

        if num_samples > 0:
            try:
                sampled = train_split.shuffle(seed=42).select(range(min(num_samples * 16, len(train_split))))
            except Exception:
                print("Sampling failed")
                sampled = train_split
        else:
            sampled = train_split

        def tokenize_function(examples):
            text_key = 'text' if 'text' in examples else list(examples.keys())[0]
            return tokenizer(examples[text_key])

        tokenized = sampled.map(
            tokenize_function,
            batched=True,
            remove_columns=list(sampled.features),
        )

        def group_texts(examples):
            concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()}
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            if total_length >= seq_len:
                total_length = (total_length // seq_len) * seq_len
            result = {
                k: [t[i: i + seq_len] for i in range(0, total_length, seq_len)]
                for k, t in concatenated_examples.items()
            }
            result["labels"] = result["input_ids"].copy()
            return result

        lm_dataset = tokenized.map(group_texts, batched=True)
        if num_samples > 0:
            if len(lm_dataset) > num_samples:       
                lm_dataset = lm_dataset.select(range(num_samples))
        dataloader = DataLoader(
            lm_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=default_data_collator,
        )
    else:
        print(f"Unsupported dataset: {dataset_name}")
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    return model, tokenizer, dataloader


def collect_outputs(model, dataloader) -> ExpertOutputCollector:
    print("Initializing expert output collector...")
    collector = ExpertOutputCollector(model)
    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Collecting expert outputs")):
            if batch_idx >= MAX_SEQUENCES:
                break
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            _ = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False,
                output_attentions=False,
                use_cache=False,
                return_dict=True,
            )
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    return collector


def _stack_expert_samples(
    layer_outputs: Dict[int, List[torch.Tensor]]
) -> Tuple[List[int], List[torch.Tensor], int]: 
    if not layer_outputs:
        return [], [], 0
    expert_indices = sorted(layer_outputs.keys())
    samples_per_expert = [len([v for v in layer_outputs[e] if v is not None]) for e in expert_indices]
    if not samples_per_expert or min(samples_per_expert) < 2:
        return expert_indices, [], 0
    min_samples = min(samples_per_expert)

    expert_mats: List[torch.Tensor] = []
    for e in expert_indices:
        vectors = [v for v in layer_outputs[e] if v is not None][:min_samples]
        if not vectors:
            expert_mats.append(torch.zeros(min_samples, 1))
        else:
            mat = torch.stack(vectors, dim=0)  # [min_samples, feat]
            expert_mats.append(mat)

    return expert_indices, expert_mats, min_samples


def build_correlation_matrix(layer_outputs: Dict[int, List[torch.Tensor]]) -> torch.Tensor:
    if not layer_outputs:
        return torch.zeros(1, 1)
    expert_indices, expert_mats, min_samples = _stack_expert_samples(layer_outputs)
    num_experts = len(expert_indices)
    if min_samples < 2 or not expert_mats:
        return torch.zeros(num_experts if num_experts > 0 else 1, num_experts if num_experts > 0 else 1)

    R = torch.zeros(num_experts, num_experts, dtype=torch.float32)
    for i in range(num_experts):
        R[i, i] = 1.0
        for j in range(i + 1, num_experts):
            mi = expert_mats[i]
            mj = expert_mats[j]
            mean_i = mi.mean(dim=0)
            mean_j = mj.mean(dim=0)
            std_i = mi.std(dim=0)
            std_j = mj.std(dim=0)
            valid = (std_i > 0) & (std_j > 0)
            if valid.sum() == 0:
                corr = 0.0
            else:
                ci = (mi[:, valid] - mean_i[valid]).mean(dim=0)
                cj = (mj[:, valid] - mean_j[valid]).mean(dim=0)
                denom = (std_i[valid].mean() * std_j[valid].mean()).item()
                if denom == 0:
                    corr = 0.0
                else:
                    corr = float((ci * cj).mean().item() / denom)
            corr = float(max(-1.0, min(1.0, corr)))
            R[i, j] = corr
            R[j, i] = corr
    R = 0.5 * (R + R.T)
    R = R + 1e-6 * torch.eye(num_experts, dtype=R.dtype)
    return R


def gaussian_cca_pairwise_mi(
    X: torch.Tensor,
    Y: torch.Tensor,
    eps: float = 1e-6,
    ridge: float = 1e-4,
    device=None,
    use_float64: bool = False,
) -> float:
    if X.dim() != 2 or Y.dim() != 2 or X.size(0) != Y.size(0) or X.size(0) < 2:
        return 0.0
    target_device = device if device is not None else (X.device if X.is_cuda else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")))
    target_dtype = torch.float64 if use_float64 else torch.float32
    X = X.to(device=target_device, dtype=target_dtype, non_blocking=True)
    Y = Y.to(device=target_device, dtype=target_dtype, non_blocking=True)
    n = X.size(0)
    Xc = X - X.mean(dim=0, keepdim=True)
    Yc = Y - Y.mean(dim=0, keepdim=True)
    Sxx = (Xc.T @ Xc) / (n - 1 + eps)
    Syy = (Yc.T @ Yc) / (n - 1 + eps)
    Sxy = (Xc.T @ Yc) / (n - 1 + eps)
    dx = Sxx.size(0)
    dy = Syy.size(0)
    Sxx = Sxx + ridge * torch.eye(dx, dtype=Sxx.dtype, device=target_device)
    Syy = Syy + ridge * torch.eye(dy, dtype=Syy.dtype, device=target_device)

    def inv_sqrt(mat: torch.Tensor, min_eig: float = 1e-6) -> torch.Tensor:
        eigvals, eigvecs = torch.linalg.eigh(mat)
        eigvals = torch.clamp(eigvals, min=min_eig)
        inv_sqrt_vals = torch.diag(torch.pow(eigvals, -0.5))
        return (eigvecs @ inv_sqrt_vals @ eigvecs.T)

    Sxx_mhalf = inv_sqrt(Sxx)
    Syy_mhalf = inv_sqrt(Syy)

    T = Sxx_mhalf @ Sxy @ Syy_mhalf
    try:
        sv = torch.linalg.svdvals(T)
    except RuntimeError:
        _, sv, _ = torch.linalg.svd(T, full_matrices=False)
    sv = torch.clamp(sv, 0.0, 1.0 - eps)
    arg = torch.clamp(1.0 - sv * sv, min=eps)
    mi_tensor = -0.5 * torch.log(arg).sum()
    mi = float(mi_tensor.detach().to("cpu").item())
    if not np.isfinite(mi) or mi < 0:
        mi = float(max(0.0, mi))
    return mi


def gaussian_differential_entropy(
    X: torch.Tensor,
    eps: float = 1e-6,
    ridge: float = 1e-4,
    device=None,
    use_float64: bool = False,
) -> float:
    if X.dim() != 2 or X.size(0) < 2 or X.size(1) < 1:
        return 0.0
    target_device = device if device is not None else (X.device if X.is_cuda else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")))
    target_dtype = torch.float64 if use_float64 else torch.float32
    X = X.to(device=target_device, dtype=target_dtype, non_blocking=True)
    n = X.size(0)
    d = X.size(1)
    Xc = X - X.mean(dim=0, keepdim=True)
    Sigma = (Xc.T @ Xc) / (n - 1 + eps)
    Sigma = Sigma + ridge * torch.eye(d, dtype=Sigma.dtype, device=target_device)
    # Use eigenvalues for stable logdet
    try:
        eigvals = torch.linalg.eigvalsh(Sigma)
    except RuntimeError:
        eigvals, _ = torch.linalg.eigh(Sigma)
    eigvals = torch.clamp(eigvals, min=eps)
    logdet = torch.log(eigvals).sum()
    const = float(d) * float(np.log(2.0 * np.pi * np.e))
    h_tensor = 0.5 * (logdet + const)
    h = float(h_tensor.detach().to("cpu").item())
    if not np.isfinite(h):
        return 0.0
    return h


def mutual_information_sum_from_R(R: torch.Tensor, eps: float = 1e-8) -> float:
    if R.numel() <= 1 or R.size(0) < 2:
        return 0.0
    N = R.size(0)
    info = 0.0
    for i in range(N):
        for j in range(i + 1, N):
            rij = float(R[i, j].item())
            rij = max(-1.0 + eps, min(1.0 - eps, rij))
            arg = max(eps, 1.0 - rij * rij)
            mi = -0.5 * np.log(arg)
            info += float(mi)
    return float(info)


def compute_layer_mi_scores(model: MixtralForCausalLM, collector: ExpertOutputCollector) -> Tuple[Dict[int, float], Dict[int, List[List[float]]]]:
    scores: Dict[int, float] = {}
    matrices: Dict[int, List[List[float]]] = {}
    num_layers = model.config.num_hidden_layers
    for layer_idx in tqdm(range(num_layers), desc="Computing MI scores"):
        layer_outputs = collector.get_layer_outputs(layer_idx)
        if not layer_outputs:
            scores[layer_idx] = 0.0
            matrices[layer_idx] = []
            continue
        expert_indices, expert_mats, min_samples = _stack_expert_samples(layer_outputs)
        if min_samples < 2 or not expert_mats:
            scores[layer_idx] = 0.0
            matrices[layer_idx] = []
            continue
        num_experts = len(expert_indices)
        score = 0.0
        mi_matrix = [[0.0 for _ in range(num_experts)] for _ in range(num_experts)]
        for i in range(num_experts):
            Xi = expert_mats[i]
            hi = gaussian_cca_pairwise_mi(Xi, Xi, device=DEVICE)
            mi_matrix[i][i] = float(hi)
        for i in range(num_experts):
            for j in range(i + 1, num_experts):
                Xi = expert_mats[i]
                Xj = expert_mats[j]
                mij = gaussian_cca_pairwise_mi(Xi, Xj, device=DEVICE)
                score += mij
                mi_matrix[i][j] = float(mij)
                mi_matrix[j][i] = float(mij)
        scores[layer_idx] = score
        matrices[layer_idx] = mi_matrix
    return scores, matrices


def normalize_and_format_results(scores: Dict[int, float], mi_matrices: Dict[int, List[List[float]]], model: MixtralForCausalLM) -> Dict[str, Dict[str, float]]:
    num_layers = model.config.num_hidden_layers
    raw = np.array([scores.get(i, 0.0) for i in range(num_layers)], dtype=np.float64)
    if raw.max() > raw.min():
        norm = (raw - raw.min()) / (raw.max() - raw.min())
    else:
        norm = np.ones_like(raw) * 0.5
    result: Dict[str, Dict[str, float]] = {}
    for i in range(num_layers):
        result[f"layer_{i}"] = {
            "mutual_information_sum": float(raw[i]),
            "normalized_importance": float(norm[i].item()),
            "final_score": float(norm[i].item()),
            "layer_importance_rank": int(np.argsort(-norm)[i] + 1),
            "importance_category": "high" if norm[i] > 0.7 else ("medium" if norm[i] > 0.3 else "low"),
            "pairwise_mi_matrix": mi_matrices.get(i, []),
        }
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default=DATASET_NAME, choices=["wikitext", "c4"], help="calibration dataset")
    parser.add_argument("--num_samples", type=int, default=NUM_SAMPLES, help="mapped to n_blocks_for_stat")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE)
    parser.add_argument("--seq_len", type=int, default=SEQUENCE_LENGTH, help="mapped to max_block_size")
    args = parser.parse_args()
    start = time.time()
    print("=" * 60)
    print("Mixtral Sum of Mutual Information from Correlation Matrix-based Layer Importance")
    print("=" * 60)
    print(f"Using device: {DEVICE}")
    print(f"Number of GPUs: {NUM_GPUS}")
    print(f"Model path: {MODEL_PATH}")
    print(f"Num samples: {NUM_SAMPLES} | Batch size: {BATCH_SIZE} | Seq len: {SEQUENCE_LENGTH} | Max sequences: {MAX_SEQUENCES}")

    try:
        print("\n" + "=" * 40)
        print("Step 1: Loading Resources")
        print("=" * 40)
        model, tokenizer, dataloader = load_resources(args.dataset, args.num_samples, args.batch_size, args.seq_len)

        print("\n" + "=" * 40)
        print("Step 2: Collecting Expert Outputs")
        print("=" * 40)
        collector = collect_outputs(model, dataloader)


        print("\n" + "=" * 40)
        print("Step 3: Computing Layer MI Scores")
        print("=" * 40)
        try:
            mi_scores, mi_matrices = compute_layer_mi_scores(model, collector)
        finally:
            collector.cleanup()

        print("\n" + "=" * 40)
        print("Step 4: Processing Results")
        print("=" * 40)
        formatted = normalize_and_format_results(mi_scores, mi_matrices, model)

        print(f"Saving importance scores to {OUTPUT_FILE}...")
        with open(OUTPUT_FILE, 'w') as f:
            json.dump(formatted, f, indent=4)

        elapsed = time.time() - start
        print(f"\nCompleted successfully in {elapsed:.2f} seconds!")

    except Exception as e:
        print(f"\nError occurred: {e}")
        import traceback
        traceback.print_exc()
        return

    print("\nSample Layer Importance Scores:")
    for i in range(min(8, len(formatted))):
        lk = f"layer_{i}"
        s = formatted[lk]
        print(f"  {lk}:")
        print(f"    Mutual Information Sum: {s['mutual_information_sum']:.6f}")
        print(f"    Normalized Importance: {s['normalized_importance']:.4f}")
        print(f"    Importance Rank: {s['layer_importance_rank']}")
        print(f"    Category: {s['importance_category']}")

    all_mi = [d['mutual_information_sum'] for d in formatted.values()]
    all_norm = [d['normalized_importance'] for d in formatted.values()]
    print("\nOverall Statistics:")
    print(f"  Total layers: {len(formatted)}")
    print(f"  MI sum - Mean: {np.mean(all_mi):.6f}, Std: {np.std(all_mi):.6f}")
    print(f"  MI sum - Min: {np.min(all_mi):.6f}, Max: {np.max(all_mi):.6f}")
    print(f"  Normalized scores - Mean: {np.mean(all_norm):.4f}, Std: {np.std(all_norm):.4f}")

    high = [k for k, v in formatted.items() if v['importance_category'] == 'high']
    med = [k for k, v in formatted.items() if v['importance_category'] == 'medium']
    low = [k for k, v in formatted.items() if v['importance_category'] == 'low']
    print("\nLayer Importance Distribution:")
    print(f"  High importance ({len(high)} layers): {high}")
    print(f"  Medium importance ({len(med)} layers): {med}")
    print(f"  Low importance ({len(low)} layers): {low}")

    sorted_layers = sorted(formatted.items(), key=lambda x: x[1]['normalized_importance'], reverse=True)
    print("\nTop 5 Most Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[:5]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")

    print("\nTop 5 Least Important Layers:")
    for i, (lk, s) in enumerate(sorted_layers[-5:]):
        print(f"  {i+1}. {lk}: {s['normalized_importance']:.4f}")


if __name__ == "__main__":
    main()


