import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn.functional as F
from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
    DeepseekV2ForCausalLM,
    DeepseekV2MoE,
    DeepseekV2MoEGate
)
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import numpy as np
import time

MODEL_PATH = "/Path/DeepSeek-V2-Lite"
OUTPUT_FILE = "./importance_score/expert_wise/deepseek_v2_norm_drop_loss_fre.json"

ALPHA = 0.5
BETA = 0.5




NUM_SAMPLES = 256  
BATCH_SIZE = 4
MAX_SEQUENCE_LENGTH = 10240
MIN_SEQUENCE_LENGTH = 5120

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ExpertImportanceCacheDataset:
    def __init__(self):
        self.Xs = []  
        self.Zs = []  

    def append(self, X=None, Z=None):
        if X is not None:
            if X.dim() == 3:
                X = X.view(-1, X.size(-1))
            self.Xs.append(X.detach().to('cpu', non_blocking=True))
        if Z is not None:
            if Z.dim() == 3:
                Z = Z.view(-1, Z.size(-1))
            self.Zs.append(Z.detach().to('cpu', non_blocking=True))

    def __len__(self):
        return len(self.Xs)

    def __iter__(self):
        return zip(self.Xs, self.Zs)

class ExpertImportanceWrapper(torch.nn.Module):
    def __init__(self, model_layer):
        super().__init__()
        self.model = model_layer.mlp
        self.cache_space = ExpertImportanceCacheDataset()
        self.cache_X = False
        self.cache_Z = False
        self.expert_to_drop = None  

    def forward(self, hidden_states: torch.Tensor):
        if hidden_states.dim() == 2:
            hidden_states_3d = hidden_states.unsqueeze(0)
        elif hidden_states.dim() == 3:
            hidden_states_3d = hidden_states
        else:
            raise ValueError(f"Unexpected input dimensions: {hidden_states.shape}")

        batch_size, sequence_length, hidden_dim = hidden_states_3d.shape

        gate_device = next(self.model.gate.parameters()).device
        if hidden_states_3d.device != gate_device:
            hidden_states_3d = hidden_states_3d.to(gate_device)

        residuals = hidden_states_3d

        gate = self.model.gate
        topk_method = getattr(gate, 'topk_method', 'greedy')

        if self.expert_to_drop is not None or topk_method not in ("greedy", "group_limited_greedy"):
            flat_states = hidden_states_3d.view(-1, hidden_dim)
            logits = F.linear(flat_states.to(torch.float32), gate.weight.to(torch.float32), None)
            if self.expert_to_drop is not None:
                logits[:, self.expert_to_drop] = -float('inf')
            scores = logits.softmax(dim=-1, dtype=torch.float32)

            if topk_method == "group_limited_greedy":
                num_group = gate.num_group
                topk_group = gate.topk_group
                group_scores = scores.view(batch_size * sequence_length, num_group, -1).max(dim=-1).values
                group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
                group_mask = torch.zeros_like(group_scores)
                group_mask.scatter_(1, group_idx, 1)
                score_mask = (
                    group_mask.unsqueeze(-1)
                    .expand(batch_size * sequence_length, num_group, gate.num_experts // num_group)
                    .reshape(batch_size * sequence_length, -1)
                )
                tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
                topk_weight, topk_idx = torch.topk(tmp_scores, k=gate.top_k, dim=-1, sorted=False)
            else:
                topk_weight, topk_idx = torch.topk(scores, k=gate.top_k, dim=-1, sorted=False)

            topk_weight = topk_weight * gate.routed_scaling_factor
        else:
            topk_idx, topk_weight = gate(hidden_states_3d)
            flat_states = hidden_states_3d.view(-1, hidden_dim)

        moe_out_flat = self.model.moe(flat_states, topk_idx, topk_weight)
        moe_out_3d = moe_out_flat.view(batch_size, sequence_length, hidden_dim)

        final_hidden_states = moe_out_3d + self.model.shared_experts(residuals)

        self.last_topk_idx = topk_idx.view(batch_size, sequence_length, -1)
        self.last_topk_weight = topk_weight.view(batch_size, sequence_length, -1)

        if self.cache_X or self.cache_Z:
            self.cache_space.append(
                X=(flat_states if self.cache_X else None),
                Z=(moe_out_flat if self.cache_Z else None)
            )

        return final_hidden_states

    @torch.no_grad()
    def evaluate_expert_importance(self):
        self.cache_X = False
        self.cache_Z = False
        
        expert_losses = {}
        
        with torch.inference_mode():
            num_experts = getattr(self.model.gate, 'num_experts', None)
            if num_experts is None:
                num_experts = self.model.config.n_routed_experts
            for expert_idx in range(num_experts):
                self.expert_to_drop = expert_idx
                loss = 0
                
                for (hidden_states, final_hidden_states) in self.cache_space:
                    wrapper_device = next(self.model.parameters()).device
                    
                    hidden_states = hidden_states.to(device=wrapper_device, non_blocking=True)
                    final_hidden_states = final_hidden_states.to(
                        dtype=torch.float64, device=wrapper_device, non_blocking=True)

                    if hidden_states.dim() == 2:
                        
                        hidden_states_3d = hidden_states.unsqueeze(0)
                    else:
                        hidden_states_3d = hidden_states
                    
                    final_hidden_states_e = self.forward(hidden_states_3d)
                    
                    if final_hidden_states_e.dim() == 3:
                        final_hidden_states_e = final_hidden_states_e.squeeze(0)
                    
                    loss += torch.norm(final_hidden_states -
                                     final_hidden_states_e.to(torch.float64)).item()
                
                expert_losses[expert_idx] = loss
                
        self.expert_to_drop = None
        
        return expert_losses

def load_resources():
    print("Loading model and tokenizer...")
    model = DeepseekV2ForCausalLM.from_pretrained(
        MODEL_PATH,
        dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading and preparing dataset...")
    
    
    
    dataset = load_dataset('json', data_files={'train': './data/c4-train.00000-of-01024.json.gz'}, trust_remote_code=True)['train']
    
    def num_tokens(text: str) -> int:
        return len(tokenizer(text, add_special_tokens=False).input_ids)

    def keep_example(example) -> bool:
        return num_tokens(example["text"]) > MIN_SEQUENCE_LENGTH
    
    def tokenize_function(examples):
        return tokenizer(
            examples['text'], 
            truncation=True, 
            max_length=MAX_SEQUENCE_LENGTH, 
            padding="max_length",
            return_tensors="pt"
        )

    tokenized_dataset = dataset.filter(keep_example).map(
        tokenize_function, 
        batched=True, 
        remove_columns=["text"]
    )
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    calibration_dataset = tokenized_dataset.select(range(min(NUM_SAMPLES, len(tokenized_dataset))))
    dataloader = DataLoader(calibration_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return model, tokenizer, dataloader

def calculate_expert_frequencies_and_losses(model, dataloader):
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.n_routed_experts
    
    expert_frequencies = [{i: 0 for i in range(num_experts)} for _ in range(num_layers)]
    layer_wrappers = []
    
    print("Preparing layer wrappers...")
    for layer_idx in range(num_layers):
        layer = model.model.layers[layer_idx]
        if hasattr(layer, 'mlp') and isinstance(layer.mlp, DeepseekV2MoE):
            wrapper = ExpertImportanceWrapper(layer)
            wrapper.cache_X = True
            wrapper.cache_Z = True
            layer_wrappers.append(wrapper)
            layer.mlp = wrapper
        else:
            layer_wrappers.append(None)

    print("Calculating expert activation frequencies and collecting cache data...")
    model.eval()
    
    router_outputs = {}

    def create_hook(layer_idx):
        def hook(module, input, output):
            if hasattr(module, 'last_topk_idx'):
                topk_idx = module.last_topk_idx  
                expert_indices = topk_idx[..., 0].cpu()
                if layer_idx not in router_outputs:
                    router_outputs[layer_idx] = []
                router_outputs[layer_idx].append(expert_indices)
        return hook

    hooks = []
    for layer_idx, wrapper in enumerate(layer_wrappers):
        if wrapper is not None:
            hook = wrapper.register_forward_hook(create_hook(layer_idx))
            hooks.append(hook)

    try:
        start_time = time.time()
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Frequency and Cache Data Collection"):
                inputs = {k: v.to(DEVICE) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}

                router_outputs.clear()
                
                outputs = model(**inputs)
                
                for layer_idx, expert_choices_list in router_outputs.items():
                    for expert_choices in expert_choices_list:
                        for expert_idx in range(num_experts):
                            count = (expert_choices == expert_idx).sum().item()
                            expert_frequencies[layer_idx][expert_idx] += count
        print(f"Frequency calculation completed time: {time.time() - start_time}")
    finally:
        for hook in hooks:
            hook.remove()
    
    print("Calculating expert reconstruction losses...")
    expert_losses = []
    for layer_idx, wrapper in enumerate(tqdm(layer_wrappers, desc="Expert Loss Calculation")):
        if wrapper is not None:
            if layer_idx < 16:
                target_device = 'cuda:0'
            else:
                target_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
            
            wrapper.to(target_device)
            layer_losses = wrapper.evaluate_expert_importance()
            expert_losses.append(layer_losses)
            wrapper.to('cpu')
            torch.cuda.empty_cache()  
        else:
            expert_losses.append({})
    
    for layer_idx, wrapper in enumerate(layer_wrappers):
        if wrapper is not None:
            model.model.layers[layer_idx].mlp = wrapper.model
    
    return expert_frequencies, expert_losses

def calculate_importance_scores(model, expert_frequencies, expert_losses):
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.n_routed_experts
    
    loss_scores = []
    freq_scores = []

    print("Processing reconstruction loss scores...")
    
    for layer_idx in range(num_layers):
        layer_loss_scores = []
        layer_freq_scores = []
        
        for expert_idx in range(num_experts):
            if layer_idx < len(expert_losses) and expert_idx in expert_losses[layer_idx]:
                loss_score = expert_losses[layer_idx][expert_idx]
            else:
                loss_score = 0.0
            
            layer_loss_scores.append(loss_score)
            layer_freq_scores.append(expert_frequencies[layer_idx][expert_idx])

        loss_scores.append(layer_loss_scores)
        freq_scores.append(layer_freq_scores)

    loss_scores_tensor = torch.tensor(loss_scores, dtype=torch.float32)
    freq_scores_tensor = torch.tensor(freq_scores, dtype=torch.float32)

    normalized_loss = torch.zeros_like(loss_scores_tensor)
    normalized_freq = torch.zeros_like(freq_scores_tensor)
    
    for layer_idx in range(num_layers):
        layer_loss = loss_scores_tensor[layer_idx]
        layer_freq = freq_scores_tensor[layer_idx]
        
        loss_min, loss_max = layer_loss.min(), layer_loss.max()
        if loss_max > loss_min:
            normalized_loss[layer_idx] = (layer_loss - loss_min) / (loss_max - loss_min)
        else:
            normalized_loss[layer_idx] = torch.ones_like(layer_loss) * 0.5
        
        freq_min, freq_max = layer_freq.min(), layer_freq.max()
        if freq_max > freq_min:
            normalized_freq[layer_idx] = (layer_freq - freq_min) / (freq_max - freq_min)
        else:
            normalized_freq[layer_idx] = torch.ones_like(layer_freq) * 0.5

    final_scores = ALPHA * normalized_loss + BETA * normalized_freq

    result = {}
    for layer_idx in range(num_layers):
        result[f'layer_{layer_idx}'] = {
            f'expert_{expert_idx}': {
                'final_score': final_scores[layer_idx, expert_idx].item(),
                'raw_loss_score': loss_scores_tensor[layer_idx, expert_idx].item(),
                'normalized_loss': normalized_loss[layer_idx, expert_idx].item(),
                'frequency': freq_scores_tensor[layer_idx, expert_idx].item(),
                'normalized_freq': normalized_freq[layer_idx, expert_idx].item()
            }
            for expert_idx in range(num_experts)
        }
    
    return result

def main():
    print(f"Using device: {DEVICE}")
    print(f"Model path: {MODEL_PATH}")
    print(f"Alpha: {ALPHA}, Beta: {BETA}")
    
    model, tokenizer, dataloader = load_resources()
    
    expert_frequencies, expert_losses = calculate_expert_frequencies_and_losses(model, dataloader)
    
    importance_scores = calculate_importance_scores(model, expert_frequencies, expert_losses)
    
    print(f"Saving importance scores to {OUTPUT_FILE}...")
    with open(OUTPUT_FILE, 'w') as f:
        json.dump(importance_scores, f, indent=4)
        
    print("Done!")

    print("\nSample scores for Layer 0:")
    for expert_idx in range(min(4, len(importance_scores['layer_0']))):  
        expert_key = f'expert_{expert_idx}'
        scores = importance_scores['layer_0'][expert_key]
        print(f"  {expert_key}:")
        print(f"    Final Score: {scores['final_score']:.4f}")
        print(f"    Raw Loss Score: {scores['raw_loss_score']:.4f}")
        print(f"    Frequency: {scores['frequency']}")
    
    all_final_scores = []
    total_experts = 0
    layer_stats = {}
    
    for layer_key, layer_data in importance_scores.items():
        layer_scores = []
        for expert_key, expert_data in layer_data.items():
            score = expert_data['final_score']
            all_final_scores.append(score)
            layer_scores.append(score)
            total_experts += 1
        
        layer_stats[layer_key] = {
            'mean': np.mean(layer_scores),
            'std': np.std(layer_scores),
            'min': np.min(layer_scores),
            'max': np.max(layer_scores)
        }
    
    print(f"\nOverall Statistics:")
    print(f"  Total experts: {total_experts}")
    print(f"  Global mean final score: {np.mean(all_final_scores):.4f}")
    print(f"  Global std final score: {np.std(all_final_scores):.4f}")
    print(f"  Global min final score: {np.min(all_final_scores):.4f}")
    print(f"  Global max final score: {np.max(all_final_scores):.4f}")
    
    print(f"\nPer-Layer Statistics (after layer-wise normalization):")
    for layer_key, stats in layer_stats.items():
        print(f"  {layer_key}:")
        print(f"    Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
        print(f"    Min: {stats['min']:.4f}, Max: {stats['max']:.4f}")

if __name__ == "__main__":
    main()
