import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader

import numpy as np
import argparse
import random
from tqdm import tqdm
import pickle
from sentence_transformers import SentenceTransformer
import json

# import necessary modules
from models import PruneLlama2ForCausalLM
from pruning import collect_info_reg_llama, help_functions_hn
from torch.optim.lr_scheduler import CosineAnnealingLR

from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, AutoModel
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
from torch.cuda import amp
from lib.dataset_loader import (
    load_mc_dataset, format_mc_example, evaluate_mc_example, calculate_perplexity, format_mc_prompt_with_ans, build_wikitext_ids, build_ptb_ids, sample_wikitext_sequences
)

def compute_stable_mse_loss(predicted, target, epsilon=1e-6):
    """Compute numerically stable MSE loss"""
    # clamp prediction range
    predicted_clamped = torch.clamp(predicted, min=-10, max=10)
    
    # calculate difference
    diff = predicted_clamped - target
    
    # use Huber loss (more robust to outliers)
    # huber_delta = 1.0
    # loss = torch.where(
    #     diff.abs() <= huber_delta,
    #     0.5 * diff.pow(2),
    #     huber_delta * (diff.abs() - 0.5 * huber_delta)
    # ).mean()
    
    # or use MSE with gradient clipping
    loss = (diff.pow(2) + epsilon).mean()
    
    return loss

def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

class MaskRouter(nn.Module):
    """Mask router based on sentence transformer"""
    
    def __init__(self, sentence_model_path, num_clusters, hidden_dim=None, device="cuda", use_classification=True, lora_rank=8):
        super().__init__()
        
        self.device = device
        self.use_classification = use_classification

        # Load sentence transformer
        self.model = AutoModel.from_pretrained(sentence_model_path, trust_remote_code=True, use_cache=False).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(sentence_model_path, padding_side='left')

        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()

        # Configure LoRA parameters
        lora_config = LoraConfig(
            r=lora_rank,  # LoRA rank
            lora_alpha=32,  # scaling factor
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="SEQ_CLS"
        )
        
        # Apply LoRA to Sentence Transformer
        self.model = get_peft_model(self.model, lora_config)
        
        # Print trainable parameters
        self.model.print_trainable_parameters()

        # Get embedding dimension
        self.embedding_dim = self.model.config.hidden_size

        if use_classification:
            if hidden_dim is None:
                self.classifier = nn.Linear(self.embedding_dim, num_clusters)
            else:
                self.classifier = nn.Sequential(
                    nn.Linear(self.embedding_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim, num_clusters)

                )

        # Regression layer (predict score vector instead of classification)
        if hidden_dim is None:
            # Direct regression
            self.regressor = nn.Linear(self.embedding_dim, num_clusters)
        else:
            # Regressor with hidden layers
            self.regressor = nn.Sequential(
                nn.Linear(self.embedding_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim, num_clusters)
            )
        
        self.num_clusters = num_clusters

        if self.use_classification:
            self.classifier.apply(self.init_weights)
        else:
            self.regressor.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, texts):
        """
        Args:
            texts: List of strings
        Returns:
            scores: (batch_size, num_clusters) - predicted score vector
        """
        # Get sentence embeddings
        with torch.cuda.amp.autocast():
            outputs = self.model(texts['input_ids'], attention_mask=texts['attention_mask'])
            
            embeddings = last_token_pool(outputs.last_hidden_state, texts['attention_mask'])
            
            # embeddings = F.normalize(embeddings, p=2, dim=1)
            embeddings = F.normalize(embeddings, p=2, dim=1, eps=1e-4)

        # Regression prediction scores
        if self.use_classification:
            # Output logits (without softmax)
            logits = self.classifier(embeddings)
            return logits
        else:
            # Original regression output
            scores = self.regressor(embeddings)
            return scores
    
    def predict_best_mask(self, texts):
        """Predict best mask index"""
        scores = self.forward(texts)
        return torch.argmax(scores, dim=-1)


