import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import torch
import torch.nn.functional as F
from transformers import MixtralForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import numpy as np
import time

MODEL_PATH = "/Path/Mixtral-8x7B-v0.1"
OUTPUT_FILE = "./importance_score/expert_wise/mixtral_norm_drop_loss_fre.json"

ALPHA = 0.5
BETA = 0.5

DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
NUM_SAMPLES = 256  
BATCH_SIZE = 4
SEQUENCE_LENGTH = 1024

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ExpertImportanceCacheDataset:
    def __init__(self):
        self.Xs = []  
        self.Zs = []  

    def append(self, X=None, Z=None):
        if X is not None:
            if X.dim() == 3:
                X = X.view(-1, X.size(-1))
            self.Xs.append(X.detach().to('cpu', non_blocking=True))
        if Z is not None:
            if Z.dim() == 3:
                Z = Z.view(-1, Z.size(-1))
            self.Zs.append(Z.detach().to('cpu', non_blocking=True))

    def __len__(self):
        return len(self.Xs)

    def __iter__(self):
        return zip(self.Xs, self.Zs)

class ExpertImportanceWrapper(torch.nn.Module):
    def __init__(self, model_layer):
        super().__init__()
        self.model = model_layer.block_sparse_moe
        self.cache_space = ExpertImportanceCacheDataset()
        self.cache_X = False
        self.cache_Z = False
        self.expert_to_drop = None  

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if hidden_states.dim() == 3:
            batch_size, sequence_length, hidden_dim = hidden_states.shape
            hidden_states = hidden_states.view(-1, hidden_dim)
            reshape_needed = True
        elif hidden_states.dim() == 2:
            total_tokens, hidden_dim = hidden_states.shape
            batch_size = 1
            sequence_length = total_tokens
            reshape_needed = True
        else:
            raise ValueError(f"Unexpected input dimensions: {hidden_states.shape}")
        
        gate_device = next(self.model.gate.parameters()).device
        if hidden_states.device != gate_device:
            hidden_states = hidden_states.to(gate_device)
        
        router_logits = self.model.gate(hidden_states)

        if self.expert_to_drop is not None:
            router_logits[:, self.expert_to_drop] = -float('inf')

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(
            routing_weights, self.model.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), 
            dtype=hidden_states.dtype, 
            device=hidden_states.device
        )

        expert_mask = torch.nn.functional.one_hot(
            selected_experts, num_classes=self.model.num_experts).permute(2, 1, 0)

        for expert_idx in range(self.model.num_experts):
            expert_layer = self.model.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
                continue

            top_x_list = top_x.tolist()
            idx_list = idx.tolist()
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            
            current_hidden_states = expert_layer(current_state)
            current_hidden_states = current_hidden_states * routing_weights[top_x_list, idx_list, None]

            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        if reshape_needed:
            final_hidden_states_output = final_hidden_states.reshape(
                batch_size, sequence_length, hidden_dim)
        else:
            final_hidden_states_output = final_hidden_states

        if self.cache_X or self.cache_Z:
            self.cache_space.append(
                X=(hidden_states if self.cache_X else None),
                Z=(final_hidden_states if self.cache_Z else None)
            )

        return final_hidden_states_output, router_logits

    @torch.no_grad()
    def evaluate_expert_importance(self):
        self.cache_X = False
        self.cache_Z = False
        
        expert_losses = {}
        
        with torch.inference_mode():
            for expert_idx in range(self.model.num_experts):
                self.expert_to_drop = expert_idx
                loss = 0
                
                for (hidden_states, final_hidden_states) in self.cache_space:
                    wrapper_device = next(self.model.parameters()).device
                    
                    hidden_states = hidden_states.to(device=wrapper_device, non_blocking=True)
                    final_hidden_states = final_hidden_states.to(
                        dtype=torch.float64, device=wrapper_device, non_blocking=True)

                    if hidden_states.dim() == 2:
                        hidden_states_3d = hidden_states.unsqueeze(0)
                    else:
                        hidden_states_3d = hidden_states
                    
                    final_hidden_states_e, _ = self.forward(hidden_states_3d)
                    
                    if final_hidden_states_e.dim() == 3:
                        final_hidden_states_e = final_hidden_states_e.squeeze(0)
                    
                    loss += torch.norm(final_hidden_states -
                                     final_hidden_states_e.to(torch.float64)).item()
                
                expert_losses[expert_idx] = loss
                
        self.expert_to_drop = None
        
        return expert_losses

