import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
import argparse
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from transformers import AutoTokenizer

# import necessary modules (adjust according to your project structure)
from models import PruneLlama2ForCausalLM
from pruning import collect_info_reg_llama
from pruning.dyn_hypernetwork import dyn_hypernetwork
from lib.dataset_loader import read_lrp_file, read_mixed_lrp_file

class MaskSimilarityAnalyzer:
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device if torch.cuda.is_available() else "cpu")
        
        # add normalization settings (consistent with training)
        self.normalize_lrp = getattr(args, 'normalize_lrp', True)
        self.normalize_activations = getattr(args, 'normalize_activations', False)
        
        self.setup_model()
        self.load_data()

    def normalize_tensor_layerwise(self, tensor, eps=1e-8):
        """
        layer-wise tensor normalization (z-score normalization)
        applicable to LRP scores and activations
        
        Args:
            tensor (torch.Tensor): input tensor, shape [batch_size, layer_size] or [layer_size]
            eps (float): numerical stability parameter, prevent division by zero
            
        Returns:
            torch.Tensor: normalized tensor
        """
        if tensor.numel() == 0:
            return tensor
        
        # take absolute value (reasonable for activations too, since we care about importance magnitude)
        tensor = torch.abs(tensor)

        # save original shape
        original_shape = tensor.shape
        
        # if 1D tensor, add batch dimension for processing
        if tensor.dim() == 1:
            tensor = tensor.unsqueeze(0)
            squeeze_later = True
        else:
            squeeze_later = False
        
        # calculate mean and std for each sample (on last dimension)
        mean = tensor.mean(dim=-1, keepdim=True)
        std = tensor.std(dim=-1, keepdim=True, unbiased=False)
        
        # prevent zero standard deviation
        std = torch.clamp(std, min=eps)
        
        # Z-score normalization
        normalized_tensor = (tensor - mean) / std
        
        # restore original shape
        if squeeze_later:
            normalized_tensor = normalized_tensor.squeeze(0)
        
        return normalized_tensor

    def setup_model(self):
        """initialize model and hypernetwork"""
        print(f"Loading model: {self.args.model_path}")
        
        # load model
        self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
        #'''
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                print(f"set pad_token to eos_token: {self.tokenizer.pad_token}")
            else:
                raise ValueError("No pad_token or eos_token found in tokenizer")
        #'''
                
        self.model = PruneLlama2ForCausalLM.from_pretrained(
            self.args.model_path,
            torch_dtype=torch.float16,
            device_map=self.device
        )
        self.model.config.use_cache = False
        self.model.eval()
        
        # get parameter regularization structures
        self.param_reg = collect_info_reg_llama(self.model, p=self.args.p, lam=self.args.lam)
        
        # initialize hypernetwork
        print("Initializing hypernetwork...")
        self.hypernetwork = dyn_hypernetwork(
            t_structures=self.param_reg.structures,
            lrp_scale=self.args.lrp_scale,
            base=self.args.base,
            T_start=self.args.T_start,
            T_end=self.args.T_end,
            target_sparsity=self.args.target_sparsity,
            hidden_dim=self.args.hidden_dim
        ).to(self.device)
        
        # load trained hypernetwork weights
        if self.args.hypernetwork_checkpoint and not self.args.run_mask_combination_analysis:
            print(f"Loading hypernetwork checkpoint: {self.args.hypernetwork_checkpoint}")
            checkpoint = torch.load(self.args.hypernetwork_checkpoint, map_location=self.device)
            if 'hypernetwork' in checkpoint:
                self.hypernetwork.load_state_dict(checkpoint['hypernetwork'])
            else:
                self.hypernetwork.load_state_dict(checkpoint)
            print("Hypernetwork loaded successfully")
    
    def load_data(self):
        """load training data"""
        print("Loading training data...")
        if hasattr(self.args, 'use_mixed_data') and self.args.use_mixed_data:
            self.train_samples_data = read_mixed_lrp_file()
        else:
            self.train_samples_data = read_lrp_file(self.args.train_lrp_path)
        
        # limit training samples
        if self.args.max_samples:
            self.train_samples_data = self.train_samples_data[:self.args.max_samples]
        
        print(f"Loaded {len(self.train_samples_data)} training samples")

    def create_train_dataset(self):
        """create training dataset"""
        from torch.utils.data import Dataset
        
        class SimpleDataset(Dataset):
            def __init__(self, samples_data, param_reg_structures, device, normalize_lrp=True, normalize_activations=False):
                self.device = device
                self.normalize_lrp = normalize_lrp
                self.normalize_activations = normalize_activations
                self.samples = []
                
                for idx in tqdm(range(len(samples_data)), desc="Processing samples"):
                    sample_data = samples_data[idx]
                    
                    # process sample_ids
                    sample_ids = sample_data["sample_id"]
                    if isinstance(sample_ids, np.ndarray):
                        sample_ids = torch.from_numpy(sample_ids).long()
                    elif not isinstance(sample_ids, torch.Tensor):
                        sample_ids = torch.tensor(sample_ids).long()
                    
                    # process activations and lrp
                    layer_activations = []
                    input_lrp = []
                    
                    lrp_scores = sample_data["lrp"]
                    activations = sample_data["activations"]
                    
                    for structure_idx in range(len(param_reg_structures)):
                        if structure_idx < len(lrp_scores) and structure_idx < len(activations):
                            activation_data = activations[structure_idx]
                            lrp_data = lrp_scores[structure_idx]
                            
                            if isinstance(activation_data, np.ndarray):
                                activation_tensor = torch.from_numpy(activation_data).float()
                            else:
                                activation_tensor = torch.tensor(activation_data).float()
                            
                            if isinstance(lrp_data, np.ndarray):
                                lrp_tensor = torch.from_numpy(lrp_data).float()
                            else:
                                lrp_tensor = torch.tensor(lrp_data).float()
                            
                            # add batch dimension
                            if activation_tensor.dim() == 1:
                                activation_tensor = activation_tensor.unsqueeze(0)
                            if lrp_tensor.dim() == 1:
                                lrp_tensor = lrp_tensor.unsqueeze(0)
                            
                            # apply normalization (consistent with training)
                            if self.normalize_lrp:
                                lrp_tensor = self.normalize_tensor_layerwise(lrp_tensor)
                            
                            if self.normalize_activations:
                                activation_tensor = self.normalize_tensor_layerwise(activation_tensor)
                            
                            layer_activations.append(activation_tensor)
                            input_lrp.append(lrp_tensor)
                    
                    # add batch dimension to sample_ids
                    if sample_ids.dim() == 1:
                        sample_ids = sample_ids.unsqueeze(0)
                    
                    self.samples.append({
                        'sample_ids': sample_ids,
                        'layer_activations': layer_activations,
                        'input_lrp': input_lrp
                    })
            
            def normalize_tensor_layerwise(self, tensor, eps=1e-8):
                """same normalization method as training code"""
                if tensor.numel() == 0:
                    return tensor
                
                tensor = torch.abs(tensor)
                original_shape = tensor.shape
                
                if tensor.dim() == 1:
                    tensor = tensor.unsqueeze(0)
                    squeeze_later = True
                else:
                    squeeze_later = False
                
                mean = tensor.mean(dim=-1, keepdim=True)
                std = tensor.std(dim=-1, keepdim=True, unbiased=False)
                std = torch.clamp(std, min=eps)
                
                normalized_tensor = (tensor - mean) / std
                
                if squeeze_later:
                    normalized_tensor = normalized_tensor.squeeze(0)
                
                return normalized_tensor
            
            def __len__(self):
                return len(self.samples)
            
            def __getitem__(self, idx):
                sample = self.samples[idx]
                return {
                    'sample_ids': sample['sample_ids'].to(self.device),
                    'layer_activations': [act.to(self.device) for act in sample['layer_activations']],
                    'input_lrp': [lrp.to(self.device) for lrp in sample['input_lrp']]
                }
        
        # pass normalization parameters
        return SimpleDataset(self.train_samples_data, self.param_reg.structures, self.device, 
                            self.normalize_lrp, self.normalize_activations)    


    def generate_masks_for_all_samples(self):
        """generate masks for all training samples"""
        train_dataset = self.create_train_dataset()
        masks = []
        layer_masks = []  # new: save mask information for each layer
        
        print("Generating masks for all training samples...")
        self.hypernetwork.eval()
        
        with torch.no_grad():
            for idx in tqdm(range(len(train_dataset))):
                sample = train_dataset[idx]
                
                # use hypernetwork to generate hard masks
                mask = self.hypernetwork.hard_output(
                    sample['layer_activations'],
                    sample['input_lrp']
                )
                
                # convert mask to binary vector
                binary_mask, layer_mask_info = self.convert_mask_to_binary(mask, return_layer_info=True)
                masks.append(binary_mask)
                layer_masks.append(layer_mask_info)
        
        return np.array(masks), layer_masks

    def load_similarity_matrix(self, filepath):
        """load saved similarity matrix"""
        try:
            similarity_matrix = np.load(filepath)
            print(f"Loaded similarity matrix from: {filepath}")
            print(f"Matrix shape: {similarity_matrix.shape}")
            return similarity_matrix
        except Exception as e:
            print(f"Failed to load similarity matrix from {filepath}: {e}")
            return None

    def convert_mask_to_binary(self, mask_list, return_layer_info=False):
        """convert mask list to a single binary vector"""
        binary_vectors = []
        layer_info = []
        
        for layer_idx, mask_tensor in enumerate(mask_list):
            # ensure mask is binary (0 or 1) and convert to boolean type
            binary_mask = (mask_tensor > 0.5).bool()
            binary_mask_np = binary_mask.cpu().numpy().flatten().astype(np.bool_)
            binary_vectors.append(binary_mask_np)
            
            if return_layer_info:
                # calculate sparsity rate statistics for this layer
                total_params = len(binary_mask_np)
                active_params = np.sum(binary_mask_np)
                sparsity_rate = 1.0 - (active_params / total_params)
                
                layer_info.append({
                    'layer_idx': layer_idx,
                    'total_params': total_params,
                    'active_params': active_params,
                    'sparsity_rate': sparsity_rate,
                    'mask': binary_mask_np
                })
        
        # concatenate masks for all layers
        concatenated_mask = np.concatenate(binary_vectors)
        
        if return_layer_info:
            return concatenated_mask, layer_info
        else:
            return concatenated_mask

    def compute_jaccard_similarity_matrix(self, masks):
        """compute Jaccard similarity between all mask pairs"""
        n_samples = len(masks)
        similarity_matrix = np.zeros((n_samples, n_samples))
        
        print("Computing Jaccard similarity matrix...")
        for i in tqdm(range(n_samples)):
            for j in range(i, n_samples):
                jaccard_sim = self.jaccard_similarity(masks[i], masks[j])
                similarity_matrix[i, j] = jaccard_sim
                similarity_matrix[j, i] = jaccard_sim  # symmetric matrix
        
        # save similarity matrix
        if hasattr(self.args, 'similarity_matrix_path') and self.args.similarity_matrix_path:
            try:
                os.makedirs(os.path.dirname(self.args.similarity_matrix_path), exist_ok=True)
                np.save(self.args.similarity_matrix_path, similarity_matrix)
                print(f"Similarity matrix saved to: {self.args.similarity_matrix_path}")
            except Exception as e:
                print(f"Failed to save similarity matrix: {e}")
        
        return similarity_matrix

    def jaccard_similarity(self, mask1, mask2):
        """compute Jaccard similarity between two binary masks"""
        # ensure masks are boolean type
        mask1 = mask1.astype(np.bool_)
        mask2 = mask2.astype(np.bool_)
        
        intersection = np.sum(mask1 & mask2)
        union = np.sum(mask1 | mask2)
        
        if union == 0:
            return 1.0  # both masks are all zeros
        
        return intersection / union
    
    def perform_clustering(self, similarity_matrix, n_clusters_range=None, specified_clusters=None):
        """perform clustering analysis based on similarity matrix"""
        # convert similarity matrix to distance matrix
        # distance_matrix = 1 - similarity_matrix
        features = similarity_matrix
        
        if n_clusters_range is None:
            n_clusters_range = range(2, min(21, len(similarity_matrix) // 2))
        
        # try different numbers of clusters
        results = {}
        
        '''
        print("Performing clustering analysis...")
        for n_clusters in n_clusters_range:
            # use hierarchical clustering
            clustering = AgglomerativeClustering(
                n_clusters=n_clusters,
                metric='precomputed',
                linkage='average'
            )
            cluster_labels = clustering.fit_predict(distance_matrix)
            
            # calculate silhouette score
            silhouette_avg = silhouette_score(distance_matrix, cluster_labels, metric='precomputed')
            
            results[n_clusters] = {
                'labels': cluster_labels,
                'silhouette_score': silhouette_avg
            }
            
            print(f"n_clusters={n_clusters}, silhouette_score={silhouette_avg:.4f}")
        
        '''
        print("Performing balanced K-means clustering analysis...")
        for n_clusters in n_clusters_range:
            
            # perform K-means clustering
            kmeans = KMeans(
                n_clusters=n_clusters, 
                init='k-means++',  # can also use custom initialization
                random_state=58,
                n_init=10,
                max_iter=300
            )
            
            cluster_labels = kmeans.fit_predict(features)
            
            # calculate silhouette score (requires original distance matrix)
            distance_matrix = 1 - similarity_matrix
            silhouette_avg = silhouette_score(distance_matrix, cluster_labels, metric='precomputed')
            
            results[n_clusters] = {
                'labels': cluster_labels,
                'silhouette_score': silhouette_avg,
                'cluster_centers': kmeans.cluster_centers_,
                'inertia': kmeans.inertia_
            }
            
            # print size of each cluster
            cluster_sizes = [np.sum(cluster_labels == i) for i in range(n_clusters)]
            print(f"n_clusters={n_clusters}, silhouette_score={silhouette_avg:.4f}, "
                f"sizes={cluster_sizes}")
            
        best_n_clusters = max(results.keys(), key=lambda k: results[k]['silhouette_score'])
        print(f"Best number of clusters: {best_n_clusters}")
        return results, best_n_clusters
    
    def visualize_results(self, similarity_matrix, clustering_results, best_n_clusters, layer_stats=None):
        """visualize results"""
        if layer_stats is not None:
            fig, axes = plt.subplots(3, 2, figsize=(15, 18))
        else:
            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # 1. similarity matrix heatmap
        sns.heatmap(similarity_matrix, annot=False, cmap='viridis', ax=axes[0, 0])
        axes[0, 0].set_title('Jaccard Similarity Matrix')
        axes[0, 0].set_xlabel('Sample Index')
        axes[0, 0].set_ylabel('Sample Index')
        
        # 2. silhouette score vs number of clusters
        n_clusters_list = list(clustering_results.keys())
        silhouette_scores = [clustering_results[k]['silhouette_score'] for k in n_clusters_list]
        
        axes[0, 1].plot(n_clusters_list, silhouette_scores, 'bo-')
        axes[0, 1].set_xlabel('Number of Clusters')
        axes[0, 1].set_ylabel('Silhouette Score')
        axes[0, 1].set_title('Silhouette Score vs Number of Clusters')
        axes[0, 1].grid(True)
        
        # 3. similarity matrix for best clustering result (sorted by clusters)
        best_labels = clustering_results[best_n_clusters]['labels']
        sorted_indices = np.argsort(best_labels)
        sorted_similarity = similarity_matrix[sorted_indices][:, sorted_indices]
        
        sns.heatmap(sorted_similarity, annot=False, cmap='viridis', ax=axes[1, 0])
        axes[1, 0].set_title(f'Similarity Matrix (Sorted by {best_n_clusters} Clusters)')
        axes[1, 0].set_xlabel('Sample Index (Sorted)')
        axes[1, 0].set_ylabel('Sample Index (Sorted)')
        
        # 4. t-SNE visualization
        if len(similarity_matrix) > 3:  # t-SNE requires at least 4 samples
            try:
                # use distance matrix for t-SNE, set correct parameters
                distance_matrix = 1 - similarity_matrix
                # ensure distance matrix is symmetric and diagonal is 0
                distance_matrix = (distance_matrix + distance_matrix.T) / 2
                np.fill_diagonal(distance_matrix, 0)
                
                tsne = TSNE(n_components=2, 
                           metric='precomputed', 
                           init='random',  # use random initialization instead of PCA
                           random_state=42,
                           perplexity=min(30, len(similarity_matrix) - 1))  # adjust perplexity
                tsne_result = tsne.fit_transform(distance_matrix)
                
                scatter = axes[1, 1].scatter(tsne_result[:, 0], tsne_result[:, 1], 
                                           c=best_labels, cmap='tab10', s=50)
                axes[1, 1].set_title(f't-SNE Visualization ({best_n_clusters} Clusters)')
                axes[1, 1].set_xlabel('t-SNE 1')
                axes[1, 1].set_ylabel('t-SNE 2')
                
                # add legend
                for cluster_id in range(best_n_clusters):
                    cluster_indices = np.where(best_labels == cluster_id)[0]
                    if len(cluster_indices) > 0:
                        axes[1, 1].scatter([], [], c=plt.cm.tab10(cluster_id), 
                                         label=f'Cluster {cluster_id} ({len(cluster_indices)})')
                axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
                
            except Exception as e:
                # if pre-computed distance fails, try using similarity data directly
                try:
                    tsne = TSNE(n_components=2, random_state=42, 
                               perplexity=min(30, len(similarity_matrix) - 1))
                    tsne_result = tsne.fit_transform(similarity_matrix)
                    
                    scatter = axes[1, 1].scatter(tsne_result[:, 0], tsne_result[:, 1], 
                                               c=best_labels, cmap='tab10', s=50)
                    axes[1, 1].set_title(f't-SNE Visualization ({best_n_clusters} Clusters)')
                    axes[1, 1].set_xlabel('t-SNE 1')
                    axes[1, 1].set_ylabel('t-SNE 2')
                    
                    # add legend
                    for cluster_id in range(best_n_clusters):
                        cluster_indices = np.where(best_labels == cluster_id)[0]
                        if len(cluster_indices) > 0:
                            axes[1, 1].scatter([], [], c=plt.cm.tab10(cluster_id), 
                                             label=f'Cluster {cluster_id} ({len(cluster_indices)})')
                    axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
                    
                except Exception as e2:
                    axes[1, 1].text(0.5, 0.5, f't-SNE failed: {str(e2)}', 
                                   ha='center', va='center', transform=axes[1, 1].transAxes)
                    axes[1, 1].set_title('t-SNE Visualization (Failed)')
        else:
            axes[1, 1].text(0.5, 0.5, 'Not enough samples for t-SNE', 
                           ha='center', va='center', transform=axes[1, 1].transAxes)
            axes[1, 1].set_title('t-SNE Visualization (Insufficient Data)')
        
        # 5. layer-wise sparsity statistics (if available)
        if layer_stats is not None:
            # 5a. mean sparsity and std for each layer
            layer_indices = list(layer_stats.keys())
            means = [layer_stats[i]['mean'] for i in layer_indices]
            stds = [layer_stats[i]['std'] for i in layer_indices]
            
            axes[2, 0].errorbar(layer_indices, means, yerr=stds, marker='o', capsize=3)
            axes[2, 0].set_xlabel('Layer Index')
            axes[2, 0].set_ylabel('Sparsity Rate')
            axes[2, 0].set_title('Layer-wise Sparsity Rate (Mean ± Std)')
            axes[2, 0].grid(True, alpha=0.3)
            
            # 5b. boxplot of sparsity rate distribution (only show first 20 layers to avoid crowding)
            display_layers = min(20, len(layer_indices))
            sparsity_data = [layer_stats[i]['sparsity_rates'] for i in layer_indices[:display_layers]]
            
            bp = axes[2, 1].boxplot(sparsity_data, patch_artist=True)
            axes[2, 1].set_xlabel('Layer Index')
            axes[2, 1].set_ylabel('Sparsity Rate')
            axes[2, 1].set_title(f'Sparsity Rate Distribution (First {display_layers} Layers)')
            axes[2, 1].grid(True, alpha=0.3)
            
            # add colors to boxplot
            colors = plt.cm.viridis(np.linspace(0, 1, display_layers))
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        
        plt.tight_layout()
        
        # save image
        if hasattr(self.args, 'output_dir') and self.args.output_dir:
            os.makedirs(self.args.output_dir, exist_ok=True)
            plt.savefig(os.path.join(self.args.output_dir, 'mask_similarity_analysis.png'), 
                       dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def analyze_layer_sparsity_statistics(self, layer_masks):
        """analyze sparsity statistics for each layer"""
        print("\n=== Layer-wise Sparsity Analysis ===")
        
        # get number of layers
        num_layers = len(layer_masks[0])
        num_samples = len(layer_masks)
        
        layer_stats = {}
        
        for layer_idx in range(num_layers):
            # collect sparsity rates for all samples in this layer
            layer_sparsity_rates = [sample[layer_idx]['sparsity_rate'] for sample in layer_masks]
            layer_total_params = layer_masks[0][layer_idx]['total_params']
            
            # calculate statistics
            mean_sparsity = np.mean(layer_sparsity_rates)
            std_sparsity = np.std(layer_sparsity_rates)
            var_sparsity = np.var(layer_sparsity_rates)
            min_sparsity = np.min(layer_sparsity_rates)
            max_sparsity = np.max(layer_sparsity_rates)
            median_sparsity = np.median(layer_sparsity_rates)
            
            layer_stats[layer_idx] = {
                'sparsity_rates': layer_sparsity_rates,
                'total_params': layer_total_params,
                'mean': mean_sparsity,
                'std': std_sparsity,
                'var': var_sparsity,
                'min': min_sparsity,
                'max': max_sparsity,
                'median': median_sparsity,
                'range': max_sparsity - min_sparsity
            }
            
            print(f"Layer {layer_idx:2d} (params: {layer_total_params:6d}): "
                  f"mean={mean_sparsity:.4f}, std={std_sparsity:.4f}, "
                  f"min={min_sparsity:.4f}, max={max_sparsity:.4f}, "
                  f"range={max_sparsity - min_sparsity:.4f}")
        
        # calculate overall statistics
        all_means = [stats['mean'] for stats in layer_stats.values()]
        all_stds = [stats['std'] for stats in layer_stats.values()]
        all_ranges = [stats['range'] for stats in layer_stats.values()]
        
        print(f"\n=== Overall Layer Statistics ===")
        print(f"Average sparsity across layers: {np.mean(all_means):.4f} ± {np.std(all_means):.4f}")
        print(f"Average std across layers: {np.mean(all_stds):.4f}")
        print(f"Average range across layers: {np.mean(all_ranges):.4f}")
        print(f"Layer with highest variability: {np.argmax(all_stds)} (std={np.max(all_stds):.4f})")
        print(f"Layer with lowest variability: {np.argmin(all_stds)} (std={np.min(all_stds):.4f})")
        
    def analyze_cluster_statistics(self, similarity_matrix, clustering_results, best_n_clusters):
        """analyze clustering statistics"""
        best_labels = clustering_results[best_n_clusters]['labels']
        
        print(f"\n=== Cluster Analysis Results ===")
        print(f"Best number of clusters: {best_n_clusters}")
        print(f"Best silhouette score: {clustering_results[best_n_clusters]['silhouette_score']:.4f}")
        
        # calculate statistics for each cluster
        for cluster_id in range(best_n_clusters):
            cluster_mask = (best_labels == cluster_id)
            cluster_indices = np.where(cluster_mask)[0]
            cluster_size = len(cluster_indices)
            
            if cluster_size > 1:
                # calculate intra-cluster similarity
                cluster_similarities = similarity_matrix[cluster_mask][:, cluster_mask]
                avg_intra_similarity = np.mean(cluster_similarities[np.triu_indices_from(cluster_similarities, k=1)])
                
                print(f"\nCluster {cluster_id}:")
                print(f"  Size: {cluster_size} samples")
                print(f"  Samples: {cluster_indices.tolist()}")
                print(f"  Average intra-cluster similarity: {avg_intra_similarity:.4f}")
            else:
                print(f"\nCluster {cluster_id}:")
                print(f"  Size: {cluster_size} samples (singleton)")
                print(f"  Samples: {cluster_indices.tolist()}")
        
        # calculate overall statistics
        overall_avg_similarity = np.mean(similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)])
        print(f"\nOverall average similarity: {overall_avg_similarity:.4f}")
        
        return {
            'best_n_clusters': best_n_clusters,
            'best_silhouette_score': clustering_results[best_n_clusters]['silhouette_score'],
            'cluster_labels': best_labels,
            'overall_avg_similarity': overall_avg_similarity
        }

    def save_clustering_results(self, results):
        """save clustering results for Router"""
        output_path = self.args.clustering_results_save_path
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        try:
            with open(output_path, 'wb') as f:
                pickle.dump(results, f)
            print(f"Clustering results saved to: {output_path}")
        except Exception as e:
            print(f"Failed to save clustering results: {e}")

    def run_analysis(self):
        """run complete mask similarity analysis"""
        print("Starting mask similarity analysis...")
        
        # 1. generate masks for all samples
        masks, layer_masks = self.generate_masks_for_all_samples()
        print(f"Generated masks shape: {masks.shape}")
        
        # 2. analyze sparsity rate statistics for each layer
        layer_stats = self.analyze_layer_sparsity_statistics(layer_masks)
        
        # 3. compute or load Jaccard similarity matrix
        similarity_matrix = None
        
        # check if saved similarity matrix exists
        if (hasattr(self.args, 'similarity_matrix_path') and 
            self.args.similarity_matrix_path and 
            os.path.exists(self.args.similarity_matrix_path)):
            
            similarity_matrix = self.load_similarity_matrix(self.args.similarity_matrix_path)
            
            # validate matrix dimensions
            if similarity_matrix is not None and similarity_matrix.shape[0] != len(masks):
                print(f"Warning: Saved matrix shape {similarity_matrix.shape} doesn't match current data shape {len(masks)}. Recomputing...")
                similarity_matrix = None
        
        # specify number of clusters
        specified_clusters = 20

        # if not loaded successfully, recompute
        if similarity_matrix is None:
            similarity_matrix = self.compute_jaccard_similarity_matrix(masks)
        
        # 4. perform clustering analysis
        clustering_results, _ = self.perform_clustering(similarity_matrix)
        
        # 5. analyze clustering statistics
        stats = self.analyze_cluster_statistics(similarity_matrix, clustering_results, specified_clusters)
        
        # 6. visualize results (including layer-wise sparsity)
        self.visualize_results(similarity_matrix, clustering_results, specified_clusters, layer_stats)

        results = {
            'masks': masks,
            'layer_masks': layer_masks,
            'layer_stats': layer_stats,
            'similarity_matrix': similarity_matrix,
            'clustering_results': clustering_results,
            'statistics': stats
        }
        
        self.save_clustering_results(results)

        return results
    
    def run_mask_combination_analysis(self):
        """run mask combination analysis"""
        print("Starting mask combination analysis...")
        
        # 1. generate all wikitext masks
        all_wikitext_masks = self.generate_all_wikitext_masks()
        print(f"Generated {len(all_wikitext_masks)} wikitext masks")
        
        # 2. find minimal mask combination
        mask_combination_results = self.find_minimal_mask_combination(
            all_wikitext_masks,
            arc_e_coverage_threshold=getattr(self.args, 'arc_e_coverage_threshold', 0.95),
            max_mask_combination_size=getattr(self.args, 'max_mask_combination_size', 10),
            use_optimized_algorithm=getattr(self.args, 'use_optimized_algorithm', True)
        )
        
        if mask_combination_results:
            # save mask combination results
            if hasattr(self.args, 'mask_combination_save_path') and self.args.mask_combination_save_path:
                os.makedirs(os.path.dirname(self.args.mask_combination_save_path), exist_ok=True)
                with open(self.args.mask_combination_save_path, 'wb') as f:
                    pickle.dump(mask_combination_results, f)
                print(f"Mask combination results saved to: {self.args.mask_combination_save_path}")
        
        return mask_combination_results

    def generate_all_wikitext_masks(self):
        """generate all wikitext masks"""
        print("Generating all wikitext masks...")
        
        # load wikitext hypernetwork
        wikitext_hypernetwork_path = "xxx/project/DISP/wikitext/final_hypernetwork.pt"
        if not os.path.exists(wikitext_hypernetwork_path):
            raise FileNotFoundError(f"Wikitext hypernetwork not found: {wikitext_hypernetwork_path}")
        
        # load wikitext hypernetwork
        from pruning.dyn_hypernetwork import dyn_hypernetwork
        wikitext_hypernetwork = dyn_hypernetwork(
            t_structures=self.param_reg.structures,
            lrp_scale=self.args.lrp_scale,
            base=self.args.base,
            T_start=self.args.T_start,
            T_end=self.args.T_end,
            target_sparsity=self.args.target_sparsity,
            hidden_dim=self.args.hidden_dim
        ).to(self.device)
        
        # load trained wikitext hypernetwork weights
        checkpoint = torch.load(wikitext_hypernetwork_path, map_location=self.device)
        if 'hypernetwork' in checkpoint:
            wikitext_hypernetwork.load_state_dict(checkpoint['hypernetwork'])
        else:
            wikitext_hypernetwork.load_state_dict(checkpoint)
        print("Wikitext hypernetwork loaded successfully")
        
        # generate wikitext masks
        wikitext_masks = []
        
        # save original data path and sample data
        original_train_lrp_path = self.args.train_lrp_path
        original_train_samples_data = self.train_samples_data
        
        try:
            # load wikitext data
            self.args.train_lrp_path = "xxx/project/DISP/wikitext/lrp_train_ppl.pkl"
            self.load_data()  # reload wikitext data
            
            # limit wikitext sample count for efficiency
            if hasattr(self.args, 'max_wikitext_samples') and self.args.max_wikitext_samples:
                max_samples = min(self.args.max_wikitext_samples, len(self.train_samples_data))
                self.train_samples_data = self.train_samples_data[:max_samples]
                print(f"Limiting wikitext sample count to: {max_samples}")
            
            # create wikitext dataset
            wikitext_dataset = self.create_train_dataset()
            print(f"Generating masks using {len(wikitext_dataset)} wikitext samples")
            
            wikitext_hypernetwork.eval()
            with torch.no_grad():
                for idx in tqdm(range(len(wikitext_dataset)), desc="Generating wikitext masks"):
                    sample = wikitext_dataset[idx]
                    
                    # use wikitext hypernetwork to generate hard masks
                    mask = wikitext_hypernetwork.hard_output(
                        sample['layer_activations'],
                        sample['input_lrp']
                    )
                    
                    # convert mask to binary vector
                    binary_mask = self.convert_mask_to_binary(mask)
                    wikitext_masks.append(binary_mask)
        finally:
            # restore original data
            self.args.train_lrp_path = original_train_lrp_path
            self.train_samples_data = original_train_samples_data
        
        return wikitext_masks
    
    def find_minimal_mask_combination(self, all_wikitext_masks, arc_e_coverage_threshold=0.95, max_mask_combination_size=10, use_optimized_algorithm=True):
        """
        find minimal mask combination such that most arc-e training data can find at least one mask that predicts correctly
        
        Args:
            all_wikitext_masks (list): list of all wikitext masks
            arc_e_coverage_threshold (float): coverage threshold for arc-e dataset, default 0.95
            max_mask_combination_size (int): maximum mask combination size to prevent excessive search space
            use_optimized_algorithm (bool): whether to use optimized algorithm
            
        Returns:
            dict: results containing optimal mask combination and statistics
        """
        print(f"\n=== Finding minimal mask combination (coverage threshold: {arc_e_coverage_threshold}) ===")
        
        # 1. load arc-e dataset
        # arc_e_dataset = self.load_arc_e_dataset()
        multi_dataset_names = ["arc-e", "arc-c", "hellaswag", "piqa", "winogrande"]
        per_dataset_sample_num = getattr(self.args, "per_dataset_sample_num", 100)
        arc_e_dataset = self.load_multi_mc_datasets(multi_dataset_names, per_dataset_sample_num)

        if not arc_e_dataset:
            print("Failed to load arc-e dataset")
            return None
        
        print(f"Loaded {len(arc_e_dataset)} arc-e samples")
        print(f"Using {len(all_wikitext_masks)} wikitext masks")
        
        # 2. evaluate all wikitext masks on each arc-e sample for prediction effect
        sample_mask_results = self.evaluate_masks_on_arc_e(arc_e_dataset, all_wikitext_masks)
        
        # 3. use algorithm to find minimal mask combination
        if use_optimized_algorithm:
            optimal_mask_combination = self.optimized_mask_selection(
                sample_mask_results, 
                arc_e_coverage_threshold, 
                max_mask_combination_size,
                all_wikitext_masks
            )
        else:
            optimal_mask_combination = self.greedy_mask_selection(
                sample_mask_results, 
                arc_e_coverage_threshold, 
                max_mask_combination_size,
                all_wikitext_masks
            )
        
        # 4. analyze results
        results = self.analyze_mask_combination_results(
            optimal_mask_combination, 
            sample_mask_results, 
            arc_e_dataset,
            all_wikitext_masks
        )
        
        return results
    
    def load_arc_e_dataset(self):
        """load arc-e dataset"""
        try:
            from lib.dataset_loader import load_mc_dataset, format_mc_example
            
            # load arc-e dataset
            dataset = load_mc_dataset("arc-e", split="train")
            
            # format samples
            formatted_samples = []
            for example in dataset:
                formatted_example = format_mc_example(example, "arc-e")
                formatted_example["dataset_name"] = "arc-e"
                formatted_example["original_example"] = example
                formatted_samples.append(formatted_example)
            
            return formatted_samples
            
        except Exception as e:
            print(f"Failed to load arc-e dataset: {e}")
            return None

    def load_multi_mc_datasets(self, dataset_names, per_dataset_sample_num=100):
        """
        sample fixed number of samples from the training set of multiple multiple choice datasets,
        and merge them into a list
        """
        from lib.dataset_loader import load_mc_dataset, format_mc_example
        all_samples = []
        for name in dataset_names:
            ds = load_mc_dataset(name, split="train")
            # random sampling
            indices = np.random.choice(len(ds), size=min(per_dataset_sample_num, len(ds)), replace=False)
            for idx in indices:
                example = ds[int(idx)]
                formatted = format_mc_example(example, name)
                formatted["dataset_name"] = name
                formatted["original_example"] = example
                all_samples.append(formatted)
        return all_samples

    def evaluate_masks_on_arc_e(self, arc_e_dataset, all_wikitext_masks):
        """
        evaluate prediction effect of all wikitext masks on each arc-e sample
        optimized version: prioritize looping through each wikitext mask to reduce set_gate_vectors calls
        
        Args:
            arc_e_dataset (list): arc-e dataset
            all_wikitext_masks (list): list of all wikitext masks
            
        Returns:
            dict: prediction results for each sample under each mask
        """
        print("Evaluating all wikitext masks on arc-e dataset...")
        
        # check for cached results
        cache_path = getattr(self.args, 'mask_evaluation_cache_path', None)
        if cache_path and os.path.exists(cache_path):
            try:
                with open(cache_path, 'rb') as f:
                    cached_results = pickle.load(f)
                print(f"Loading evaluation results from cache: {cache_path}")
                return cached_results
            except Exception as e:
                print(f"Failed to load cache: {e}")
        
        # set model for evaluation
        self.setup_llm_for_evaluation()
        
        # initialize result dictionary
        sample_mask_results = {sample_idx: {} for sample_idx in range(len(arc_e_dataset))}
        
        # optimized: prioritize looping through each wikitext mask, then test all arc-e samples
        for mask_idx in tqdm(range(len(all_wikitext_masks)), desc="Evaluating wikitext masks"):
            try:
                # apply current mask (set only once)
                current_mask = all_wikitext_masks[mask_idx]
                single_masks = self.convert_flat_mask_to_layer_masks(current_mask)
                self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                
                # batch test all arc-e samples
                try:
                    batch_results = self._evaluate_batch_samples(arc_e_dataset)
                    for sample_idx, result in enumerate(batch_results):
                        sample_mask_results[sample_idx][mask_idx] = result
                except Exception as e:
                    print(f"Mask {mask_idx} batch evaluation failed: {e}")
                    # if batch evaluation fails, fallback to single evaluation
                    for sample_idx, formatted_example in enumerate(arc_e_dataset):
                        try:
                            result = self._evaluate_single_sample(formatted_example)
                            sample_mask_results[sample_idx][mask_idx] = result
                        except Exception as e2:
                            print(f"Mask {mask_idx}, sample {sample_idx} evaluation failed: {e2}")
                            sample_mask_results[sample_idx][mask_idx] = self._get_default_result(formatted_example)
                
                # restore model state (restore after each mask test)
                self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
                
            except Exception as e:
                print(f"Mask {mask_idx} setting failed: {e}")
                # if mask setting fails, set default result for all samples
                for sample_idx in range(len(arc_e_dataset)):
                    sample_mask_results[sample_idx][mask_idx] = self._get_default_result(arc_e_dataset[sample_idx])
        
        # save cached results
        if cache_path:
            try:
                os.makedirs(os.path.dirname(cache_path), exist_ok=True)
                with open(cache_path, 'wb') as f:
                    pickle.dump(sample_mask_results, f)
                print(f"Evaluation results cached to: {cache_path}")
            except Exception as e:
                print(f"Failed to save cache: {e}")
        
        return sample_mask_results
    
    def _evaluate_single_sample(self, formatted_example):
        """evaluate single sample"""
        from lib.dataset_loader import evaluate_mc_example
        result = evaluate_mc_example(
            self.llm_model, self.tokenizer, formatted_example,
            device=self.device, max_length=2048
        )
        
        return {
            'is_correct': result["is_correct"],
            'is_correct_normalized': result["is_correct_normalized"],
            'prediction': result["prediction"],
            'label': result["label"]
        }
    
    def _evaluate_batch_samples(self, formatted_examples):
        """batch evaluate multiple samples"""
        if self.args.mask_evaluation_batch_size == 1:
            from lib.dataset_loader import evaluate_mc_example
            results = [
                evaluate_mc_example(
                    self.llm_model, self.tokenizer, example,
                    device=self.device, max_length=2048
                )
                for example in formatted_examples
            ]
        else:
            from lib.dataset_loader import evaluate_mc_examples_batch
            results = evaluate_mc_examples_batch(
                self.llm_model, self.tokenizer, formatted_examples,
                device=self.device, 
                max_length=2048,
                batch_size=self.args.mask_evaluation_batch_size
            )
        return [
            {
                'is_correct': result["is_correct"],
                'is_correct_normalized': result["is_correct_normalized"],
                'prediction': result["prediction"],
                'label': result["label"]
            }
            for result in results
        ]
    
    def _get_default_result(self, formatted_example):
        """get default evaluation result"""
        return {
            'is_correct': False,
            'is_correct_normalized': False,
            'prediction': -1,
            'label': formatted_example.get('label', -1)
        }
    
    def setup_llm_for_evaluation(self):
        """set up LLM model for evaluation - fixed version"""
        if not hasattr(self, 'hn_helper'):
            from pruning import help_functions_hn
            self.hn_helper = help_functions_hn(self.param_reg.structures)
        
        # fixed: do not reload model, directly use the existing model instance
        if not hasattr(self, 'llm_model'):
            if hasattr(self, 'model'):
                # if model already exists, use it directly
                self.llm_model = self.model
                print("✅ Reusing existing model instance")
            else:
                # if no model, load new one
                print("⚠️  No existing model loaded, loading...")
                from models import PruneLlama2ForCausalLM
                self.llm_model = PruneLlama2ForCausalLM.from_pretrained(
                    self.args.model_path,
                    torch_dtype=torch.float16,
                    device_map=self.device
                )
                self.llm_model.config.use_cache = False
                self.llm_model.eval()
                # also set model attribute
                self.model = self.llm_model
    
    def convert_flat_mask_to_layer_masks(self, flat_mask):
        """convert flat mask to layer masks"""
        layer_masks = []
        start_idx = 0
        
        for layer_size in self.param_reg.structures:
            end_idx = start_idx + layer_size
            layer_mask = flat_mask[start_idx:end_idx]
            layer_mask_tensor = torch.from_numpy(layer_mask.astype(np.float32))
            layer_masks.append(layer_mask_tensor)
            start_idx = end_idx
        
        return layer_masks
    
    def greedy_mask_selection(self, sample_mask_results, coverage_threshold, max_combination_size, all_wikitext_masks):
        """
        use greedy algorithm to select minimal mask combination
        
        Args:
            sample_mask_results (dict): prediction results for each sample under each mask
            coverage_threshold (float): coverage threshold
            max_combination_size (int): maximum combination size
            all_wikitext_masks (list): list of all wikitext masks
            
        Returns:
            dict: optimal mask combination information
        """
        print("Using greedy algorithm to select minimal mask combination...")
        
        num_samples = len(sample_mask_results)
        num_masks = len(all_wikitext_masks)
        
        # count how many samples each mask can predict correctly
        mask_coverage = {}
        for mask_idx in range(num_masks):
            correct_samples = set()
            for sample_idx, sample_results in sample_mask_results.items():
                if sample_results[mask_idx]['is_correct']:
                    correct_samples.add(sample_idx)
            mask_coverage[mask_idx] = correct_samples
        
        # greedy selection
        selected_masks = []
        covered_samples = set(range(num_samples))
        remaining_samples = set(range(num_samples))
        
        while len(selected_masks) < max_combination_size and remaining_samples:
            best_mask = None
            best_coverage = 0
            
            # find mask that covers the most uncovered samples
            for mask_idx in range(num_masks):
                if mask_idx in selected_masks:
                    continue
                
                new_coverage = len(mask_coverage[mask_idx] & remaining_samples)
                if new_coverage > best_coverage:
                    best_coverage = new_coverage
                    best_mask = mask_idx
            
            if best_mask is None or best_coverage == 0:
                break
            
            # add best mask
            selected_masks.append(best_mask)
            newly_covered = mask_coverage[best_mask] & remaining_samples
            covered_samples.update(newly_covered)
            remaining_samples -= newly_covered
            
            current_coverage = len(covered_samples) / num_samples
            print(f"Selecting mask {best_mask}: added {len(newly_covered)} samples, "
                  f"total coverage: {current_coverage:.4f} ({len(covered_samples)}/{num_samples})")
            
            # check if coverage threshold is met
            if current_coverage >= coverage_threshold:
                print(f"Coverage threshold {coverage_threshold} met")
                break
        
        return {
            'selected_masks': selected_masks,
            'coverage': len(covered_samples) / num_samples,
            'covered_samples': len(covered_samples),
            'total_samples': num_samples,
            'uncovered_samples': list(remaining_samples)
        }
    
    def optimized_mask_selection(self, sample_mask_results, coverage_threshold, max_combination_size, all_wikitext_masks):
        """
        optimized mask selection algorithm, using more efficient set operations
        
        Args:
            sample_mask_results (dict): prediction results for each sample under each mask
            coverage_threshold (float): coverage threshold
            max_combination_size (int): maximum combination size
            all_wikitext_masks (list): list of all wikitext masks
            
        Returns:
            dict: optimal mask combination information
        """
        print("Using optimized mask selection algorithm...")
        
        num_samples = len(sample_mask_results)
        num_masks = len(all_wikitext_masks)
        
        # pre-calculate coverage for each mask (using numpy array for efficiency)
        mask_coverage_matrix = np.zeros((num_masks, num_samples), dtype=bool)
        for mask_idx in range(num_masks):
            for sample_idx, sample_results in sample_mask_results.items():
                if sample_results[mask_idx]['is_correct']:
                    mask_coverage_matrix[mask_idx, sample_idx] = True
        
        # greedy selection
        selected_masks = []
        covered_samples = np.zeros(num_samples, dtype=bool)
        remaining_samples = np.ones(num_samples, dtype=bool)
        
        while len(selected_masks) < max_combination_size and np.any(remaining_samples):
            best_mask = None
            best_coverage = 0
            
            # find mask that covers the most uncovered samples
            for mask_idx in range(num_masks):
                if mask_idx in selected_masks:
                    continue
                
                # use numpy boolean operations for efficiency
                new_coverage = np.sum(mask_coverage_matrix[mask_idx] & remaining_samples)
                if new_coverage > best_coverage:
                    best_coverage = new_coverage
                    best_mask = mask_idx
            
            if best_mask is None or best_coverage == 0:
                break
            
            # add best mask
            selected_masks.append(best_mask)
            newly_covered = mask_coverage_matrix[best_mask] & remaining_samples
            covered_samples |= newly_covered
            remaining_samples &= ~newly_covered
            
            current_coverage = np.sum(covered_samples) / num_samples
            print(f"Selecting mask {best_mask}: added {np.sum(newly_covered)} samples, "
                  f"total coverage: {current_coverage:.4f} ({np.sum(covered_samples)}/{num_samples})")
            
            # check if coverage threshold is met
            if current_coverage >= coverage_threshold:
                print(f"Coverage threshold {coverage_threshold} met")
                break
        
        return {
            'selected_masks': selected_masks,
            'coverage': np.sum(covered_samples) / num_samples,
            'covered_samples': int(np.sum(covered_samples)),
            'total_samples': num_samples,
            'uncovered_samples': np.where(remaining_samples)[0].tolist()
        }
    
    def analyze_mask_combination_results(self, optimal_combination, sample_mask_results, arc_e_dataset, all_wikitext_masks):
        """analyze mask combination results"""
        print("\n=== Mask Combination Analysis Results ===")
        
        selected_masks = optimal_combination['selected_masks']
        coverage = optimal_combination['coverage']
        covered_samples = optimal_combination['covered_samples']
        total_samples = optimal_combination['total_samples']
        
        print(f"Number of selected masks: {len(selected_masks)}")
        print(f"Selected mask indices: {selected_masks}")
        print(f"Coverage: {coverage:.4f} ({covered_samples}/{total_samples})")
        
        # analyze performance of each selected mask individually
        mask_performance = {}
        for mask_idx in selected_masks:
            correct_count = 0
            for sample_idx, sample_results in sample_mask_results.items():
                if sample_results[mask_idx]['is_correct']:
                    correct_count += 1
            
            accuracy = correct_count / total_samples
            mask_performance[mask_idx] = {
                'accuracy': accuracy,
                'correct_count': correct_count,
                'total_count': total_samples
            }
            
            print(f"Mask {mask_idx}: Accuracy {accuracy:.4f} ({correct_count}/{total_samples})")
        
        # analyze uncovered samples
        uncovered_samples = optimal_combination['uncovered_samples']
        if uncovered_samples:
            print(f"\nNumber of uncovered samples: {len(uncovered_samples)}")
            print(f"Uncovered sample indices: {uncovered_samples[:10]}{'...' if len(uncovered_samples) > 10 else ''}")
        
        # calculate collaborative effect of combined masks
        combination_accuracy = self.calculate_combination_accuracy(
            selected_masks, sample_mask_results, total_samples
        )
        
        results = {
            'optimal_combination': optimal_combination,
            'mask_performance': mask_performance,
            'combination_accuracy': combination_accuracy,
            'selected_mask_indices': selected_masks,
            'coverage_rate': coverage,
            'uncovered_samples': uncovered_samples,
            'total_wikitext_masks': len(all_wikitext_masks)
        }
        
        return results
    
    def calculate_combination_accuracy(self, selected_masks, sample_mask_results, total_samples):
        """calculate collaborative accuracy of combined masks"""
        correct_samples = 0
        
        for sample_idx, sample_results in sample_mask_results.items():
            # check if any selected mask can correctly predict this sample
            sample_correct = False
            for mask_idx in selected_masks:
                if sample_results[mask_idx]['is_correct']:
                    sample_correct = True
                    break
            
            if sample_correct:
                correct_samples += 1
        
        return correct_samples / total_samples


def main():
    parser = argparse.ArgumentParser(description="Mask similarity analysis and clustering")
    
    # model parameters
    parser.add_argument("--model_path", type=str, default="xxx/llms/meta/Llama-2-7B-hf",
                        help="Model path")
    parser.add_argument("--device", type=str, default="cuda:3",
                        help="Computing device")
    parser.add_argument("--train_lrp_path", type=str, default="xxx/project/DISP/wikitext/lrp_train_ppl.pkl",
                        help="Training LRP score file path")
    parser.add_argument("--hypernetwork_checkpoint", type=str, default="xxx/project/DynPrune/llama-2-7b/041/hn/final_hypernetwork.pt", # xxx/project/DISP/arc-e/final_hypernetwork.pt
                        help="Trained hypernetwork checkpoint path")
    parser.add_argument("--use_mixed_data", default=False, type=bool,
                        help="Whether to use mixed data")
    parser.add_argument("--similarity_matrix_path", type=str, 
                        default="xxx/project/DISP/arc-e/similarity_matrix.npy")
    parser.add_argument("--normalize_lrp", type=bool, default=True, 
                        help="Whether to normalize LRP scores layer-wise")
    parser.add_argument("--normalize_activations", type=bool, default=False, help="Whether to normalize activations layer-wise")    

    # hypernetwork parameters
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--lrp_scale", type=float, default=1.0)
    parser.add_argument("--base", type=float, default=0.5)
    parser.add_argument("--T_start", type=float, default=0.5)
    parser.add_argument("--T_end", type=float, default=0.1)
    parser.add_argument("--target_sparsity", type=float, default=0.4)
    
    # pruning parameters
    parser.add_argument("--p", type=float, default=0.6)
    parser.add_argument("--lam", type=float, default=4.0)
    
    # analysis parameters
    parser.add_argument("--max_samples", type=int, default=None,
                        help="Maximum training sample count")
    parser.add_argument("--output_dir", type=str, default="./mask_analysis_results",
                        help="Output directory for results")
    parser.add_argument("--clustering_results_save_path", type=str, 
                        default="xxx/project/DISP/arc-e/clustering_results.pkl")
    
    # mask combination analysis parameters
    parser.add_argument("--find_minimal_mask_combination", action="store_true", default=False,
                        help="Whether to find minimal mask combination")
    parser.add_argument("--arc_e_coverage_threshold", type=float, default=0.99,
                        help="Coverage threshold for arc-e dataset")
    parser.add_argument("--max_mask_combination_size", type=int, default=10,
                        help="Maximum mask combination size")
    parser.add_argument("--mask_combination_save_path", type=str, 
                        default="xxx/project/DynPrune/llama-2-7b/041/mask_combination_results.pkl", # xxx/project/DISP/arc-e/mask_combination_results.pkl
                        help="Path to save mask combination results")
    parser.add_argument("--use_optimized_algorithm", action="store_true", default=True,
                        help="Whether to use optimized mask selection algorithm")
    parser.add_argument("--mask_evaluation_cache_path", type=str, 
                        default="xxx/project/DynPrune/llama-2-7b/041/mask_evaluation_cache.pkl",
                        help="Path to cache mask evaluation results")
    
    # run mode selection
    parser.add_argument("--run_mask_combination_analysis", action="store_true", default=False,
                        help="Run mask combination analysis instead of full similarity analysis")
    parser.add_argument("--max_wikitext_samples", type=int, default=1000,
                        help="Maximum number of samples to use when generating wikitext masks")
    parser.add_argument("--mask_evaluation_batch_size", type=int, default=32,
                        help="Batch size for mask evaluation")
    parser.add_argument("--per_dataset_sample_num", type=int, default=128, help="Number of samples to sample per multiple choice dataset")
    
    args = parser.parse_args()
    
    # create analyzer
    analyzer = MaskSimilarityAnalyzer(args)
    
    # select run mode based on parameters
    if hasattr(args, 'run_mask_combination_analysis') and args.run_mask_combination_analysis:
        # run mask combination analysis
        results = analyzer.run_mask_combination_analysis()
        print("\nMask combination analysis completed successfully!")
    else:
        # run complete mask similarity analysis
        results = analyzer.run_analysis()
        print("\nAnalysis completed successfully!")
        print(f"Results saved to: {args.output_dir}")


if __name__ == "__main__":
    main()