class MaskRouterTrainer:
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
        print(f"Using device: {self.device}")
        
        # Load preprocessed data
        self.load_preprocessed_data()
        
        # If test mode, load representative masks first
        if args.mode == "test" or args.mode == "individual_mask":
            # Setup LLM model for loading masks
            self.setup_llm_for_testing()
        
        # Build Router model
        self.setup_router()
        
        # Setup training parameters
        self.setup_training()
    
    def load_preprocessed_data(self):
        """Load preprocessed data"""
        print("Loading preprocessed data...")
        self.train_data = []
        num_samples = getattr(self.args, "num_samples_per_dataset", None)
        random.seed(self.args.seed)

        if self.args.train_datasets is None:
            all_data_path = os.path.join(self.args.data_dir, "all_datasets_processed.pkl")
            with open(all_data_path, 'rb') as f:
                data = pickle.load(f)
                if num_samples is not None and len(data) > num_samples:
                    data = random.sample(data, num_samples)
                self.train_data = data
        else:
            for dataset_name in self.args.train_datasets:
                all_data_path = os.path.join(self.args.data_dir, dataset_name + "_processed.pkl")
                if not os.path.exists(all_data_path):
                    raise FileNotFoundError(f"Preprocessed data not found at {all_data_path}")
                with open(all_data_path, 'rb') as f:
                    data = pickle.load(f)
                    if num_samples is not None and len(data) > num_samples:
                        data = random.sample(data, num_samples)
                    self.train_data.extend(data)
                print(f"Loaded {len(self.train_data)} training samples")
        
        # Truncate scores vector if needed
        if hasattr(self.args, 'num_masks_to_load') and self.args.num_masks_to_load is not None:
            print(f"Truncating scores vector to match num_masks_to_load={self.args.num_masks_to_load}")
            
            for sample in self.train_data:
                # Only keep first num_masks_to_load scores
                original_scores = sample['scores']
                sample['scores'] = original_scores[:self.args.num_masks_to_load]
            
            # Print truncated label distribution (for debugging)
            all_labels = [np.argmax(sample['scores']) for sample in self.train_data]
            from collections import Counter
            label_dist = Counter(all_labels)
            print(f"Truncated label distribution: {dict(sorted(label_dist.items()))}")
            print(f"Label range: [{min(all_labels)}, {max(all_labels)}]")

    def setup_router(self):
        """Setup Router model"""
        print("Setting up router model...")
        
        # First load representative masks to get actual count
        if self.args.mode == "test" or self.args.mode == "individual_mask" or self.args.mode == "train":
            # For training mode also need to know actual mask count
            self.setup_llm_for_testing()  # This will load representative_masks
        
        self.router = MaskRouter(
            sentence_model_path=self.args.sentence_model_path,
            num_clusters=self.args.num_clusters,
            hidden_dim=self.args.router_hidden_dim,
            use_classification=self.args.use_classification,
            device=self.device
        ).to(self.device)
        
        print(f"Router embedding dim: {self.router.embedding_dim}")
        print(f"Router parameters: {sum(p.numel() for p in self.router.parameters())}")
    
    def setup_training(self):
        """Setup training parameters"""
        print("Setting up training configuration...")
        
        self.optimizer = torch.optim.AdamW(
            self.router.parameters(),  # All parameters use same learning rate
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
            eps=1e-8
        )
    
        # Use linear warmup + cosine annealing scheduler
        total_steps = len(self.train_data) // self.args.batch_size * self.args.num_epochs
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=total_steps // 10,  # 10% warmup
            num_training_steps=total_steps
        )
    
        # Initialize gradient scaler
        self.scaler = amp.GradScaler()
        
        print(f"Optimizer: AdamW with LoRA-specific learning rates")
        print(f"Scheduler: OneCycleLR with cosine annealing")
    
    def prepare_batch_data(self, batch_samples):
        """Prepare batch data"""
        texts = [sample['text'] for sample in batch_samples]
        
        # Convert score list to tensor
        scores_list = [sample['scores'] for sample in batch_samples]
        target_scores = torch.tensor(scores_list, dtype=torch.float32, device=self.device)
        
        return texts, target_scores

    def prepare_batch_data_classification(self, batch_samples):
        """Prepare batch data for classification task"""
        texts = [sample['text'] for sample in batch_samples]
        
        # Convert scores to labels (best mask index)
        labels = []
        for sample in batch_samples:
            scores = sample['scores']
            best_mask_idx = np.argmax(scores)
            labels.append(best_mask_idx)
        
        target_labels = torch.tensor(labels, dtype=torch.long, device=self.device)
        return texts, target_labels

    def train_router(self):
        """Train Router"""
        print("Starting router training...")
        print(f"Training mode: {'Classification' if self.router.use_classification else 'Regression'}")
        print(f"Training configuration:")
        print(f"  Epochs: {self.args.num_epochs}")
        print(f"  Batch size: {self.args.batch_size}")
        print(f"  Learning rate: {self.args.learning_rate}")
        print(f"  Weight decay: {self.args.weight_decay}")
        print(f"  No warmup")
        
        self.router.train()
        
        global_step = 0
        best_loss = float('inf')

        # Prefetch data to GPU
        print("Preloading data to GPU...")
        # preloaded_texts = [sample['text'] for sample in self.train_data]
        # Tokenize the texts
        preloaded_texts = [sample['text'] for sample in self.train_data]
        
        preloaded_scores = torch.tensor(
            [sample['scores'] for sample in self.train_data],
            dtype=torch.float32,
            device=self.device
        )

        print("\n=== Pre-training debug ===")
        # Check data label distribution
        all_labels = [np.argmax(sample['scores']) for sample in self.train_data]
        from collections import Counter
        label_dist = Counter(all_labels)
        print(f"Label distribution: {dict(label_dist)}")

        for epoch in range(self.args.num_epochs):
            epoch_loss = 0
            num_batches = 0
            
            # Randomly shuffle training data indices
            indices = torch.randperm(len(self.train_data))
            
            # Calculate total batch count
            total_batches = (len(self.train_data) + self.args.batch_size - 1) // self.args.batch_size
            
            # Create progress bar
            pbar = tqdm(total=total_batches, desc=f"Epoch {epoch+1}/{self.args.num_epochs}")
            
            for start_idx in range(0, len(self.train_data), self.args.batch_size):
                end_idx = min(start_idx + self.args.batch_size, len(self.train_data))
                batch_indices = indices[start_idx:end_idx]

                batch_texts = [preloaded_texts[i] for i in batch_indices]

                batch_tokens = self.router.tokenizer(
                    batch_texts,
                    padding=True,
                    truncation=True,
                    max_length=2048,
                    return_tensors="pt"
                ).to(self.device)

                # Prepare data
                if self.router.use_classification:

                    target_labels = torch.tensor(
                        [np.argmax(self.train_data[i]['scores']) for i in batch_indices],
                        dtype=torch.long, device=self.device
                    )
                    
                    # print(f"batch_texts: {batch_texts[0]}")
                    # quit()

                    self.optimizer.zero_grad()

                    # Forward pass
                    with amp.autocast():
                        logits = self.router(batch_tokens)
                        loss = F.cross_entropy(logits, target_labels)
                else:
                    # Original MSE regression logic
                    target_scores = preloaded_scores[batch_indices]
                    
                    self.optimizer.zero_grad()

                    with amp.autocast():
                        predicted_scores = self.router(batch_tokens)
                        
                        # loss = F.mse_loss(predicted_scores, target_scores)
                        loss = compute_stable_mse_loss(predicted_scores, target_scores)
                
                # Use gradient scaling
                self.scaler.scale(loss).backward()
                
                # Gradient clipping
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.router.parameters(), max_norm=1.0)
                
                # Update parameters
                self.scaler.step(self.optimizer)
                self.scaler.update()

                if global_step % 100 == 0:
                    print(f"Current scale: {self.scaler.get_scale()}")

                self.scheduler.step()
                
                epoch_loss += loss.item()
                num_batches += 1
                global_step += 1
                
                # Update progress bar
                current_lr = self.scheduler.get_last_lr()[0]
                pbar.update(1)
                pbar.set_postfix({
                    'loss': f'{loss.item():.6f}',
                    'lr': f'{current_lr:.2e}',
                    'step': global_step
                })
        
        # Save final model
        self.save_router("router_final.pt")
        print("Training completed!")
    
    def evaluate_on_test_set(self, test_datasets=None):
        """Evaluate Router performance on test set"""
        print("\n=== Test set evaluation ===")
        
        # Test all datasets
        if test_datasets is None or test_datasets == ["all"]:
            test_datasets = ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag', 'boolq', 'openbookqa', 'wikitext', 'ptb']
        
        # Setup LLM model (for testing)
        self.setup_llm_for_testing()
        
        all_results = {}
        
        for dataset_name in test_datasets:
            print(f"\n--- Testing {dataset_name} ---")
            
            if dataset_name == 'wikitext' or dataset_name == 'ptb':
                # Special handling for WikiText dataset
                if dataset_name == 'wikitext':
                    input_ids = build_wikitext_ids(self.tokenizer, split="test")
                elif dataset_name == 'ptb':
                    input_ids = build_ptb_ids(self.tokenizer, split="test")
                samples = sample_wikitext_sequences(input_ids,
                                                    seqlen=2048,
                                                    n=None,  # Get all inputs
                                                    random_sample=False)
                    
                print(f"WikiText samples shape: {samples.shape}")
                    
                # Calculate perplexity, use router to predict mask for each sample
                bs = 1
                original_ppls = 0
                total_nll, total_tokens = 0, 0
                mask_usage_count = {}  # Count usage of each mask
                # Store perplexity for each sample
                # per_sample_ppls = []
                self.router.eval()
                    
                for i in tqdm(range(0, samples.size(0), bs), desc="Processing wikitext with router"):
                    batch = samples[i : i + bs]  # shape [B, seqlen]
                        
                    # Convert current batch tokens to text for router prediction
                    batch_texts = []
                    for j in range(batch.size(0)):
                        text = self.tokenizer.decode(batch[j], skip_special_tokens=True)
                        batch_texts.append(text)
                    
                    # Calculate original perplexity
                    self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
                    original_ppl = calculate_perplexity(
                        self.llm_model,
                        batch,
                        limit_length=2048,
                        device=self.device,
                    )

                    # Use router to predict best mask (fix tokenization)
                    with torch.no_grad():
                        # Tokenize text
                        batch_tokens = self.router.tokenizer(
                            batch_texts,
                            padding=True,
                            truncation=True,
                            max_length=2048,
                            return_tensors="pt"
                        ).to(self.device)
                        
                        predicted_scores = self.router(batch_tokens)
                        best_mask_idx = torch.argmax(predicted_scores, dim=1)[0].item()
                        
                    # Count mask usage
                    mask_usage_count[best_mask_idx] = mask_usage_count.get(best_mask_idx, 0) + 1
                        
                    # Apply predicted mask
                    best_mask = self.representative_masks[best_mask_idx]
                    single_masks = self.convert_flat_mask_to_layer_masks(best_mask)
                    self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                    self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                        
                    # Calculate perplexity
                    nll = calculate_perplexity(
                            self.llm_model,
                            batch,
                            limit_length=2048,
                            device=self.device,
                    )

                    original_ppls += original_ppl
                    total_nll += nll
                    total_tokens += (batch.size(0) * (batch.size(1) - 1))
                    
                original_ppls = torch.exp(original_ppls / total_tokens)
                ppl = torch.exp(total_nll / total_tokens)
                
                print(f"Original perplexity: {original_ppls:.4f}")
                print(f"Masked WikiText perplexity: {ppl:.4f}")
                print(f"Mask usage distribution: {mask_usage_count}")

            else:
                try:
                    # Load test set
                    mc_dataset = load_mc_dataset(dataset_name, split="test")
                    
                    dataset_results = []
                    original_correct = 0
                    original_correct_norm = 0
                    router_correct = 0
                    router_correct_norm = 0
                    
                    self.router.eval()
                    
                    for test_idx, mc_sample in enumerate(tqdm(mc_dataset, desc=f"Testing {dataset_name}")):
                        formatted_example = format_mc_example(mc_sample, dataset_name)
                        formatted_example["dataset_name"] = dataset_name
                        
                        # 1. Test original model (no mask)
                        self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
                        original_result = evaluate_mc_example(
                            self.llm_model, self.tokenizer, formatted_example, 
                            device=self.device, max_length=2048
                        )
                        
                        # 2. Use Router to predict best mask (fix tokenization)
                        full_context = formatted_example.get("question", "")
                        if not full_context and "context_prefix" in formatted_example:
                            full_context = formatted_example["context_prefix"]
                        
                        with torch.no_grad():
                            # Tokenize text
                            batch_tokens = self.router.tokenizer(
                                [full_context],  # Note: needs to be a list
                                padding=True,
                                truncation=True,
                                max_length=2048,
                                return_tensors="pt"
                            ).to(self.device)
                            
                            predicted_scores = self.router(batch_tokens)
                            best_mask_idx = torch.argmax(predicted_scores, dim=1)[0].item()
                        
                        # 3. Apply predicted mask
                        best_mask = self.representative_masks[best_mask_idx]
                        single_masks = self.convert_flat_mask_to_layer_masks(best_mask)
                        self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                        self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                        
                        # 4. Test pruned model
                        pruned_result = evaluate_mc_example(
                            self.llm_model, self.tokenizer, formatted_example,
                            device=self.device, max_length=2048
                        )
                        
                        # Collect results
                        original_correct += original_result["is_correct"]
                        original_correct_norm += original_result["is_correct_normalized"]
                        router_correct += pruned_result["is_correct"]
                        router_correct_norm += pruned_result["is_correct_normalized"]
                        
                        # Record detailed results
                        dataset_results.append({
                            "test_idx": test_idx,
                            "selected_mask": best_mask_idx,
                            "predicted_scores": predicted_scores[0].detach().cpu().numpy().tolist(),
                            "original_correct": original_result["is_correct"],
                            "original_correct_norm": original_result["is_correct_normalized"],
                            "pruned_correct": pruned_result["is_correct"],
                            "pruned_correct_norm": pruned_result["is_correct_normalized"],
                            "label": original_result["label"],
                            "original_prediction": original_result["prediction"],
                            "original_prediction_norm": original_result["normalized_prediction"],
                            "pruned_prediction": pruned_result["prediction"],
                            "pruned_prediction_norm": pruned_result["normalized_prediction"]
                        })
                    
                    # Calculate dataset-level accuracy
                    num_samples = len(mc_dataset)
                    original_acc = original_correct / num_samples
                    original_acc_norm = original_correct_norm / num_samples
                    router_acc = router_correct / num_samples
                    router_acc_norm = router_correct_norm / num_samples
                    
                    # Save results
                    all_results[dataset_name] = {
                        "num_samples": num_samples,
                        "original_accuracy": original_acc,
                        "original_accuracy_norm": original_acc_norm,
                        "router_accuracy": router_acc,
                        "router_accuracy_norm": router_acc_norm,
                        "accuracy_drop": original_acc - router_acc,
                        "accuracy_norm_drop": original_acc_norm - router_acc_norm,
                        "detailed_results": dataset_results
                    }
                    
                    print(f"{dataset_name} results:")
                    print(f"  Sample count: {num_samples}")
                    print(f"  Original accuracy: {original_acc:.4f}")
                    print(f"  Router accuracy: {router_acc:.4f}")
                    print(f"  Original accuracy (norm): {original_acc_norm:.4f}")
                    print(f"  Router accuracy (norm): {router_acc_norm:.4f}")
                    
                except Exception as e:
                    print(f"Error testing {dataset_name}: {e}")
                    all_results[dataset_name] = {"error": str(e)}
        
        return all_results

    def evaluate_individual_masks(self, test_datasets=None):
        print("\n=== Individual mask testing ===")
        
        if test_datasets is None:
            test_datasets = ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag', 'wikitext']
        
        # Setup LLM model (for testing)
        self.setup_llm_for_testing()
        
        all_mask_results = {}
        
        # Test each representative mask
        for mask_idx in range(len(self.representative_masks)):
            print(f"\n--- Testing mask {mask_idx} ---")
            
            current_mask = self.representative_masks[mask_idx]
            single_masks = self.convert_flat_mask_to_layer_masks(current_mask)
            
            # Apply current mask
            self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
            self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
            
            mask_results = {}
            
            # Test current mask on each dataset
            for dataset_name in test_datasets:
                print(f"  Testing dataset: {dataset_name}")
                
                if dataset_name == 'wikitext' or dataset_name == 'ptb':
                    # Special handling for WikiText dataset
                    if dataset_name == 'wikitext':
                        input_ids = build_wikitext_ids(self.tokenizer, split="test")
                    elif dataset_name == 'ptb':
                        input_ids = build_ptb_ids(self.tokenizer, split="test")
                    samples = sample_wikitext_sequences(input_ids,
                                                    seqlen=2048,
                                                    n=None,  # Get all samples
                                                    random_sample=True)
                    
                    print(f"    WikiText samples shape: {samples.shape}")
                    
                    # Calculate perplexity
                    bs = 1
                    total_nll, total_tokens = 0, 0
                    
                    for i in tqdm(range(0, samples.size(0), bs), desc=f"    Processing wikitext with mask {mask_idx}"):
                        batch = samples[i : i + bs]
                        
                        # Calculate perplexity
                        nll = calculate_perplexity(
                            self.llm_model,
                            batch,
                            limit_length=2048,
                            device=self.device,
                        )
                        
                        total_nll += nll
                        total_tokens += (batch.size(0) * (batch.size(1) - 1))
                    
                    ppl = torch.exp(total_nll / total_tokens)
                    print(f"    Mask {mask_idx} perplexity on WikiText: {ppl:.4f}")
                    
                    mask_results[dataset_name] = {
                        "perplexity": ppl.item(),
                        "total_samples": samples.size(0)
                    }
                    
                else:
                    try:
                        # Load test set
                        mc_dataset = load_mc_dataset(dataset_name, split="test")
                        
                        dataset_results = []
                        correct_count = 0
                        correct_norm_count = 0
                        
                        test_samples = list(mc_dataset)
                        
                        for test_idx, mc_sample in enumerate(tqdm(test_samples, desc=f"    Testing {dataset_name} with mask {mask_idx}")):
                            formatted_example = format_mc_example(mc_sample, dataset_name)
                            formatted_example["dataset_name"] = dataset_name
                            
                            # Test current mask
                            result = evaluate_mc_example(
                                self.llm_model, self.tokenizer, formatted_example,
                                device=self.device, max_length=2048
                            )
                            
                            correct_count += result["is_correct"]
                            correct_norm_count += result["is_correct_normalized"]
                            
                        
                        # Calculate accuracy
                        num_samples = len(test_samples)
                        accuracy = correct_count / num_samples
                        accuracy_norm = correct_norm_count / num_samples
                        
                        print(f"    Mask {mask_idx} accuracy on {dataset_name}: {accuracy:.4f}")
                        print(f"    Mask {mask_idx} normalized accuracy on {dataset_name}: {accuracy_norm:.4f}")
                        
                        mask_results[dataset_name] = {
                            "num_samples": num_samples,
                            "accuracy": accuracy,
                            "accuracy_norm": accuracy_norm,
                            "correct_count": correct_count,
                            "correct_norm_count": correct_norm_count
                        }
                    
                    except Exception as e:
                        print(f"    Error testing mask {mask_idx} on {dataset_name}: {e}")
                        mask_results[dataset_name] = {"error": str(e)}
            
            all_mask_results[f"mask_{mask_idx}"] = mask_results
        
        # Print summary results
        print("\n=== Mask performance summary ===")
        for mask_key, mask_result in all_mask_results.items():
            print(f"\n{mask_key}:")
            for dataset_name, dataset_result in mask_result.items():
                if "error" in dataset_result:
                    print(f"  {dataset_name}: Error")
                elif dataset_name == 'wikitext':
                    print(f"  {dataset_name}: Perplexity = {dataset_result['perplexity']:.4f}")
                else:
                    print(f"  {dataset_name}: Accuracy = {dataset_result['accuracy']:.4f}")
        
        return all_mask_results


    def setup_llm_for_testing(self):
        """Setup LLM model for testing"""
        if hasattr(self, 'llm_model'):
            return
        
        print("Setting up LLM model for testing...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
        self.llm_model = PruneLlama2ForCausalLM.from_pretrained(
            self.args.model_path,
            torch_dtype=torch.float16,
            device_map=self.device
        )
        self.llm_model.config.use_cache = False
        self.llm_model.eval()
        
        # Get parameter regularization structure
        self.param_reg = collect_info_reg_llama(self.llm_model, p=self.args.p, lam=self.args.lam)
        
        # Initialize helper
        self.hn_helper = help_functions_hn(self.param_reg.structures)
        
        # Load representative masks (if not already loaded)
        if not hasattr(self, 'representative_masks'):
            self.load_representative_masks_for_testing()
    
    def load_representative_masks_for_testing(self):
        """Load representative masks for testing - using mask combination results"""
        print("Loading representative masks from mask combination results...")
        
        # First try to load mask combination analysis results
        mask_combination_file = "xxx/project/DynPrune/llama-2-7b/041/mask_combination_results.pkl"
        # mask_combination_file = "xxx/project/DISP/arc-e/mask_combination_results.pkl"
        if os.path.exists(mask_combination_file):
            print(f"Found mask combination results at {mask_combination_file}")
            with open(mask_combination_file, 'rb') as f:
                mask_combination_data = pickle.load(f)
            
            # Check if contains necessary fields
            if 'selected_mask_indices' in mask_combination_data:
                selected_mask_indices = mask_combination_data['selected_mask_indices']
                print(f"Found {len(selected_mask_indices)} selected mask indices: {selected_mask_indices}")
                
                # Need to load all wikitext masks
                all_wikitext_masks = self.load_all_wikitext_masks()
                
                # Extract representative masks based on selected indices
                self.representative_masks = []
                
                for i, mask_idx in enumerate(selected_mask_indices):
                    if mask_idx < len(all_wikitext_masks):
                        self.representative_masks.append(all_wikitext_masks[mask_idx])
                        print(f"Selected mask {i}: index {mask_idx}")
                    else:
                        print(f"Warning: mask index {mask_idx} out of range (max: {len(all_wikitext_masks)-1})")
                
                # Modification: prioritize using num_masks_to_load
                if self.args.num_masks_to_load is not None and self.args.num_masks_to_load < len(self.representative_masks):
                    print(f"Limiting mask count: from {len(self.representative_masks)} to {self.args.num_masks_to_load}")
                    self.representative_masks = self.representative_masks[:self.args.num_masks_to_load]
                    self.args.num_clusters = self.args.num_masks_to_load
                else:
                    # No longer automatically override num_clusters, keep user parameters
                    pass
                
                print(f"Loaded {len(self.representative_masks)} representative masks from mask combination analysis")
                return
        
        # If mask combination analysis results not found, fall back to clustering method
        print("Mask combination results not found, falling back to clustering results...")
        clustering_file = os.path.join(os.path.dirname(self.args.data_dir), "clustering_results.pkl")
        if not os.path.exists(clustering_file):
            raise FileNotFoundError(f"Neither mask combination results nor clustering results found")
        
        with open(clustering_file, 'rb') as f:
            clustering_data = pickle.load(f)
        
        masks = clustering_data['masks']
        similarity_matrix = clustering_data['similarity_matrix']
        clustering_results = clustering_data['clustering_results']
        
        # Use same clustering count
        cluster_labels = clustering_results[self.args.num_clusters]['labels']
        
        # Rebuild representative masks
        self.representative_masks = []
        for cluster_id in range(self.args.num_clusters):
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            
            if len(cluster_indices) == 0:
                continue
            
            if len(cluster_indices) == 1:
                representative_idx = cluster_indices[0]
            else:
                cluster_sim_matrix = similarity_matrix[cluster_indices][:, cluster_indices]
                avg_similarities = np.mean(cluster_sim_matrix, axis=1)
                best_idx_in_cluster = np.argmax(avg_similarities)
                representative_idx = cluster_indices[best_idx_in_cluster]
            
            self.representative_masks.append(masks[representative_idx])

        # Limit loaded mask count
        if self.args.num_masks_to_load is not None and self.args.num_masks_to_load < len(self.representative_masks):
            print(f"Limiting mask count: from {len(self.representative_masks)} to {self.args.num_masks_to_load}")
            self.representative_masks = self.representative_masks[:self.args.num_masks_to_load]
            # Update num_clusters
            self.args.num_clusters = self.args.num_masks_to_load
        
        print(f"Finally loaded {len(self.representative_masks)} representative masks")

    
    def convert_flat_mask_to_layer_masks(self, flat_mask):
        """Convert flattened mask to layered mask format"""
        layer_masks = []
        start_idx = 0
        
        for i, layer_size in enumerate(self.param_reg.structures):
            if not isinstance(layer_size, (int, np.integer)):
                raise ValueError(f"Expected integer for layer size, got {type(layer_size)} at layer {i}")
            
            end_idx = start_idx + layer_size
            if end_idx > len(flat_mask):
                raise ValueError(f"Mask too short: need {end_idx} elements, got {len(flat_mask)}")
            
            layer_mask = flat_mask[start_idx:end_idx]
            layer_mask_tensor = torch.from_numpy(layer_mask.astype(np.float32))
            layer_masks.append(layer_mask_tensor)
            
            start_idx = end_idx
        
        return layer_masks
    
    def load_all_wikitext_masks(self):
        """Load all wikitext masks"""
        print("Loading all wikitext masks...")
        
        # Load wikitext hypernetwork
        wikitext_hypernetwork_path = "xxx/project/DynPrune/llama-2-7b/041/hn/final_hypernetwork.pt"
        # wikitext_hypernetwork_path = "xxx/project/DISP/wikitext/final_hypernetwork.pt"
        if not os.path.exists(wikitext_hypernetwork_path):
            raise FileNotFoundError(f"Wikitext hypernetwork not found: {wikitext_hypernetwork_path}")
        
        # Load wikitext hypernetwork
        from pruning.dyn_hypernetwork import dyn_hypernetwork
        wikitext_hypernetwork = dyn_hypernetwork(
            t_structures=self.param_reg.structures,
            lrp_scale=getattr(self.args, 'lrp_scale', 1.0),
            base=getattr(self.args, 'base', 0.5),
            T_start=getattr(self.args, 'T_start', 0.5),
            T_end=getattr(self.args, 'T_end', 0.1),
            target_sparsity=getattr(self.args, 'target_sparsity', 0.4),
            hidden_dim=getattr(self.args, 'hidden_dim', 128)
        ).to(self.device)
        
        # Load trained wikitext hypernetwork weights
        checkpoint = torch.load(wikitext_hypernetwork_path, map_location=self.device)
        if 'hypernetwork' in checkpoint:
            wikitext_hypernetwork.load_state_dict(checkpoint['hypernetwork'])
        else:
            wikitext_hypernetwork.load_state_dict(checkpoint)
        print("Wikitext hypernetwork loaded successfully")
        
        # Generate wikitext masks
        wikitext_masks = []
        
        # Load wikitext data
        wikitext_lrp_path = "xxx/project/DISP/wikitext/lrp_train_ppl.pkl"
        if not os.path.exists(wikitext_lrp_path):
            raise FileNotFoundError(f"Wikitext LRP data not found: {wikitext_lrp_path}")
        
        # Read wikitext data
        with open(wikitext_lrp_path, 'rb') as f:
            wikitext_samples_data = pickle.load(f)
        
        # Limit wikitext sample count for efficiency
        max_wikitext_samples = getattr(self.args, 'max_wikitext_samples', 200)
        if max_wikitext_samples and len(wikitext_samples_data) > max_wikitext_samples:
            wikitext_samples_data = wikitext_samples_data[:max_wikitext_samples]
            print(f"Limited wikitext sample count to: {max_wikitext_samples}")
        
        # Create wikitext dataset
        wikitext_dataset = self.create_wikitext_dataset(wikitext_samples_data)
        print(f"Using {len(wikitext_dataset)} wikitext samples to generate masks")
        
        wikitext_hypernetwork.eval()
        with torch.no_grad():
            for idx in tqdm(range(len(wikitext_dataset)), desc="Generating wikitext masks"):
                sample = wikitext_dataset[idx]
                
                # Use wikitext hypernetwork to generate hard mask
                mask = wikitext_hypernetwork.hard_output(
                    sample['layer_activations'],
                    sample['input_lrp']
                )
                
                # Convert mask to binary vector
                binary_mask = self.convert_mask_to_binary(mask)
                wikitext_masks.append(binary_mask)
        
        print(f"Generated {len(wikitext_masks)} wikitext masks")
        return wikitext_masks
    
    def create_wikitext_dataset(self, samples_data):
        """Create dataset for wikitext data"""
        from torch.utils.data import Dataset
        
        class WikitextDataset(Dataset):
            def __init__(self, samples_data, param_reg_structures, device, normalize_lrp=True, normalize_activations=False):
                self.device = device
                self.normalize_lrp = normalize_lrp
                self.normalize_activations = normalize_activations
                self.samples = []
                
                for idx in tqdm(range(len(samples_data)), desc="Processing wikitext samples"):
                    sample_data = samples_data[idx]
                    
                    # Process sample_ids
                    sample_ids = sample_data["sample_id"]
                    if isinstance(sample_ids, np.ndarray):
                        sample_ids = torch.from_numpy(sample_ids).long()
                    elif not isinstance(sample_ids, torch.Tensor):
                        sample_ids = torch.tensor(sample_ids).long()
                    
                    # Process activations and lrp
                    layer_activations = []
                    input_lrp = []
                    
                    lrp_scores = sample_data["lrp"]
                    activations = sample_data["activations"]
                    
                    for structure_idx in range(len(param_reg_structures)):
                        if structure_idx < len(lrp_scores) and structure_idx < len(activations):
                            activation_data = activations[structure_idx]
                            lrp_data = lrp_scores[structure_idx]
                            
                            if isinstance(activation_data, np.ndarray):
                                activation_tensor = torch.from_numpy(activation_data).float()
                            else:
                                activation_tensor = torch.tensor(activation_data).float()
                            
                            if isinstance(lrp_data, np.ndarray):
                                lrp_tensor = torch.from_numpy(lrp_data).float()
                            else:
                                lrp_tensor = torch.tensor(lrp_data).float()
                            
                            # Add batch dimension
                            if activation_tensor.dim() == 1:
                                activation_tensor = activation_tensor.unsqueeze(0)
                            if lrp_tensor.dim() == 1:
                                lrp_tensor = lrp_tensor.unsqueeze(0)
                            
                            # Apply normalization
                            if self.normalize_lrp:
                                lrp_tensor = self.normalize_tensor_layerwise(lrp_tensor)
                            
                            if self.normalize_activations:
                                activation_tensor = self.normalize_tensor_layerwise(activation_tensor)
                            
                            layer_activations.append(activation_tensor)
                            input_lrp.append(lrp_tensor)
                    
                    # Add batch dimension to sample_ids
                    if sample_ids.dim() == 1:
                        sample_ids = sample_ids.unsqueeze(0)
                    
                    self.samples.append({
                        'sample_ids': sample_ids,
                        'layer_activations': layer_activations,
                        'input_lrp': input_lrp
                    })
            
            def normalize_tensor_layerwise(self, tensor, eps=1e-8):
                """Layer-wise normalization method"""
                if tensor.numel() == 0:
                    return tensor
                
                tensor = torch.abs(tensor)
                original_shape = tensor.shape
                
                if tensor.dim() == 1:
                    tensor = tensor.unsqueeze(0)
                    squeeze_later = True
                else:
                    squeeze_later = False
                
                mean = tensor.mean(dim=-1, keepdim=True)
                std = tensor.std(dim=-1, keepdim=True, unbiased=False)
                std = torch.clamp(std, min=eps)
                
                normalized_tensor = (tensor - mean) / std
                
                if squeeze_later:
                    normalized_tensor = normalized_tensor.squeeze(0)
                
                return normalized_tensor
            
            def __len__(self):
                return len(self.samples)
            
            def __getitem__(self, idx):
                sample = self.samples[idx]
                return {
                    'sample_ids': sample['sample_ids'].to(self.device),
                    'layer_activations': [act.to(self.device) for act in sample['layer_activations']],
                    'input_lrp': [lrp.to(self.device) for lrp in sample['input_lrp']]
                }
        
        # Pass normalization parameters
        return WikitextDataset(samples_data, self.param_reg.structures, self.device, 
                              getattr(self.args, 'normalize_lrp', True), 
                              getattr(self.args, 'normalize_activations', False))
    
    def convert_mask_to_binary(self, mask_list):
        """Convert mask list to single binary vector"""
        binary_vectors = []
        
        for layer_idx, mask_tensor in enumerate(mask_list):
            # Ensure mask is binary (0 or 1) and convert to boolean type
            binary_mask = (mask_tensor > 0.5).bool()
            binary_mask_np = binary_mask.cpu().numpy().flatten().astype(np.bool_)
            binary_vectors.append(binary_mask_np)
        
        # Concatenate all layer masks
        concatenated_mask = np.concatenate(binary_vectors)
        return concatenated_mask
    
    def evaluate_router(self, test_data=None):
        """Evaluate Router performance (validation set)"""
        if test_data is None:
            # Use part of training data as validation set
            # random sample 1000
            test_data = random.sample(self.train_data, 1000)
        
        self.router.eval()
        total_samples = 0
        total_loss = 0
        correct_predictions = 0
        
        with torch.no_grad():
            for i in range(0, len(test_data), self.args.batch_size):
                batch_samples = test_data[i:i+self.args.batch_size]

                if self.router.use_classification:
                    texts = [sample['text'] for sample in batch_samples]
                    target_labels = torch.tensor(
                        [np.argmax(sample['scores']) for sample in batch_samples],
                        dtype=torch.long, device=self.device
                    )
                    
                    batch_tokens = self.router.tokenizer(
                        texts,
                        padding=True,
                        truncation=True,
                        max_length=2048,
                        return_tensors="pt"
                    ).to(self.device)
                    
                    logits = self.router(batch_tokens)
                    
                    # Calculate loss
                    loss = F.cross_entropy(logits, target_labels)
                    total_loss += loss.item() * len(batch_samples)
                    
                    # Calculate accuracy
                    predicted_labels = torch.argmax(logits, dim=1)
                    correct_predictions += (predicted_labels == target_labels).sum().item()
                else:
                    texts = [sample['text'] for sample in batch_samples]
                    target_scores = torch.tensor(
                        [sample['scores'] for sample in batch_samples],
                        dtype=torch.float32, device=self.device
                    )
                    
                    batch_tokens = self.router.tokenizer(
                        texts,
                        padding=True,
                        truncation=True,
                        max_length=2048,
                        return_tensors="pt"
                    ).to(self.device)
                    
                    predicted_scores = self.router(batch_tokens)
                    
                    # loss = F.mse_loss(predicted_scores, target_scores)
                    loss = compute_stable_mse_loss(predicted_scores, target_scores)
                    total_loss += loss.item() * len(batch_samples)
                    
                    pred_best = torch.argmax(predicted_scores, dim=1)
                    true_best = torch.argmax(target_scores, dim=1)
                    correct_predictions += (pred_best == true_best).sum().item()
                
                total_samples += len(batch_samples)
        
        avg_loss = total_loss / total_samples
        avg_accuracy = correct_predictions / total_samples
        
        self.router.train()
        return avg_loss, avg_accuracy
    
    def save_router(self, filename):
        """Save Router model - fixed version"""
        filepath = os.path.join(self.args.output_dir, filename)
        os.makedirs(self.args.output_dir, exist_ok=True)

        model_state_dict = {
            'lora_state_dict': get_peft_model_state_dict(self.router.model),
            'classifier_state_dict': self.router.classifier.state_dict() if self.router.use_classification else self.router.regressor.state_dict()
        }
        
        save_dict = {
            'model_state_dict': model_state_dict,  
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            #'metadata': self.metadata,
            'args': self.args,
            'num_clusters': self.args.num_clusters
        }
        
        torch.save(save_dict, filepath)
        print(f"Router saved to {filepath}")

    def load_router(self, checkpoint_path):
        """Load trained Router model - matching scheme 2 format"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Debug info
        print("Available keys in checkpoint:", list(checkpoint.keys()))
        if 'model_state_dict' in checkpoint:
            print("Available keys in model_state_dict:", list(checkpoint['model_state_dict'].keys()))
        
        if 'model_state_dict' in checkpoint and 'lora_state_dict' in checkpoint['model_state_dict']:
            lora_state_dict = checkpoint['model_state_dict']['lora_state_dict']
            
            # Use PEFT's correct loading method
            from peft import set_peft_model_state_dict
            set_peft_model_state_dict(self.router.model, lora_state_dict)
            print("LoRA weights loaded")
        
        if 'model_state_dict' in checkpoint and 'classifier_state_dict' in checkpoint['model_state_dict']:
            classifier_state_dict = checkpoint['model_state_dict']['classifier_state_dict']
            
            if self.router.use_classification:
                self.router.classifier.load_state_dict(classifier_state_dict)
                print("Classifier weights loaded")
            else:
                self.router.regressor.load_state_dict(classifier_state_dict)
                print("Regressor weights loaded")
        
        print(f"Router loaded from {checkpoint_path}")

    def evaluate_mask_routing_matrix(self, test_datasets=None):
        """
        Evaluate mask-routing matrix: test effect of each mask on data routed to each path
        Returns a 10x10 accuracy matrix, matrix[i][j] represents accuracy of mask i on data predicted as mask j by router
        """
        print("\n=== Mask routing matrix evaluation ===")
        
        if test_datasets is None:
            test_datasets = ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag', 'boolq', 'openbookqa']
        
        # Setup LLM model (for testing)
        self.setup_llm_for_testing()
        
        num_masks = len(self.representative_masks)
        
        all_results = {}
        
        for dataset_name in test_datasets:
            print(f"\n--- Testing dataset: {dataset_name} ---")
            
            try:
                # Load test set
                mc_dataset = load_mc_dataset(dataset_name, split="train")
                
                # Step 1: Collect routing prediction results for all samples
                print("  Step 1: Collecting routing prediction results...")
                samples_by_route = [[] for _ in range(num_masks)]  # Group by routing prediction
                
                self.router.eval()
                
                for test_idx, mc_sample in enumerate(tqdm(mc_dataset, desc=f"  Routing prediction")):
                    formatted_example = format_mc_example(mc_sample, dataset_name)
                    formatted_example["dataset_name"] = dataset_name
                    
                    # Get question text
                    full_context = formatted_example.get("question", "")
                    if not full_context and "context_prefix" in formatted_example:
                        full_context = formatted_example["context_prefix"]
                    
                    # Router prediction
                    with torch.no_grad():
                        batch_tokens = self.router.tokenizer(
                            [full_context],
                            padding=True,
                            truncation=True,
                            max_length=2048,
                            return_tensors="pt"
                        ).to(self.device)
                        
                        predicted_scores = self.router(batch_tokens)
                        predicted_route = torch.argmax(predicted_scores, dim=1)[0].item()
                    
                    # Group by routing prediction
                    samples_by_route[predicted_route].append((test_idx, formatted_example))
                
                # Print routing distribution
                route_counts = [len(samples) for samples in samples_by_route]
                print(f"  Routing distribution: {route_counts}")
                
                # Step 2: Test each mask on each routing group
                print("  Step 2: Testing mask matrix...")
                accuracy_matrix = np.zeros((num_masks, num_masks))
                
                for mask_idx in range(num_masks):
                    print(f"    Testing mask {mask_idx}")
                    
                    # Apply current mask
                    current_mask = self.representative_masks[mask_idx]
                    single_masks = self.convert_flat_mask_to_layer_masks(current_mask)
                    self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                    self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                    
                    # Test on each routing group
                    for route_idx in range(num_masks):
                        if len(samples_by_route[route_idx]) == 0:
                            accuracy_matrix[mask_idx][route_idx] = np.nan
                            continue
                        
                        correct_count = 0
                        total_count = 0
                        
                        for test_idx, formatted_example in samples_by_route[route_idx]:
                            result = evaluate_mc_example(
                                self.llm_model, self.tokenizer, formatted_example,
                                device=self.device, max_length=2048
                            )
                            
                            correct_count += result["is_correct"]
                            total_count += 1
                        
                        accuracy = correct_count / total_count if total_count > 0 else 0
                        accuracy_matrix[mask_idx][route_idx] = accuracy
                
                # Reset model state
                self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
                
                all_results[dataset_name] = {
                    'accuracy_matrix': accuracy_matrix,
                    'route_counts': route_counts,
                    'total_samples': sum(route_counts)
                }
                
                # Print results
                self.print_routing_matrix_results(dataset_name, accuracy_matrix, route_counts)
                
            except Exception as e:
                print(f"  Error testing {dataset_name}: {e}")
                all_results[dataset_name] = {"error": str(e)}
        
        return all_results

    def print_routing_matrix_results(self, dataset_name, accuracy_matrix, route_counts):
        """Print routing matrix results"""
        print(f"\n  {dataset_name} routing matrix results:")
        print(f"  Matrix dimensions: {accuracy_matrix.shape}")
        
        # Print header
        print("    " + "".join([f"Route{j:2d} " for j in range(accuracy_matrix.shape[1])]))
        
        # Print each row
        for i in range(accuracy_matrix.shape[0]):
            row_str = f"Mask{i:2d} "
            for j in range(accuracy_matrix.shape[1]):
                if np.isnan(accuracy_matrix[i][j]):
                    row_str += "  N/A  "
                else:
                    row_str += f"{accuracy_matrix[i][j]:.3f} "
            print("  " + row_str)
        
        # Print routing distribution
        print(f"  Routing distribution: {route_counts}")
        
        # Calculate some statistics
        valid_mask = ~np.isnan(accuracy_matrix)
        if valid_mask.any():
            print(f"  Average accuracy: {np.nanmean(accuracy_matrix):.4f}")
            print(f"  Highest accuracy: {np.nanmax(accuracy_matrix):.4f}")
            print(f"  Lowest accuracy: {np.nanmin(accuracy_matrix):.4f}")
            
            # Diagonal accuracy (should ideally be higher)
            diagonal_acc = np.diag(accuracy_matrix)
            valid_diagonal = diagonal_acc[~np.isnan(diagonal_acc)]
            if len(valid_diagonal) > 0:
                print(f"  Diagonal average accuracy: {np.mean(valid_diagonal):.4f}")


    def draw_visualizations(self):
        """Main visualization function"""
        import matplotlib.pyplot as plt
        import seaborn as sns
        from sklearn.manifold import TSNE
        import pandas as pd
        from collections import defaultdict
        
        print("\n=== Starting visualization generation ===")
        
        # Set visualization output directory
        viz_dir = self.args.viz_output_dir or os.path.join(self.args.output_dir, "visualizations")
        os.makedirs(viz_dir, exist_ok=True)
        
        # 1. Collect data
        dataset_mask_counts, embeddings_data = self.collect_visualization_data()
        
        # 2. Create heatmap
        self.create_heatmap(dataset_mask_counts, viz_dir)
        
        # 3. Create t-SNE plot
        self.create_tsne_plot(embeddings_data, viz_dir)
        
        print(f"\nVisualization results saved to: {viz_dir}")

    def collect_visualization_data(self):
        """Collect data needed for visualization"""
        from collections import defaultdict
        import random
        
        # Initialize
        dataset_mask_counts = defaultdict(lambda: defaultdict(int))
        embeddings_data = {
            'embeddings': [],
            'dataset_labels': [],
            'mask_predictions': [],
            'texts': []
        }
        
        # Datasets to process
        target_datasets = ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag', 'boolq', 'openbookqa']
        
        self.router.eval()
        
        for dataset_name in target_datasets:
            print(f"\nProcessing dataset: {dataset_name}")
            
            try:
                # Load dataset
                mc_dataset = load_mc_dataset(dataset_name, split="test")
                dataset_samples = list(mc_dataset)
                
                # Sample
                if len(dataset_samples) > self.args.viz_samples_per_dataset:
                    dataset_samples = random.sample(dataset_samples, self.args.viz_samples_per_dataset)
                
                print(f"  Processing {len(dataset_samples)} samples")
                
                # Process each sample
                for idx, mc_sample in enumerate(tqdm(dataset_samples, desc=f"  {dataset_name}")):
                    formatted_example = format_mc_example(mc_sample, dataset_name)
                    
                    # Get question text
                    question_text = formatted_example.get("question", "")
                    if not question_text and "context_prefix" in formatted_example:
                        question_text = formatted_example["context_prefix"]
                    
                    # Get router's embedding and prediction
                    with torch.no_grad():
                        # Tokenize
                        batch_tokens = self.router.tokenizer(
                            [question_text],
                            padding=True,
                            truncation=True,
                            max_length=2048,
                            return_tensors="pt"
                        ).to(self.device)
                        
                        # Get model output (need to modify router's forward method to return both embeddings and scores)
                        embeddings, scores = self.get_router_embeddings_and_scores(batch_tokens)
                        
                        # Predict best mask
                        best_mask_idx = torch.argmax(scores, dim=1)[0].item()
                    
                    # Record data
                    dataset_mask_counts[dataset_name][best_mask_idx] += 1
                    
                    # Save embedding data (only save part for t-SNE)
                    if idx < 100:  # Max 100 samples per dataset for t-SNE
                        embeddings_data['embeddings'].append(embeddings[0].cpu().numpy())
                        embeddings_data['dataset_labels'].append(dataset_name)
                        embeddings_data['mask_predictions'].append(best_mask_idx)
                        embeddings_data['texts'].append(question_text[:100])  # Save first 100 characters
                        
            except Exception as e:
                print(f"  Error processing {dataset_name}: {e}")
                continue
        
        return dataset_mask_counts, embeddings_data

    def get_router_embeddings_and_scores(self, batch_tokens):
        """Get router's embeddings and scores (need to temporarily modify forward logic)"""
        # Get sentence embeddings
        with torch.cuda.amp.autocast():
            outputs = self.router.model(batch_tokens['input_ids'], 
                                    attention_mask=batch_tokens['attention_mask'])
            embeddings = last_token_pool(outputs.last_hidden_state, 
                                        batch_tokens['attention_mask'])
            embeddings = F.normalize(embeddings, p=2, dim=1, eps=1e-4)
        
        # Get scores
        if self.router.use_classification:
            scores = self.router.classifier(embeddings)
        else:
            scores = self.router.regressor(embeddings)
        
        return embeddings, scores

    def create_heatmap(self, dataset_mask_counts, viz_dir):
        """Create dataset-mask selection heatmap"""
        import matplotlib.pyplot as plt
        import seaborn as sns
        import pandas as pd
        
        print("\nCreating heatmap...")
        
        # Convert to DataFrame
        datasets = sorted(dataset_mask_counts.keys())
        mask_ids = list(range(self.args.num_clusters))
        
        # Create matrix
        matrix = []
        for dataset in datasets:
            row = []
            total_count = sum(dataset_mask_counts[dataset].values())
            for mask_id in mask_ids:
                count = dataset_mask_counts[dataset].get(mask_id, 0)
                # Convert to percentage
                percentage = (count / total_count * 100) if total_count > 0 else 0
                row.append(percentage)
            matrix.append(row)
        
        # Create DataFrame
        df = pd.DataFrame(matrix, index=datasets, columns=[f'Mask {i}' for i in mask_ids])
        
        # Draw heatmap
        plt.figure(figsize=(12, 8))
        
        # Use custom color map
        sns.heatmap(df, annot=True, fmt='.1f', cmap='YlOrRd', 
                    cbar_kws={'label': 'Selection Percentage (%)'},
                    square=True, linewidths=0.5, linecolor='gray')
        
        plt.title('Dataset-Mask Selection Heatmap', fontsize=16, pad=20)
        plt.xlabel('Mask ID', fontsize=12)
        plt.ylabel('Dataset', fontsize=12)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save image
        heatmap_path = os.path.join(viz_dir, 'dataset_mask_heatmap.png')
        plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"  Heatmap saved to: {heatmap_path}")
        
        # Save raw data
        df.to_csv(os.path.join(viz_dir, 'dataset_mask_counts.csv'))

    def create_tsne_plot(self, embeddings_data, viz_dir):
        """Create t-SNE dimensionality reduction scatter plot"""
        import matplotlib.pyplot as plt
        import seaborn as sns
        from sklearn.manifold import TSNE
        import numpy as np
        
        print("\nCreating t-SNE plot...")
        
        if len(embeddings_data['embeddings']) == 0:
            print("  Not enough embedding data for t-SNE")
            return
        
        # Prepare data
        X = np.array(embeddings_data['embeddings'])
        dataset_labels = embeddings_data['dataset_labels']
        mask_predictions = embeddings_data['mask_predictions']
        
        print(f"  t-SNE dimensionality reduction... (sample count: {X.shape[0]})")
        
        # t-SNE dimensionality reduction
        tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
        X_tsne = tsne.fit_transform(X)
        
        # Create plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # Plot 1: Color by dataset
        dataset_colors = {
            'arc-e': '#1f77b4',
            'arc-c': '#ff7f0e', 
            'piqa': '#2ca02c',
            'winogrande': '#d62728',
            'hellaswag': '#9467bd',
            'boolq': '#8c564b',
            'openbookqa': '#e377c2'
        }
        
        for dataset in dataset_colors:
            mask = [label == dataset for label in dataset_labels]
            if any(mask):
                ax1.scatter(X_tsne[mask, 0], X_tsne[mask, 1], 
                        c=dataset_colors[dataset], label=dataset, 
                        alpha=0.6, s=50, edgecolors='black', linewidth=0.5)
        
        ax1.set_title('t-SNE Visualization by Dataset', fontsize=14)
        ax1.set_xlabel('t-SNE Component 1')
        ax1.set_ylabel('t-SNE Component 2')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Color by predicted mask
        scatter = ax2.scatter(X_tsne[:, 0], X_tsne[:, 1], 
                            c=mask_predictions, cmap='tab10',
                            alpha=0.6, s=50, edgecolors='black', linewidth=0.5)
        
        ax2.set_title('t-SNE Visualization by Predicted Mask', fontsize=14)
        ax2.set_xlabel('t-SNE Component 1')
        ax2.set_ylabel('t-SNE Component 2')
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax2)
        cbar.set_label('Mask ID')
        cbar.set_ticks(range(self.args.num_clusters))
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save image
        tsne_path = os.path.join(viz_dir, 'router_embeddings_tsne.png')
        plt.savefig(tsne_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"  t-SNE plot saved to: {tsne_path}")


    def compute_mask_difference_matrix(self, metric='hamming'):
        """
        Compute difference matrix between masks
        
        Args:
            metric: Difference metric method
                - 'hamming': Hamming distance (proportion of different positions)
                - 'jaccard': Jaccard distance (1 - Jaccard similarity)
                - 'cosine': Cosine distance (1 - cosine similarity)
                - 'overlap': Overlap rate difference (1 - overlap rate)
                - 'euclidean': Euclidean distance
        
        Returns:
            difference_matrix: (num_masks, num_masks) difference matrix
        """
        print(f"Computing mask difference matrix (metric: {metric})")
        
        # Ensure representative masks are loaded
        if not hasattr(self, 'representative_masks'):
            self.setup_llm_for_testing()
        
        num_masks = len(self.representative_masks)
        difference_matrix = np.zeros((num_masks, num_masks))
        
        for i in range(num_masks):
            mask_i = self.representative_masks[i].astype(np.float32)
            
            for j in range(num_masks):
                if i == j:
                    difference_matrix[i][j] = 0.0
                    continue
                    
                mask_j = self.representative_masks[j].astype(np.float32)
                
                if metric == 'hamming':
                    # Hamming distance: proportion of different positions
                    diff = np.mean(mask_i != mask_j)
                    
                elif metric == 'jaccard':
                    # Jaccard distance: 1 - |A∩B| / |A∪B|
                    intersection = np.sum((mask_i == 1) & (mask_j == 1))
                    union = np.sum((mask_i == 1) | (mask_j == 1))
                    jaccard_sim = intersection / union if union > 0 else 0
                    diff = 1 - jaccard_sim
                    
                elif metric == 'cosine':
                    # Cosine distance: 1 - cos(A, B)
                    dot_product = np.dot(mask_i, mask_j)
                    norm_i = np.linalg.norm(mask_i)
                    norm_j = np.linalg.norm(mask_j)
                    cosine_sim = dot_product / (norm_i * norm_j) if (norm_i * norm_j) > 0 else 0
                    diff = 1 - cosine_sim
                    
                elif metric == 'overlap':
                    # Overlap rate difference: 1 - |A∩B| / min(|A|, |B|)
                    intersection = np.sum((mask_i == 1) & (mask_j == 1))
                    size_i = np.sum(mask_i == 1)
                    size_j = np.sum(mask_j == 1)
                    min_size = min(size_i, size_j)
                    overlap_rate = intersection / min_size if min_size > 0 else 0
                    diff = 1 - overlap_rate
                    
                elif metric == 'euclidean':
                    # Euclidean distance (normalized)
                    diff = np.linalg.norm(mask_i - mask_j) / np.sqrt(len(mask_i))
                    
                else:
                    raise ValueError(f"Unsupported metric method: {metric}")
                
                difference_matrix[i][j] = diff
        
        # Simple print matrix
        print(f"Difference matrix ({num_masks}x{num_masks}):")
        print("    " + "".join([f"M{j:2d} " for j in range(num_masks)]))
        for i in range(num_masks):
            row_str = f"M{i:2d} " + "".join([f"{difference_matrix[i][j]:4.2f} " for j in range(num_masks)])
            print(row_str)
        
        return difference_matrix