def load_resources():
    print("Loading model and tokenizer...")
    model = MixtralForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading and preparing dataset...")
    dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split='train')
    
    def tokenize_function(examples):
        return tokenizer(
            examples['text'], 
            truncation=True, 
            max_length=SEQUENCE_LENGTH, 
            padding="max_length",
            return_tensors="pt"
        )

    tokenized_dataset = dataset.map(
        tokenize_function, 
        batched=True, 
        remove_columns=["text"]
    )
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    calibration_dataset = tokenized_dataset.select(range(min(NUM_SAMPLES, len(tokenized_dataset))))
    dataloader = DataLoader(calibration_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return model, tokenizer, dataloader

def calculate_expert_frequencies_and_losses(model, dataloader):
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_local_experts
    
    expert_frequencies = [{i: 0 for i in range(num_experts)} for _ in range(num_layers)]
    layer_wrappers = []
    
    print("Preparing layer wrappers...")
    for layer_idx in range(num_layers):
        layer = model.model.layers[layer_idx]
        if hasattr(layer, 'block_sparse_moe'):
            wrapper = ExpertImportanceWrapper(layer)
            wrapper.cache_X = True
            wrapper.cache_Z = True
            layer_wrappers.append(wrapper)
            layer.block_sparse_moe = wrapper
        else:
            layer_wrappers.append(None)

    print("Calculating expert activation frequencies and collecting cache data...")
    model.eval()
    
    router_outputs = {}
    
    def create_hook(layer_idx):
        def hook(module, input, output):
            if isinstance(output, tuple) and len(output) >= 2:
                router_probs = output[1]  
                expert_indices = torch.argmax(router_probs, dim=-1)  
                
                if layer_idx not in router_outputs:
                    router_outputs[layer_idx] = []
                router_outputs[layer_idx].append(expert_indices.cpu())
            
        return hook

    hooks = []
    for layer_idx, wrapper in enumerate(layer_wrappers):
        if wrapper is not None:
            hook = wrapper.register_forward_hook(create_hook(layer_idx))
            hooks.append(hook)

    try:
        start_time = time.time()
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Frequency and Cache Data Collection"):
                inputs = {k: v.to(DEVICE) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
                
                router_outputs.clear()
                
                outputs = model(**inputs)
                
                for layer_idx, expert_choices_list in router_outputs.items():
                    for expert_choices in expert_choices_list:
                        for expert_idx in range(num_experts):
                            count = (expert_choices == expert_idx).sum().item()
                            expert_frequencies[layer_idx][expert_idx] += count
        print(f"Frequency calculation completed time: {time.time() - start_time}")
    finally:
        for hook in hooks:
            hook.remove()
    
    print("Calculating expert reconstruction losses...")
    expert_losses = []
    for layer_idx, wrapper in enumerate(tqdm(layer_wrappers, desc="Expert Loss Calculation")):
        if wrapper is not None:
            if layer_idx < 16:
                target_device = 'cuda:0'
            else:
                target_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
            
            wrapper.to(target_device)
            layer_losses = wrapper.evaluate_expert_importance()
            expert_losses.append(layer_losses)
            wrapper.to('cpu')
            torch.cuda.empty_cache()  
        else:
            expert_losses.append({})
    
    for layer_idx, wrapper in enumerate(layer_wrappers):
        if wrapper is not None:
            model.model.layers[layer_idx].block_sparse_moe = wrapper.model
    
    return expert_frequencies, expert_losses

def calculate_importance_scores(model, expert_frequencies, expert_losses):
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_local_experts
    
    loss_scores = []
    freq_scores = []

    print("Processing reconstruction loss scores...")
    
    for layer_idx in range(num_layers):
        layer_loss_scores = []
        layer_freq_scores = []
        
        for expert_idx in range(num_experts):
            if layer_idx < len(expert_losses) and expert_idx in expert_losses[layer_idx]:
                loss_score = expert_losses[layer_idx][expert_idx]
            else:
                loss_score = 0.0
            
            layer_loss_scores.append(loss_score)
            layer_freq_scores.append(expert_frequencies[layer_idx][expert_idx])

        loss_scores.append(layer_loss_scores)
        freq_scores.append(layer_freq_scores)

    loss_scores_tensor = torch.tensor(loss_scores, dtype=torch.float32)
    freq_scores_tensor = torch.tensor(freq_scores, dtype=torch.float32)

    normalized_loss = torch.zeros_like(loss_scores_tensor)
    normalized_freq = torch.zeros_like(freq_scores_tensor)
    
    for layer_idx in range(num_layers):
        layer_loss = loss_scores_tensor[layer_idx]
        layer_freq = freq_scores_tensor[layer_idx]
        
        loss_min, loss_max = layer_loss.min(), layer_loss.max()
        if loss_max > loss_min:
            normalized_loss[layer_idx] = (layer_loss - loss_min) / (loss_max - loss_min)
        else:
            normalized_loss[layer_idx] = torch.ones_like(layer_loss) * 0.5
        
        freq_min, freq_max = layer_freq.min(), layer_freq.max()
        if freq_max > freq_min:
            normalized_freq[layer_idx] = (layer_freq - freq_min) / (freq_max - freq_min)
        else:
            normalized_freq[layer_idx] = torch.ones_like(layer_freq) * 0.5

    final_scores = ALPHA * normalized_loss + BETA * normalized_freq

    result = {}
    for layer_idx in range(num_layers):
        result[f'layer_{layer_idx}'] = {
            f'expert_{expert_idx}': {
                'final_score': final_scores[layer_idx, expert_idx].item(),
                'raw_loss_score': loss_scores_tensor[layer_idx, expert_idx].item(),
                'normalized_loss': normalized_loss[layer_idx, expert_idx].item(),
                'frequency': freq_scores_tensor[layer_idx, expert_idx].item(),
                'normalized_freq': normalized_freq[layer_idx, expert_idx].item()
            }
            for expert_idx in range(num_experts)
        }
    
    return result

def main():
    print(f"Using device: {DEVICE}")
    print(f"Model path: {MODEL_PATH}")
    print(f"Alpha: {ALPHA}, Beta: {BETA}")
    start_time = time.time()
    print(f"Start time: {start_time}")
    model, tokenizer, dataloader = load_resources()
    print(f"Load resources time: {time.time() - start_time}")
    expert_frequencies, expert_losses = calculate_expert_frequencies_and_losses(model, dataloader)
    print(f"Calculate expert frequencies and losses time: {time.time() - start_time}")
    importance_scores = calculate_importance_scores(model, expert_frequencies, expert_losses)
    print(f"Calculate importance scores time: {time.time() - start_time}")
    print(f"Saving importance scores to {OUTPUT_FILE}...")
    with open(OUTPUT_FILE, 'w') as f:
        json.dump(importance_scores, f, indent=4)
    print(f"Save importance scores time: {time.time() - start_time}")
    print("Done!")

    print("\nSample scores for Layer 0:")
    for expert_idx in range(min(4, len(importance_scores['layer_0']))):  
        expert_key = f'expert_{expert_idx}'
        scores = importance_scores['layer_0'][expert_key]
        print(f"  {expert_key}:")
        print(f"    Final Score: {scores['final_score']:.4f}")
        print(f"    Raw Loss Score: {scores['raw_loss_score']:.4f}")
        print(f"    Frequency: {scores['frequency']}")
    
    all_final_scores = []
    total_experts = 0
    layer_stats = {}
    
    for layer_key, layer_data in importance_scores.items():
        layer_scores = []
        for expert_key, expert_data in layer_data.items():
            score = expert_data['final_score']
            all_final_scores.append(score)
            layer_scores.append(score)
            total_experts += 1
        
        layer_stats[layer_key] = {
            'mean': np.mean(layer_scores),
            'std': np.std(layer_scores),
            'min': np.min(layer_scores),
            'max': np.max(layer_scores)
        }
    
    print(f"\nOverall Statistics:")
    print(f"  Total experts: {total_experts}")
    print(f"  Global mean final score: {np.mean(all_final_scores):.4f}")
    print(f"  Global std final score: {np.std(all_final_scores):.4f}")
    print(f"  Global min final score: {np.min(all_final_scores):.4f}")
    print(f"  Global max final score: {np.max(all_final_scores):.4f}")
    
    print(f"\nPer-Layer Statistics (after layer-wise normalization):")
    for layer_key, stats in layer_stats.items():
        print(f"  {layer_key}:")
        print(f"    Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
        print(f"    Min: {stats['min']:.4f}, Max: {stats['max']:.4f}")

if __name__ == "__main__":
    main()