class LearnableMaskSet(nn.Module):
    """Learnable representative mask set"""
    
    def __init__(self, representative_masks, param_reg_structures, device, mask_init_strategy="original"):
        super().__init__()
        self.device = device
        self.param_reg_structures = param_reg_structures
        self.num_masks = len(representative_masks)
        
        # Convert representative masks to learnable parameters
        self.masks = nn.ParameterList()
        
        for mask_idx, flat_mask in enumerate(representative_masks):
            if mask_init_strategy == "original":
                # Directly use original mask for initialization
                mask_tensor = torch.from_numpy(flat_mask.astype(np.float32)).to(device)
            elif mask_init_strategy == "logits":
                # Convert to logits form (0.1 corresponds to ~-2.2, 0.9 corresponds to ~+2.2)
                mask_probs = np.clip(flat_mask.astype(np.float32), 0.01, 0.99)
                mask_logits = np.log(mask_probs / (1 - mask_probs))
                mask_tensor = torch.from_numpy(mask_logits).to(device)
            else:
                raise ValueError(f"Unknown mask_init_strategy: {mask_init_strategy}")
            
            # Register as learnable parameter
            self.masks.append(nn.Parameter(mask_tensor))
        
        self.mask_init_strategy = mask_init_strategy
        print(f"Initialized {self.num_masks} learnable masks, strategy: {mask_init_strategy}")
    
    def get_binary_mask(self, mask_idx, temperature=1.0):
        """Get binary mask"""
        if mask_idx >= self.num_masks:
            raise IndexError(f"Mask index {mask_idx} out of range (0-{self.num_masks-1})")
        
        raw_mask = self.masks[mask_idx]
        binary_mask = torch.sigmoid(raw_mask / temperature)
        
        return binary_mask
    
    def get_hard_binary_mask(self, mask_idx, temperature=0.1):
        """Use Straight-Through Estimator to maintain gradients"""
        soft_mask = self.get_binary_mask(mask_idx, temperature=temperature)
        hard_mask = (soft_mask > 0.5).float()
        
        hard_mask = soft_mask + (hard_mask - soft_mask).detach()
        return hard_mask
    
    def convert_flat_mask_to_layer_masks(self, flat_mask):
        """Convert flattened mask to layered masks"""
        layer_masks = []
        start_idx = 0
        
        for layer_size in self.param_reg_structures:
            end_idx = start_idx + layer_size
            layer_mask = flat_mask[start_idx:end_idx]
            layer_masks.append(layer_mask.unsqueeze(0))  # Add batch dimension
            start_idx = end_idx
        
        return layer_masks


def compute_contrastive_loss_batch(model, input_ids_batch, original_examples_batch, tokenizer, device):
    """
    Batch compute contrastive loss - modified from second file
    """
    batch_losses = []
    
    for i in range(input_ids_batch.size(0)):
        single_input = input_ids_batch[i:i+1]
        single_original_example = original_examples_batch[i] if original_examples_batch else None
        
        # Check if contrastive learning can be used
        if (single_original_example is not None and 
            isinstance(single_original_example, dict) and
            'options' in single_original_example):
            
            try:
                contrastive_loss = compute_contrastive_loss_single(
                    model, single_input, tokenizer, device, single_original_example
                )
                batch_losses.append(contrastive_loss)
            except Exception as e:
                print(f"Contrastive learning calculation failed: {e}")
                # Fall back to traditional loss
                fallback_loss = compute_fallback_loss(model, single_input, device)
                batch_losses.append(fallback_loss)
        else:
            # Use traditional loss
            fallback_loss = compute_fallback_loss(model, single_input, device)
            batch_losses.append(fallback_loss)
    
    return torch.stack(batch_losses).mean()


def compute_contrastive_loss_single(model, input_ids, tokenizer, device, original_example):
    """
    Single sample contrastive loss calculation - copied and modified from second file
    """
    dataset_name = original_example.get("dataset_name", "")
    
    if "winogrande" in dataset_name.lower():
        ctx_pref = original_example["context_prefix"]
        tgt_suf = original_example["target_suffix"]
        options = original_example["options"]
        correct_idx = original_example["label"]
        
        option_log_probs = []
        for option in options:
            full_ctx = ctx_pref + option
            ids_full = tokenizer(full_ctx + tgt_suf,
                               add_special_tokens=False,
                               return_tensors="pt").input_ids.to(device)
            ctx_len = len(tokenizer(full_ctx, add_special_tokens=False).input_ids)
            
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(ids_full).logits
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = ids_full[:, 1:].contiguous()
            
            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
            
            if ctx_len > 0 and ctx_len - 1 < token_log_probs.shape[1]:
                target_log_probs = token_log_probs[:, ctx_len-1:]
            else:
                target_log_probs = token_log_probs
            
            avg_log_prob = target_log_probs.mean()
            option_log_probs.append(avg_log_prob)
    
    else:
        # Other dataset processing
        question = original_example["question"]
        options = original_example["options"]
        correct_idx = original_example["label"]
        
        option_log_probs = []
        for option_content in options:
            full_text = f"{question} Answer: {option_content}"
            option_input_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(device)
            
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(option_input_ids).logits
            
            question_len = len(tokenizer(question, add_special_tokens=True).input_ids)
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = option_input_ids[:, 1:].contiguous()
            
            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
            
            if question_len > 0 and question_len - 1 < token_log_probs.shape[1]:
                answer_log_probs = token_log_probs[:, question_len-1:]
            else:
                answer_log_probs = token_log_probs
            
            avg_log_prob = answer_log_probs.mean()
            option_log_probs.append(avg_log_prob)
    
    logits_tensor = torch.stack(option_log_probs)
    target = torch.tensor(correct_idx, device=device, dtype=torch.long)
    contrastive_loss = F.cross_entropy(logits_tensor.unsqueeze(0), target.unsqueeze(0))
    
    return contrastive_loss


def compute_fallback_loss(model, input_ids, device):
    """Traditional language model loss"""
    with torch.cuda.amp.autocast(dtype=torch.float16):
        logits = model(input_ids).logits
    
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    log_probs = F.log_softmax(shift_logits, dim=-1)
    token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
    ce_loss = -token_log_probs.mean()
    
    return ce_loss

def get_training_temperature(epoch, total_epochs, start_temp=2.0, end_temp=0.5):
    """Improved temperature scheduling: using cosine annealing"""
    if total_epochs <= 1:
        return end_temp
    
    cos_factor = (1 + np.cos(np.pi * epoch / total_epochs)) / 2
    temperature = end_temp + (start_temp - end_temp) * cos_factor
    return temperature

def finetune_masks(trainer, args):
    """
    Main function for finetuning representative masks
    
    Args:
        trainer: MaskRouterTrainer instance
        args: Training parameters, needs to include mask finetuning related parameters
    """
    print("\n=== Starting mask finetuning stage ===")
    
    # Ensure LLM model and representative masks are loaded
    if not hasattr(trainer, 'llm_model'):
        trainer.setup_llm_for_testing()
    
    # Create learnable mask set
    learnable_masks = LearnableMaskSet(
        trainer.representative_masks,
        trainer.param_reg.structures,
        trainer.device,
        mask_init_strategy=getattr(args, 'mask_init_strategy', 'original')
    ).to(trainer.device)
    
    # Freeze router parameters
    trainer.router.eval()
    for param in trainer.router.parameters():
        param.requires_grad = False
    print("Router parameters frozen")
    
    # Freeze LLM model parameters
    for param in trainer.llm_model.parameters():
        param.requires_grad = False
    
    # Setup optimizer (only optimize mask parameters)
    mask_optimizer = torch.optim.AdamW(
        learnable_masks.parameters(),
        lr=getattr(args, 'mask_finetune_lr', 1e-3),
        weight_decay=getattr(args, 'mask_finetune_weight_decay', 0.01)
    )
    
    # Setup learning rate scheduler
    mask_epochs = getattr(args, 'mask_finetune_epochs', 10)
    if getattr(args, 'mask_use_scheduler', True):
        mask_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mask_optimizer,
            T_max=mask_epochs,
            eta_min=getattr(args, 'mask_min_lr', 1e-5)
        )

    if getattr(args, 'mask_init_strategy', 'original') == "logits":
        start_temperature = 2.0
        end_temperature = 0.5
    else:
        start_temperature = 1.5
        end_temperature = 0.3


    print(f"Temperature scheduling: {start_temperature} -> {end_temperature} (strategy: {getattr(args, 'mask_init_strategy', 'original')})")

    # Prepare training data - support supervised samples and WikiText samples
    supervised_samples = []
    wikitext_samples = []
    
    for sample in trainer.train_data:
        # Check if supervised sample (has labels and options for MC questions)
        if ('scores' in sample and 
            isinstance(sample.get('original_example'), dict) and
            'options' in sample['original_example']):
            supervised_samples.append(sample)
        # Check if WikiText sample
        elif ('scores' in sample and 
              sample.get('dataset_name') == 'wikitext' and
              'input_ids' in sample and
              len(sample['input_ids']) > 0):
            wikitext_samples.append(sample)
    
    print(f"Found {len(supervised_samples)} supervised samples for mask finetuning")
    print(f"Found {len(wikitext_samples)} WikiText samples for mask finetuning")

    # Combine all available training samples
    all_training_samples = supervised_samples + wikitext_samples
    
    if len(all_training_samples) == 0:
        print("Warning: No available training samples found, mask finetuning will be skipped")
        return learnable_masks
    
    # Setup gradient scaler
    scaler = GradScaler()
    
    # Training loop
    best_loss = float('inf')
    best_masks_state = None
    
    for epoch in range(mask_epochs):
        epoch_losses = []
        epoch_reg_losses = []
        epoch_ce_losses = []
        
        # Randomly shuffle data
        np.random.shuffle(all_training_samples)
        
        batch_size = getattr(args, 'mask_finetune_batch_size', 2)
        total_batches = (len(all_training_samples) + batch_size - 1) // batch_size
        
        progress_bar = tqdm(range(total_batches), desc=f"Mask finetuning Epoch {epoch+1}/{mask_epochs}")
        
        for batch_idx in progress_bar:
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(all_training_samples))
            batch_samples = all_training_samples[start_idx:end_idx]
            
            batch_ce_loss = 0
            batch_reg_loss = 0
            valid_samples = 0
            
            for sample in batch_samples:
                try:
                    # 1. Use router to predict best mask index
                    text = sample['text']
                    with torch.no_grad():
                        batch_tokens = trainer.router.tokenizer(
                            [text],
                            padding=True,
                            truncation=True,
                            max_length=2048,
                            return_tensors="pt"
                        ).to(trainer.device)
                                            
                        predicted_scores = self.router(batch_tokens)
                        
                        # loss = F.mse_loss(predicted_scores, target_scores)
                        loss = compute_stable_mse_loss(predicted_scores, target_scores)
                        total_loss += loss.item() * len(batch_samples)
                        
                        pred_best = torch.argmax(predicted_scores, dim=1)
                        true_best = torch.argmax(target_scores, dim=1)
                        correct_predictions += (pred_best == true_best).sum().item()
                
                    total_samples += len(batch_samples)
                except Exception as e:
                    print(f"Error processing sample: {e}")
                    continue
        
            avg_loss = total_loss / total_samples
            avg_accuracy = correct_predictions / total_samples
            
            self.router.train()
            return avg_loss, avg_accuracy
    
        def save_router(self, filename):
            
            filepath = os.path.join(self.args.output_dir, filename)
            os.makedirs(self.args.output_dir, exist_ok=True)

            model_state_dict = {
                'lora_state_dict': get_peft_model_state_dict(self.router.model),
                'classifier_state_dict': self.router.classifier.state_dict() if self.router.use_classification else self.router.regressor.state_dict()
            }
            
            save_dict = {
                'model_state_dict': model_state_dict,  
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'scaler_state_dict': self.scaler.state_dict(),
                #'metadata': self.metadata,
                'args': self.args,
                'num_clusters': self.args.num_clusters
            }
            
            torch.save(save_dict, filepath)
            print(f"Router saved to {filepath}")

        def load_router(self, checkpoint_path):
            
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            
            print("Available keys in checkpoint:", list(checkpoint.keys()))
            if 'model_state_dict' in checkpoint:
                print("Available keys in model_state_dict:", list(checkpoint['model_state_dict'].keys()))
            
            if 'model_state_dict' in checkpoint and 'lora_state_dict' in checkpoint['model_state_dict']:
                lora_state_dict = checkpoint['model_state_dict']['lora_state_dict']
                
                from peft import set_peft_model_state_dict
                set_peft_model_state_dict(self.router.model, lora_state_dict)
                print("LoRA weights loaded")
            
            if 'model_state_dict' in checkpoint and 'classifier_state_dict' in checkpoint['model_state_dict']:
                classifier_state_dict = checkpoint['model_state_dict']['classifier_state_dict']
                
                if self.router.use_classification:
                    self.router.classifier.load_state_dict(classifier_state_dict)
                    print("Classifier weights loaded")
                else:
                    self.router.regressor.load_state_dict(classifier_state_dict)
                    print("Regressor weights loaded")
            
            print(f"Router loaded from {checkpoint_path}")




class LearnableMaskSet(nn.Module):
    
    def __init__(self, representative_masks, param_reg_structures, device, mask_init_strategy="original"):
        super().__init__()
        self.device = device
        self.param_reg_structures = param_reg_structures
        self.num_masks = len(representative_masks)
        
        self.masks = nn.ParameterList()
        
        for mask_idx, flat_mask in enumerate(representative_masks):
            if mask_init_strategy == "original":
                mask_tensor = torch.from_numpy(flat_mask.astype(np.float32)).to(device)
            elif mask_init_strategy == "logits":
                mask_probs = np.clip(flat_mask.astype(np.float32), 0.01, 0.99)
                mask_logits = np.log(mask_probs / (1 - mask_probs))
                mask_tensor = torch.from_numpy(mask_logits).to(device)
            else:
                raise ValueError(f"Unknown mask_init_strategy: {mask_init_strategy}")
            
            self.masks.append(nn.Parameter(mask_tensor))
        
        self.mask_init_strategy = mask_init_strategy
    
    def get_binary_mask(self, mask_idx, temperature=1.0):
        if mask_idx >= self.num_masks:
            raise IndexError(f"Mask index {mask_idx} out of range (0-{self.num_masks-1})")
        
        raw_mask = self.masks[mask_idx]
        binary_mask = torch.sigmoid(raw_mask / temperature)
        
        return binary_mask
    
    def get_hard_binary_mask(self, mask_idx, temperature=0.1):
        soft_mask = self.get_binary_mask(mask_idx, temperature=temperature)
        hard_mask = (soft_mask > 0.5).float()
        
        hard_mask = soft_mask + (hard_mask - soft_mask).detach()
        return hard_mask
    
    def convert_flat_mask_to_layer_masks(self, flat_mask):
        layer_masks = []
        start_idx = 0
        
        for layer_size in self.param_reg_structures:
            end_idx = start_idx + layer_size
            layer_mask = flat_mask[start_idx:end_idx]
            layer_masks.append(layer_mask.unsqueeze(0))  # add batch dimension
            start_idx = end_idx
        
        return layer_masks


def compute_contrastive_loss_batch(model, input_ids_batch, original_examples_batch, tokenizer, device):

    batch_losses = []
    
    for i in range(input_ids_batch.size(0)):
        single_input = input_ids_batch[i:i+1]
        single_original_example = original_examples_batch[i] if original_examples_batch else None
        
        if (single_original_example is not None and 
            isinstance(single_original_example, dict) and
            'options' in single_original_example):
            
            try:
                contrastive_loss = compute_contrastive_loss_single(
                    model, single_input, tokenizer, device, single_original_example
                )
                batch_losses.append(contrastive_loss)
            except Exception as e:
                fallback_loss = compute_fallback_loss(model, single_input, device)
                batch_losses.append(fallback_loss)
        else:
            fallback_loss = compute_fallback_loss(model, single_input, device)
            batch_losses.append(fallback_loss)
    
    return torch.stack(batch_losses).mean()


def compute_contrastive_loss_single(model, input_ids, tokenizer, device, original_example):

    dataset_name = original_example.get("dataset_name", "")
    
    if "winogrande" in dataset_name.lower():
        ctx_pref = original_example["context_prefix"]
        tgt_suf = original_example["target_suffix"]
        options = original_example["options"]
        correct_idx = original_example["label"]
        
        option_log_probs = []
        for option in options:
            full_ctx = ctx_pref + option
            ids_full = tokenizer(full_ctx + tgt_suf,
                               add_special_tokens=False,
                               return_tensors="pt").input_ids.to(device)
            ctx_len = len(tokenizer(full_ctx, add_special_tokens=False).input_ids)
            
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(ids_full).logits
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = ids_full[:, 1:].contiguous()
            
            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
            
            if ctx_len > 0 and ctx_len - 1 < token_log_probs.shape[1]:
                target_log_probs = token_log_probs[:, ctx_len-1:]
            else:
                target_log_probs = token_log_probs
            
            avg_log_prob = target_log_probs.mean()
            option_log_probs.append(avg_log_prob)
    
    else:
        question = original_example["question"]
        options = original_example["options"]
        correct_idx = original_example["label"]
        
        option_log_probs = []
        for option_content in options:
            full_text = f"{question} Answer: {option_content}"
            option_input_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(device)
            
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(option_input_ids).logits
            
            question_len = len(tokenizer(question, add_special_tokens=True).input_ids)
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = option_input_ids[:, 1:].contiguous()
            
            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
            
            if question_len > 0 and question_len - 1 < token_log_probs.shape[1]:
                answer_log_probs = token_log_probs[:, question_len-1:]
            else:
                answer_log_probs = token_log_probs
            
            avg_log_prob = answer_log_probs.mean()
            option_log_probs.append(avg_log_prob)
    
    logits_tensor = torch.stack(option_log_probs)
    target = torch.tensor(correct_idx, device=device, dtype=torch.long)
    contrastive_loss = F.cross_entropy(logits_tensor.unsqueeze(0), target.unsqueeze(0))
    
    return contrastive_loss


def compute_fallback_loss(model, input_ids, device):
    with torch.cuda.amp.autocast(dtype=torch.float16):
        logits = model(input_ids).logits
    
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    log_probs = F.log_softmax(shift_logits, dim=-1)
    token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
    ce_loss = -token_log_probs.mean()
    
    return ce_loss

def get_training_temperature(epoch, total_epochs, start_temp=2.0, end_temp=0.5):
    #'''
    if total_epochs <= 1:
        return end_temp
    
    cos_factor = (1 + np.cos(np.pi * epoch / total_epochs)) / 2
    temperature = end_temp + (start_temp - end_temp) * cos_factor
    #'''
    return temperature

def finetune_masks(trainer, args):
    if not hasattr(trainer, 'llm_model'):
        trainer.setup_llm_for_testing()
    
    learnable_masks = LearnableMaskSet(
        trainer.representative_masks,
        trainer.param_reg.structures,
        trainer.device,
        mask_init_strategy=getattr(args, 'mask_init_strategy', 'original')
    ).to(trainer.device)
    
    trainer.router.eval()
    for param in trainer.router.parameters():
        param.requires_grad = False
    
    for param in trainer.llm_model.parameters():
        param.requires_grad = False
    
    mask_optimizer = torch.optim.AdamW(
        learnable_masks.parameters(),
        lr=getattr(args, 'mask_finetune_lr', 1e-3),
        weight_decay=getattr(args, 'mask_finetune_weight_decay', 0.01)
    )
    
    mask_epochs = getattr(args, 'mask_finetune_epochs', 10)
    if getattr(args, 'mask_use_scheduler', True):
        mask_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mask_optimizer,
            T_max=mask_epochs,
            eta_min=getattr(args, 'mask_min_lr', 1e-5)
        )

    if getattr(args, 'mask_init_strategy', 'original') == "logits":
        start_temperature = 2.0
        end_temperature = 0.5
    else:
        start_temperature = 1.5
        end_temperature = 0.3

    supervised_samples = []
    wikitext_samples = []
    
    for sample in trainer.train_data:
        if ('scores' in sample and 
            isinstance(sample.get('original_example'), dict) and
            'options' in sample['original_example']):
            supervised_samples.append(sample)
        elif ('scores' in sample and 
              sample.get('dataset_name') == 'wikitext' and
              'input_ids' in sample and
              len(sample['input_ids']) > 0):
            wikitext_samples.append(sample)

    all_training_samples = supervised_samples + wikitext_samples
    
    scaler = GradScaler()
    
    best_loss = float('inf')
    best_masks_state = None
    
    for epoch in range(mask_epochs):
        epoch_losses = []
        epoch_reg_losses = []
        epoch_ce_losses = []
        
        np.random.shuffle(all_training_samples)
        
        batch_size = getattr(args, 'mask_finetune_batch_size', 2)
        total_batches = (len(all_training_samples) + batch_size - 1) // batch_size
        
        progress_bar = tqdm(range(total_batches), desc=f"Epoch {epoch+1}/{mask_epochs}")
        
        for batch_idx in progress_bar:
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(all_training_samples))
            batch_samples = all_training_samples[start_idx:end_idx]
            
            batch_ce_loss = 0
            batch_reg_loss = 0
            valid_samples = 0
            
            for sample in batch_samples:
                try:
                    text = sample['text']
                    with torch.no_grad():
                        batch_tokens = trainer.router.tokenizer(
                            [text],
                            padding=True,
                            truncation=True,
                            max_length=2048,
                            return_tensors="pt"
                        ).to(trainer.device)
                        
                        predicted_scores = trainer.router(batch_tokens)
                        best_mask_idx = torch.argmax(predicted_scores, dim=1)[0].item()
                    
                    temperature = get_training_temperature(epoch, mask_epochs, start_temperature, end_temperature)

                    soft_mask = learnable_masks.get_binary_mask(best_mask_idx, temperature)
                    layer_masks = learnable_masks.convert_flat_mask_to_layer_masks(soft_mask)
                    
                    trainer.hn_helper.set_gate_vectors(trainer.llm_model, layer_masks)
                    trainer.hn_helper.set_gate_status(trainer.llm_model, use_gate=True)

                    gate_mean = torch.stack([m.mean() for m in layer_masks]).mean().item()
                    
                    original_example = sample.get('original_example')
                    dataset_name = sample.get('dataset_name', '')
                    
                    if dataset_name == 'wikitext' and 'input_ids' in sample and len(sample['input_ids']) > 0:
                        input_ids = torch.tensor(sample['input_ids'], dtype=torch.long).unsqueeze(0).to(trainer.device)
                        
                        with torch.cuda.amp.autocast(dtype=torch.float16):
                            logits = trainer.llm_model(input_ids).logits
                        
                        shift_logits = logits[:, :-1, :].contiguous()
                        shift_labels = input_ids[:, 1:].contiguous()
                        log_probs = F.log_softmax(shift_logits, dim=-1)
                        token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
                        ce_loss = -token_log_probs.mean()
                        
                    elif original_example and 'options' in original_example:
                        full_text = original_example.get('question', '')
                        input_ids = trainer.tokenizer(
                            full_text,
                            return_tensors="pt",
                            max_length=2048,
                            truncation=True
                        ).input_ids.to(trainer.device)
                        
                        ce_loss = compute_contrastive_loss_single(
                            trainer.llm_model, input_ids, trainer.tokenizer, 
                            trainer.device, original_example
                        )
                    else:
                        input_ids = trainer.tokenizer(
                            sample['text'],
                            return_tensors="pt",
                            max_length=2048,
                            truncation=True
                        ).input_ids.to(trainer.device)
                        
                        ce_loss = compute_fallback_loss(trainer.llm_model, input_ids, trainer.device)
                    
                    hard_mask = learnable_masks.get_hard_binary_mask(best_mask_idx, temperature=temperature)
                    hard_layer_masks = learnable_masks.convert_flat_mask_to_layer_masks(hard_mask)
                    reg_loss = trainer.param_reg(hard_layer_masks)
                    
                    batch_ce_loss += ce_loss
                    batch_reg_loss += reg_loss
                    valid_samples += 1
                    
                except Exception as e:
                    continue
            
            if valid_samples > 0:
                avg_ce_loss = batch_ce_loss / valid_samples
                avg_reg_loss = batch_reg_loss / valid_samples
                total_loss = avg_ce_loss + getattr(args, 'mask_reg_weight', 1.0) * avg_reg_loss
                
                mask_optimizer.zero_grad()
                scaler.scale(total_loss).backward()
                
                scaler.unscale_(mask_optimizer)
                torch.nn.utils.clip_grad_norm_(learnable_masks.parameters(), 
                                             getattr(args, 'mask_grad_clip', 1.0))
                
                scaler.step(mask_optimizer)
                scaler.update()
                
                epoch_losses.append(total_loss.item())
                epoch_ce_losses.append(avg_ce_loss.item())
                epoch_reg_losses.append(avg_reg_loss.item())
                
                progress_bar.set_postfix({
                    'Loss': f'{total_loss.item():.4f}',
                    'CE': f'{avg_ce_loss.item():.4f}',
                    'Reg': f'{avg_reg_loss.item():.4f}',
                    'Keep': f'{gate_mean:.2f}',
                    'Temp': f'{temperature:.2f}'
                })
        
        if getattr(args, 'mask_use_scheduler', True):
            mask_scheduler.step()
        
        trainer.hn_helper.set_gate_status(trainer.llm_model, use_gate=False)
        
        if epoch_losses:
            avg_epoch_loss = np.mean(epoch_losses)
            avg_ce_loss = np.mean(epoch_ce_losses)
            avg_reg_loss = np.mean(epoch_reg_losses)
            
            print(f"Epoch {epoch+1}: Loss={avg_epoch_loss:.4f}, CE={avg_ce_loss:.4f}, Reg={avg_reg_loss:.4f}")
            
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                best_masks_state = {name: param.clone() for name, param in learnable_masks.named_parameters()}
    
    if best_masks_state is not None:
        for name, param in learnable_masks.named_parameters():
            param.data.copy_(best_masks_state[name])
    
    with torch.no_grad():
        final_temperature = get_training_temperature(mask_epochs-1, mask_epochs, start_temperature, end_temperature)

        updated_sparsities = []

        for i in range(learnable_masks.num_masks):
            hard_mask = learnable_masks.get_hard_binary_mask(i, temperature=final_temperature)
            trainer.representative_masks[i] = hard_mask.cpu().numpy().astype(bool)

            sparsity = 1 - hard_mask.mean().item()
            updated_sparsities.append(sparsity)


        avg_sparsity = np.mean(updated_sparsities)
    
    return learnable_masks





def main():
    parser = argparse.ArgumentParser(description="")
    
    # model path
    parser.add_argument("--model_path", type=str, default="xxx/llms/meta/Llama-2-7B-hf")
    parser.add_argument("--sentence_model_path", type=str, default="xxx/llms/XX")
    parser.add_argument("--device", type=str, default="cuda:3")
    parser.add_argument("--data_dir", type=str, default="xxx/project/DynPrune/llama-2-7b/041") # xxx/project/DynPrune/llama-2-7b/041 # xxx/project/DISP/mixed/router
    parser.add_argument("--output_dir", type=str, default="xxx/project/DynPrune/llama-2-7b/041/router") # xxx/project/DynPrune/llama-2-7b/041/router # xxx/project/DISP/mixed/router
    parser.add_argument("--train_datasets", nargs='+', default=None) # wikitext
    parser.add_argument("--test_datasets", nargs='+', default=["all"]) # wikitext
    parser.add_argument("--use_classification", type=bool, default=False,
                        help="")
    parser.add_argument("--num_samples_per_dataset", type=int, default=None)
    parser.add_argument("--seed", type=int, default=58)

    # Router Param
    parser.add_argument("--num_clusters", type=int, default=10)
    parser.add_argument("--router_hidden_dim", type=int, default=256)

    # Lora
    parser.add_argument("--lora_rank", type=int, default=8, help="")
    parser.add_argument("--lora_alpha", type=int, default=32, help="")
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="")
    
    # 
    parser.add_argument("--num_epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--learning_rate", type=float, default=2e-4)  # 10^-5
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--save_every", type=int, default=10)
    parser.add_argument("--eval_every", type=int, default=10)
    
    # 
    parser.add_argument("--test_after_training", action="store_true", default=True,
                        help="")
    parser.add_argument("--num_test_samples", type=int, default=100,
                        help="")
    
    # 
    parser.add_argument("--p", type=float, default=0.8) # 0.6
    parser.add_argument("--lam", type=float, default=4.0)
    
    # 
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--lrp_scale", type=float, default=1.0)
    parser.add_argument("--base", type=float, default=0.5)
    parser.add_argument("--T_start", type=float, default=0.5)
    parser.add_argument("--T_end", type=float, default=0.1)
    parser.add_argument("--target_sparsity", type=float, default=0.2) # 0.4
    
    # 
    parser.add_argument("--max_wikitext_samples", type=int, default=1000,
                        help="")
    parser.add_argument("--normalize_lrp", type=bool, default=True, 
                        help="")
    parser.add_argument("--normalize_activations", type=bool, default=False, 
                        help="")
    
    # 
    parser.add_argument("--mode", type=str, choices=["train", "test", "individual_mask", "draw", "visualize"], default="train")
    parser.add_argument("--router_checkpoint", type=str, default=None)

    # finetune masks
    parser.add_argument("--enable_mask_finetuning", action="store_true", default=False,
                        help="")
    parser.add_argument('--num_masks_to_load', type=int, default=None, help='')
    parser.add_argument("--mask_finetune_epochs", type=int, default=5,
                        help="")
    parser.add_argument("--mask_finetune_lr", type=float, default=1e-3,
                        help="")
    parser.add_argument("--mask_finetune_weight_decay", type=float, default=0.01,
                        help="")
    parser.add_argument("--mask_finetune_batch_size", type=int, default=1,
                        help="")
    parser.add_argument("--mask_init_strategy", type=str, choices=["original", "logits"], default="original",
                        help="")
    parser.add_argument("--mask_use_scheduler", action="store_true", default=True,
                        help="")
    parser.add_argument("--mask_min_lr", type=float, default=1e-4,
                        help="")
    parser.add_argument("--mask_reg_weight", type=float, default=2.0,
                        help="")
    parser.add_argument("--mask_grad_clip", type=float, default=1.0)
    parser.add_argument("--mask_eval_every", type=int, default=5)

    # drawing
    parser.add_argument("--viz_samples_per_dataset", type=int, default=500)
    parser.add_argument("--viz_output_dir", type=str, default="./output")


    args = parser.parse_args()

    # 
    if args.num_masks_to_load is not None:
        args.output_dir = os.path.join(
            os.path.dirname(args.output_dir), 
            f"router/{args.num_masks_to_load}"
        )
        os.makedirs(args.output_dir, exist_ok=True)

    if args.mode == "train":  
        print("=" * 20)
        
        trainer = MaskRouterTrainer(args)
        trainer.train_router()

        if args.enable_mask_finetuning:
            finetuned_masks = finetune_masks(trainer, args)
            
            final_results = trainer.evaluate_on_test_set(
                test_datasets=args.test_datasets
            )
            
            torch.save({
                'finetuned_masks': finetuned_masks.state_dict(),
                'representative_masks': trainer.representative_masks,
                'args': args
            }, os.path.join(args.output_dir, "finetuned_masks.pt"))
        else:
            final_loss, final_accuracy = trainer.evaluate_router()
            print(f"Final Validation Loss: {final_loss:.6f}")
            print(f"Final Validation Accuracy: {final_accuracy:.4f}")

            trainer.evaluate_on_test_set(
                # test_datasets=['wikitext']
                test_datasets=args.test_datasets
            )

    
    elif args.mode == "test":
        router_checkpoint = "xxx/project/DISP/mixed/router/router_final.pt"
        
        
        trainer = MaskRouterTrainer(args)
        
        # checkpoint = torch.load(router_checkpoint, map_location=trainer.device)
        # trainer.router.load_state_dict(checkpoint['model_state_dict'])
        trainer.load_router(router_checkpoint)
        
        finetuned_masks_path = os.path.join(args.output_dir, "finetuned_masks.pt")
        if os.path.exists(finetuned_masks_path):
            finetuned_checkpoint = torch.load(finetuned_masks_path, map_location=trainer.device)
            
            if 'representative_masks' in finetuned_checkpoint:
                trainer.representative_masks = finetuned_checkpoint['representative_masks']

        # ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag', 'wikitext']
        test_results = trainer.evaluate_on_test_set(
            test_datasets=['wikitext']
        )




if __name__ == "__main__":
    main()