import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn.functional as F
import numpy as np
import argparse
from tqdm import tqdm
import pickle
import json

# import necessary modules
from models import PruneLlama2ForCausalLM
from pruning import collect_info_reg_llama, help_functions_hn
from lib.dataset_loader import (
    load_mc_dataset, format_mc_example, format_mc_prompt_with_ans,
    build_wikitext_ids, sample_wikitext_sequences, calculate_perplexity_with_label, 
    calculate_perplexity, evaluate_mc_example
)
from transformers import AutoTokenizer



class DatasetPreprocessor:
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # load LLM model
        self.setup_llm_model()
        
        # load clustering results and representative masks
        self.load_representative_masks()
        
    def setup_llm_model(self):
        """setup LLM model"""
        print(f"Loading LLM model: {self.args.model_path}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
        self.llm_model = PruneLlama2ForCausalLM.from_pretrained(
            self.args.model_path,
            torch_dtype=torch.float16,
            device_map=self.device
        )
        self.llm_model.config.use_cache = False
        self.llm_model.eval()
        
        # get parameter regularization structures
        self.param_reg = collect_info_reg_llama(self.llm_model, p=self.args.p, lam=self.args.lam)
        
        # init helper
        self.hn_helper = help_functions_hn(self.param_reg.structures)

    def validate_mask_size(self, flat_mask):
        """validate mask size matches structure"""
        # calculate expected total size
        total_expected_size = sum(self.param_reg.structures)
        actual_size = len(flat_mask)
        
        print(f"Mask size validation:")
        print(f"  Number of layers: {len(self.param_reg.structures)}")
        print(f"  Layer sizes: {self.param_reg.structures[:5]}{'...' if len(self.param_reg.structures) > 5 else ''}")
        print(f"  Expected total size: {total_expected_size}")
        print(f"  Actual mask size: {actual_size}")
        
        if actual_size != total_expected_size:
            raise ValueError(f"Mask size mismatch: expected {total_expected_size}, got {actual_size}")
        
        return True

    def load_representative_masks(self):
        """load mask combination analysis results and get selected representative masks"""
        print("Loading mask combination results and selected representative masks...")
        
        # first try to load mask combination analysis results
        mask_combination_file = "xxx/project/DynPrune/llama-2-7b/041/mask_combination_results.pkl"
        if os.path.exists(mask_combination_file):
            print(f"Found mask combination results at {mask_combination_file}")
            with open(mask_combination_file, 'rb') as f:
                mask_combination_data = pickle.load(f)
            
            # check if contains necessary fields
            if 'selected_mask_indices' in mask_combination_data:
                selected_mask_indices = mask_combination_data['selected_mask_indices']
                print(f"Found {len(selected_mask_indices)} selected mask indices: {selected_mask_indices}")
                
                # need to load all wikitext masks
                all_wikitext_masks = self.load_all_wikitext_masks()
                
                # extract representative masks based on selected indices
                self.representative_masks = []
                self.cluster_assignments = {}
                
                for i, mask_idx in enumerate(selected_mask_indices):
                    if mask_idx < len(all_wikitext_masks):
                        self.representative_masks.append(all_wikitext_masks[mask_idx])
                        self.cluster_assignments[i] = {
                            'representative_idx': mask_idx,
                            'member_indices': [mask_idx],  # single mask
                            'size': 1,
                            'source': 'mask_combination'
                        }
                        print(f"Selected mask {i}: index {mask_idx}")
                    else:
                        print(f"Warning: mask index {mask_idx} out of range (max: {len(all_wikitext_masks)-1})")
                
                print(f"Loaded {len(self.representative_masks)} representative masks from mask combination analysis")
                return
        
        # if mask combination results not found, fall back to original clustering method
        print("Mask combination results not found, falling back to clustering results...")
        clustering_file = os.path.join(self.args.clustering_results_path, "clustering_results.pkl")
        if not os.path.exists(clustering_file):
            raise FileNotFoundError(f"Neither mask combination results nor clustering results found")
        
        with open(clustering_file, 'rb') as f:
            clustering_data = pickle.load(f)
        
        masks = clustering_data['masks']
        similarity_matrix = clustering_data['similarity_matrix']
        clustering_results = clustering_data['clustering_results']
        
        # use specified number of clusters
        cluster_labels = clustering_results[self.args.num_clusters]['labels']
        
        # select representative masks for each cluster
        self.representative_masks = []
        self.cluster_assignments = {}
        
        for cluster_id in range(self.args.num_clusters):
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            
            if len(cluster_indices) == 0:
                continue
            
            # select cluster center as representative
            if len(cluster_indices) == 1:
                representative_idx = cluster_indices[0]
            else:
                cluster_sim_matrix = similarity_matrix[cluster_indices][:, cluster_indices]
                avg_similarities = np.mean(cluster_sim_matrix, axis=1)
                best_idx_in_cluster = np.argmax(avg_similarities)
                representative_idx = cluster_indices[best_idx_in_cluster]
            
            self.representative_masks.append(masks[representative_idx])
            self.cluster_assignments[cluster_id] = {
                'representative_idx': representative_idx,
                'member_indices': cluster_indices.tolist(),
                'size': len(cluster_indices),
                'source': 'clustering'
            }
            
            print(f"Cluster {cluster_id}: {len(cluster_indices)} samples, "
                  f"representative: {representative_idx}")
        
        print(f"Selected {len(self.representative_masks)} representative masks from clustering")
    
    def load_all_wikitext_masks(self):
        """load all wikitext masks"""
        print("Loading all wikitext masks...")
        
        # load wikitext hypernetwork
        wikitext_hypernetwork_path = "xxx/project/DynPrune/llama-2-7b/041/hn/final_hypernetwork.pt"
        if not os.path.exists(wikitext_hypernetwork_path):
            raise FileNotFoundError(f"Wikitext hypernetwork not found: {wikitext_hypernetwork_path}")
        
        # load wikitext hypernetwork
        from pruning.dyn_hypernetwork import dyn_hypernetwork
        wikitext_hypernetwork = dyn_hypernetwork(
            t_structures=self.param_reg.structures,
            lrp_scale=getattr(self.args, 'lrp_scale', 1.0),
            base=getattr(self.args, 'base', 0.5),
            T_start=getattr(self.args, 'T_start', 0.5),
            T_end=getattr(self.args, 'T_end', 0.1),
            target_sparsity=getattr(self.args, 'target_sparsity', 0.2),
            hidden_dim=getattr(self.args, 'hidden_dim', 128)
        ).to(self.device)
        
        # load trained wikitext hypernetwork weights
        checkpoint = torch.load(wikitext_hypernetwork_path, map_location=self.device)
        if 'hypernetwork' in checkpoint:
            wikitext_hypernetwork.load_state_dict(checkpoint['hypernetwork'])
        else:
            wikitext_hypernetwork.load_state_dict(checkpoint)
        print("Wikitext hypernetwork loaded successfully")
        
        # generate wikitext masks
        wikitext_masks = []
        
        # load wikitext data
        wikitext_lrp_path = "xxx/project/DISP/wikitext/lrp_train_ppl.pkl"
        if not os.path.exists(wikitext_lrp_path):
            raise FileNotFoundError(f"Wikitext LRP data not found: {wikitext_lrp_path}")
        
        # read wikitext data
        with open(wikitext_lrp_path, 'rb') as f:
            wikitext_samples_data = pickle.load(f)
        
        # limit wikitext sample count for efficiency
        max_wikitext_samples = getattr(self.args, 'max_wikitext_samples', 200)
        if max_wikitext_samples and len(wikitext_samples_data) > max_wikitext_samples:
            wikitext_samples_data = wikitext_samples_data[:max_wikitext_samples]
            print(f"Limiting wikitext sample count to: {max_wikitext_samples}")
        
        # create wikitext dataset
        wikitext_dataset = self.create_wikitext_dataset(wikitext_samples_data)
        print(f"Using {len(wikitext_dataset)} wikitext samples to generate masks")
        
        wikitext_hypernetwork.eval()
        with torch.no_grad():
            for idx in tqdm(range(len(wikitext_dataset)), desc="Generating wikitext masks"):
                sample = wikitext_dataset[idx]
                
                # use wikitext hypernetwork to generate hard mask
                mask = wikitext_hypernetwork.hard_output(
                    sample['layer_activations'],
                    sample['input_lrp']
                )
                
                # convert mask to binary vector
                binary_mask = self.convert_mask_to_binary(mask)
                wikitext_masks.append(binary_mask)
        
        print(f"Generated {len(wikitext_masks)} wikitext masks")
        return wikitext_masks
    
    def create_wikitext_dataset(self, samples_data):
        """create dataset for wikitext data"""
        from torch.utils.data import Dataset
        
        class WikitextDataset(Dataset):
            def __init__(self, samples_data, param_reg_structures, device, normalize_lrp=True, normalize_activations=False):
                self.device = device
                self.normalize_lrp = normalize_lrp
                self.normalize_activations = normalize_activations
                self.samples = []
                
                for idx in tqdm(range(len(samples_data)), desc="Processing wikitext samples"):
                    sample_data = samples_data[idx]
                    
                    # process sample_ids
                    sample_ids = sample_data["sample_id"]
                    if isinstance(sample_ids, np.ndarray):
                        sample_ids = torch.from_numpy(sample_ids).long()
                    elif not isinstance(sample_ids, torch.Tensor):
                        sample_ids = torch.tensor(sample_ids).long()
                    
                    # process activations and lrp
                    layer_activations = []
                    input_lrp = []
                    
                    lrp_scores = sample_data["lrp"]
                    activations = sample_data["activations"]
                    
                    for structure_idx in range(len(param_reg_structures)):
                        if structure_idx < len(lrp_scores) and structure_idx < len(activations):
                            activation_data = activations[structure_idx]
                            lrp_data = lrp_scores[structure_idx]
                            
                            if isinstance(activation_data, np.ndarray):
                                activation_tensor = torch.from_numpy(activation_data).float()
                            else:
                                activation_tensor = torch.tensor(activation_data).float()
                            
                            if isinstance(lrp_data, np.ndarray):
                                lrp_tensor = torch.from_numpy(lrp_data).float()
                            else:
                                lrp_tensor = torch.tensor(lrp_data).float()
                            
                            # add batch dimension
                            if activation_tensor.dim() == 1:
                                activation_tensor = activation_tensor.unsqueeze(0)
                            if lrp_tensor.dim() == 1:
                                lrp_tensor = lrp_tensor.unsqueeze(0)
                            
                            # apply normalization
                            if self.normalize_lrp:
                                lrp_tensor = self.normalize_tensor_layerwise(lrp_tensor)
                            
                            if self.normalize_activations:
                                activation_tensor = self.normalize_tensor_layerwise(activation_tensor)
                            
                            layer_activations.append(activation_tensor)
                            input_lrp.append(lrp_tensor)
                    
                    # add batch dimension to sample_ids
                    if sample_ids.dim() == 1:
                        sample_ids = sample_ids.unsqueeze(0)
                    
                    self.samples.append({
                        'sample_ids': sample_ids,
                        'layer_activations': layer_activations,
                        'input_lrp': input_lrp
                    })
            
            def normalize_tensor_layerwise(self, tensor, eps=1e-8):
                """layer-wise normalization method"""
                if tensor.numel() == 0:
                    return tensor
                
                tensor = torch.abs(tensor)
                original_shape = tensor.shape
                
                if tensor.dim() == 1:
                    tensor = tensor.unsqueeze(0)
                    squeeze_later = True
                else:
                    squeeze_later = False
                
                mean = tensor.mean(dim=-1, keepdim=True)
                std = tensor.std(dim=-1, keepdim=True, unbiased=False)
                std = torch.clamp(std, min=eps)
                
                normalized_tensor = (tensor - mean) / std
                
                if squeeze_later:
                    normalized_tensor = normalized_tensor.squeeze(0)
                
                return normalized_tensor
            
            def __len__(self):
                return len(self.samples)
            
            def __getitem__(self, idx):
                sample = self.samples[idx]
                return {
                    'sample_ids': sample['sample_ids'].to(self.device),
                    'layer_activations': [act.to(self.device) for act in sample['layer_activations']],
                    'input_lrp': [lrp.to(self.device) for lrp in sample['input_lrp']]
                }
        
        # pass normalization parameters
        return WikitextDataset(samples_data, self.param_reg.structures, self.device, 
                              getattr(self.args, 'normalize_lrp', True), 
                              getattr(self.args, 'normalize_activations', False))
    
    def convert_mask_to_binary(self, mask_list):
        """convert mask list to a single binary vector"""
        binary_vectors = []
        
        for layer_idx, mask_tensor in enumerate(mask_list):
            # ensure mask is binary (0 or 1) and convert to boolean type
            binary_mask = (mask_tensor > 0.5).bool()
            binary_mask_np = binary_mask.cpu().numpy().flatten().astype(np.bool_)
            binary_vectors.append(binary_mask_np)
        
        # concatenate masks from all layers
        concatenated_mask = np.concatenate(binary_vectors)
        return concatenated_mask
    
    def convert_flat_mask_to_layer_masks(self, flat_mask):
        layer_masks = []
        start_idx = 0
        
        for i, layer_size in enumerate(self.param_reg.structures):
            # each element in t_structures is an integer, representing the number of parameters in that layer
            if not isinstance(layer_size, (int, np.integer)):
                raise ValueError(f"Expected integer for layer size, got {type(layer_size)} at layer {i}")
            
            # extract mask for this layer from the flattened mask
            end_idx = start_idx + layer_size
            if end_idx > len(flat_mask):
                raise ValueError(f"Mask too short: need {end_idx} elements, got {len(flat_mask)}")
            
            layer_mask = flat_mask[start_idx:end_idx]
            
            # convert to tensor, shape [layer_size]
            layer_mask_tensor = torch.from_numpy(layer_mask.astype(np.float32))
            layer_masks.append(layer_mask_tensor)
            
            start_idx = end_idx
        
        return layer_masks
    
    def calculate_mask_scores_for_mc_sample(self, formatted_example, dataset_name):
        """calculate scores for multiple choice sample across all representative masks, using contrastive learning approach"""
        scores = []
        
        for mask_idx, mask in enumerate(self.representative_masks):
            # convert and apply mask
            single_masks = self.convert_flat_mask_to_layer_masks(mask)
            self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
            self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
            
            try:
                # calculate score using contrastive learning approach
                score = self._compute_contrastive_score(formatted_example, dataset_name)
                scores.append(score)
                
            except Exception as e:
                raise e
            
            finally:
                # restore model state
                self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
        
        return scores
    
    def _compute_contrastive_score(self, original_example, dataset_name):
        """calculate contrastive score, based on compute_contrastive_loss logic"""
        if original_example is None or 'options' not in original_example:
            raise ValueError("No options found in the example")
        
        try:
            dataset_name = original_example.get("dataset_name", "")
            
            # handle dataset type
            if "winogrande" in dataset_name.lower():
                ctx_pref = original_example["context_prefix"]
                tgt_suf = original_example["target_suffix"]
                options = original_example["options"]  # [" option1", " option2"]
                correct_idx = original_example["label"]
                
                option_log_probs = []
                for option in options:
                    # build full sequence
                    full_ctx = ctx_pref + option
                    ids_full = self.tokenizer(full_ctx + tgt_suf,
                                           add_special_tokens=False,
                                           return_tensors="pt").input_ids.to(self.device)
                    ctx_len = len(self.tokenizer(full_ctx, add_special_tokens=False).input_ids)
                    
                    # forward pass
                    with torch.cuda.amp.autocast(dtype=torch.float16):
                        logits = self.llm_model(ids_full).logits
                    
                    # calculate log probabilities
                    shift_logits = logits[:, :-1, :].contiguous()
                    shift_labels = ids_full[:, 1:].contiguous()
                    
                    # calculate token-level log probabilities
                    log_probs = F.log_softmax(shift_logits, dim=-1)
                    token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
                    
                    # only calculate average log probability for the target part (from ctx_len)
                    if ctx_len > 0 and ctx_len - 1 < token_log_probs.shape[1]:
                        target_log_probs = token_log_probs[:, ctx_len-1:]
                    else:
                        target_log_probs = token_log_probs
                    
                    # calculate average
                    avg_log_prob = target_log_probs.mean()
                    option_log_probs.append(avg_log_prob)
            
            else:
                # other datasets: ARC, HellaSwag, PIQA, etc.
                question = original_example["question"]
                options = original_example["options"]
                correct_idx = original_example["label"]
                
                option_log_probs = []
                for option_content in options:
                    # build full context
                    full_text = f"{question} Answer: {option_content}"
                    option_input_ids = self.tokenizer(full_text, return_tensors="pt").input_ids.to(self.device)
                    
                    # forward pass
                    with torch.cuda.amp.autocast(dtype=torch.float16):
                        logits = self.llm_model(option_input_ids).logits
                    
                    # calculate log probability for the answer part
                    question_len = len(self.tokenizer(question, add_special_tokens=True).input_ids)
                    
                    # calculate log probabilities
                    shift_logits = logits[:, :-1, :].contiguous()
                    shift_labels = option_input_ids[:, 1:].contiguous()
                    
                    log_probs = F.log_softmax(shift_logits, dim=-1)
                    token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
                    
                    # only calculate average log probability for the answer part
                    if question_len > 0 and question_len - 1 < token_log_probs.shape[1]:
                        answer_log_probs = token_log_probs[:, question_len-1:]
                    else:
                        answer_log_probs = token_log_probs
                    
                    # calculate average
                    avg_log_prob = answer_log_probs.mean()
                    option_log_probs.append(avg_log_prob)
            
            # stack tensors
            logits_tensor = torch.stack(option_log_probs)
            
            # calculate contrastive loss
            target = torch.tensor(correct_idx, device=self.device, dtype=torch.long)
            
            # use cross entropy to calculate loss
            contrastive_loss = F.cross_entropy(logits_tensor.unsqueeze(0), target.unsqueeze(0))
            
            # return negative loss as score (lower loss, higher score)
            return -contrastive_loss.item()
            
        except Exception as e:
            print(f"Contrastive score calculation failed: {e}")
            return float('-inf')
    
    def calculate_mask_scores_for_sample(self, text, input_ids, label_pos=None):
        """calculate scores for a sample across all representative masks"""
        scores = []
        
        for mask_idx, mask in enumerate(self.representative_masks):
            # if mask_idx == 0:
            #     self.validate_mask_size(mask)
            # convert and apply mask
            single_masks = self.convert_flat_mask_to_layer_masks(mask)
            self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
            self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
            
            try:
                if label_pos is not None:
                    # labeled data: only calculate PPL for the answer part
                    loss = calculate_perplexity_with_label(
                        self.llm_model, input_ids.unsqueeze(0).to(self.device), 
                        label_pos=label_pos, device=self.device
                    )
                else:
                    # unlabeled data: calculate overall PPL
                    loss = calculate_perplexity_with_label(
                        self.llm_model, input_ids.unsqueeze(0).to(self.device), 
                        label_pos=None, device=self.device
                    )
                
                # use negative log likelihood as score (lower is better, so take negative)
                score = -loss.item()
                scores.append(score)
                
            except Exception as e:
                print(f"Error calculating score for mask {mask_idx}: {e}")
                scores.append(float('-inf'))
            
            finally:
                # restore model state
                self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
        
        return scores
        
    def process_wikitext_dataset(self):
        """optimized WikiText dataset processing - mask priority"""
        print("Processing WikiText dataset (optimized)...")
        
        # build WikiText data
        wikitext_ids = build_wikitext_ids(self.tokenizer, split="train")
        if self.args.wikitext_samples is None:
            print("Using all available WikiText data")
            samples = sample_wikitext_sequences(
                wikitext_ids, 
                seqlen=self.args.seq_length,
                n=None,
                random_sample=False
            )
        else:
            print(f"Sampling {self.args.wikitext_samples} WikiText sequences")
            samples = sample_wikitext_sequences(
                wikitext_ids, 
                seqlen=self.args.seq_length,
                n=self.args.wikitext_samples,
                random_sample=True
            )
        
        print(f"Total WikiText samples: {len(samples)}")
        
        # pre-compute text representations for all samples
        print("Pre-computing text representations...")
        sample_texts = []
        for sample in samples:
            text = self.tokenizer.decode(sample, skip_special_tokens=True)
            sample_texts.append(text)
        
        # initialize score lists for each sample
        all_scores = [[] for _ in range(len(samples))]
        
        # iterate through each representative mask
        for mask_idx, mask in enumerate(tqdm(self.representative_masks, desc="Processing masks")):
            print(f"Processing mask {mask_idx}/{len(self.representative_masks)}")
            
            # apply current mask (only apply once)
            single_masks = self.convert_flat_mask_to_layer_masks(mask)
            self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
            self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
            
            try:
                # batch process all samples
                for sample_idx, sample in enumerate(samples):
                    try:
                        # calculate score under current mask
                        loss = calculate_perplexity_with_label(
                            self.llm_model, sample.unsqueeze(0).to(self.device), 
                            label_pos=None, device=self.device
                        )
                        score = -loss.item()
                        all_scores[sample_idx].append(score)
                        
                    except Exception as e:
                        print(f"Error calculating score for sample {sample_idx} with mask {mask_idx}: {e}")
                        all_scores[sample_idx].append(float('-inf'))
            
            finally:
                # restore model state
                self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
        
        # build final dataset samples
        dataset_samples = []
        for i, (sample, text, scores) in enumerate(zip(samples, sample_texts, all_scores)):
            dataset_samples.append({
                'text': text,
                'input_ids': sample.tolist(),
                'scores': scores,
                'data_type': 'wikitext',
                'dataset_name': 'wikitext'
            })
        
        return dataset_samples


    def process_mc_dataset(self, dataset_name):
        """optimized multiple choice dataset processing - mask priority"""
        print(f"Processing {dataset_name} dataset (optimized)...")
        
        try:
            dataset = load_mc_dataset(dataset_name, split="train")
            
            # limit sample count
            if self.args.mc_samples_per_dataset is None:
                print(f"Using all {len(dataset)} samples from {dataset_name}")
            else:
                max_samples = self.args.mc_samples_per_dataset
                if len(dataset) > max_samples:
                    print(f"Sampling {max_samples} from {len(dataset)} samples in {dataset_name}")
                    indices = np.random.choice(len(dataset), max_samples, replace=False)
                    dataset = dataset.select(indices)
                else:
                    print(f"Using all {len(dataset)} samples from {dataset_name}")
            
            # pre-process all samples
            print("Pre-processing all samples...")
            formatted_examples = []
            sample_texts = []
            
            for example in tqdm(dataset, desc=f"Pre-processing {dataset_name}"):
                formatted_example = format_mc_example(example, dataset_name)
                formatted_example["dataset_name"] = dataset_name
                formatted_examples.append(formatted_example)
                
                # extract text representation
                text = formatted_example.get("question", "")
                if not text and "context_prefix" in formatted_example:
                    text = formatted_example["context_prefix"]
                sample_texts.append(text)
            
            # initialize score lists for each sample
            all_scores = [[] for _ in range(len(formatted_examples))]
            
            # iterate through each representative mask
            for mask_idx, mask in enumerate(tqdm(self.representative_masks, desc=f"Processing masks for {dataset_name}")):
                print(f"Processing mask {mask_idx}/{len(self.representative_masks)} for {dataset_name}")
                
                # apply current mask (only apply once)
                single_masks = self.convert_flat_mask_to_layer_masks(mask)
                self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                
                try:
                    # batch process all samples
                    for sample_idx, formatted_example in enumerate(formatted_examples):
                        try:
                            # calculate contrastive score
                            score = self._compute_contrastive_score(formatted_example, dataset_name)
                            all_scores[sample_idx].append(score)
                            
                        except Exception as e:
                            print(f"Error calculating score for sample {sample_idx} with mask {mask_idx}: {e}")
                            all_scores[sample_idx].append(float('-inf'))
                
                finally:
                    # restore model state
                    self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
            
            # build final dataset samples
            dataset_samples = []
            for i, (formatted_example, text, scores) in enumerate(zip(formatted_examples, sample_texts, all_scores)):
                dataset_samples.append({
                    'text': text,
                    'input_ids': [],  # contrastive learning does not need full input_ids
                    'label_pos': None,  # contrastive learning does not need label_pos
                    'scores': scores,
                    'data_type': 'mc',
                    'dataset_name': dataset_name,
                    'original_example': formatted_example
                })
            
            return dataset_samples
            
        except Exception as e:
            print(f"Failed to process {dataset_name}: {e}")
            return []
    
    def save_dataset_samples(self, samples, dataset_name):
        """save dataset samples"""
        output_path = os.path.join(self.args.output_dir, f"{dataset_name}_processed.pkl")
        os.makedirs(self.args.output_dir, exist_ok=True)
        
        try:
            with open(output_path, 'wb') as f:
                pickle.dump(samples, f)
            print(f"Saved {len(samples)} samples to {output_path}")
        except Exception as e:
            print(f"Failed to save {dataset_name} samples: {e}")

    def convert_numpy_types(self, obj):
        """recursively convert numpy types to Python native types for JSON serialization"""
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {key: self.convert_numpy_types(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [self.convert_numpy_types(item) for item in obj]
        elif isinstance(obj, tuple):
            return tuple(self.convert_numpy_types(item) for item in obj)
        else:
            return obj

    def test_representative_masks_quality(self, test_datasets=None):
        """optimized representative mask quality test - mask priority"""
        print("\n=== Test representative mask quality (optimized) ===")
        
        if test_datasets is None:
            test_datasets = ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag', 'wikitext']
        
        results = {}
        
        for dataset_name in test_datasets:
            print(f"\n--- Testing {dataset_name} ---")
            
            if dataset_name == 'wikitext':
                # optimized WikiText dataset processing
                input_ids = build_wikitext_ids(self.tokenizer, split="test")
                samples = sample_wikitext_sequences(input_ids,
                                                    seqlen=self.args.seq_length,
                                                    n=None,
                                                    random_sample=True)
                
                print(f"WikiText samples shape: {samples.shape}")
                
                # initialize PPL lists for each sample
                all_sample_ppls = [[] for _ in range(samples.size(0))]
                
                # iterate through each representative mask
                for mask_idx in tqdm(range(len(self.representative_masks)), desc="Processing masks"):
                    # apply current mask
                    current_mask = self.representative_masks[mask_idx]
                    single_masks = self.convert_flat_mask_to_layer_masks(current_mask)
                    self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                    self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                    
                    try:
                        # batch process all samples
                        for sample_idx in range(samples.size(0)):
                            sample = samples[sample_idx:sample_idx+1]
                            
                            # calculate PPL
                            nll = calculate_perplexity(
                                self.llm_model,
                                sample,
                                limit_length=self.args.seq_length,
                                device=self.device,
                            )
                            ppl = torch.exp(nll / (sample.size(0) * (sample.size(1) - 1)))
                            all_sample_ppls[sample_idx].append(ppl.item())
                    
                    finally:
                        # restore model state
                        self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
                
                # calculate results
                sample_min_ppls = [min(ppls) for ppls in all_sample_ppls]
                total_nll = sum(np.log(min_ppl) * (samples.size(1) - 1) for min_ppl in sample_min_ppls)
                total_tokens = samples.size(0) * (samples.size(1) - 1)
                overall_ppl = torch.exp(total_nll / total_tokens)
                
                results[dataset_name] = {
                    'type': 'wikitext',
                    'num_samples': samples.size(0),
                    'overall_ppl': float(overall_ppl),
                    'avg_min_ppl': float(np.mean(sample_min_ppls)),
                    'std_min_ppl': float(np.std(sample_min_ppls)),
                    'min_ppl': float(np.min(sample_min_ppls)),
                    'max_ppl': float(np.max(sample_min_ppls))
                }
                
                print(f"   Overall PPL: {overall_ppl:.4f}")
                print(f"   Average min PPL: {np.mean(sample_min_ppls):.4f}")
                
            else:
                # optimized multiple choice dataset processing
                try:
                    dataset = load_mc_dataset(dataset_name, split="test")
                    print(f"Test sample count: {len(dataset)}")
                    
                    # pre-process all samples
                    formatted_examples = []
                    for mc_sample in dataset:
                        formatted_example = format_mc_example(mc_sample, dataset_name)
                        formatted_example["dataset_name"] = dataset_name
                        formatted_examples.append(formatted_example)
                    
                    # initialize result lists for each sample
                    all_sample_results = [[] for _ in range(len(formatted_examples))]
                    
                    # iterate through each representative mask
                    for mask_idx in tqdm(range(len(self.representative_masks)), desc=f"Processing{dataset_name} masks"):
                        # apply current mask
                        current_mask = self.representative_masks[mask_idx]
                        single_masks = self.convert_flat_mask_to_layer_masks(current_mask)
                        self.hn_helper.set_gate_vectors(self.llm_model, single_masks)
                        self.hn_helper.set_gate_status(self.llm_model, use_gate=True)
                        
                        try:
                            # batch process all samples
                            for sample_idx, formatted_example in enumerate(formatted_examples):
                                # test current mask
                                result = evaluate_mc_example(
                                    self.llm_model, self.tokenizer, formatted_example,
                                    device=self.device, max_length=2048
                                )
                                
                                mask_correct = result["is_correct"]
                                all_sample_results[sample_idx].append(mask_correct)
                        
                        finally:
                            # restore model state
                            self.hn_helper.set_gate_status(self.llm_model, use_gate=False)
                    
                    # calculate results
                    correct_samples = sum(1 for sample_results in all_sample_results if any(sample_results))
                    overall_accuracy = correct_samples / len(dataset)
                    
                    # calculate individual accuracy for each mask
                    mask_accuracies = []
                    for mask_idx in range(len(self.representative_masks)):
                        mask_correct_count = sum(1 for sample_results in all_sample_results if sample_results[mask_idx])
                        mask_accuracy = mask_correct_count / len(dataset)
                        mask_accuracies.append(mask_accuracy)
                    
                    results[dataset_name] = {
                        'type': 'mc',
                        'num_samples': len(dataset),
                        'overall_accuracy': float(overall_accuracy),
                        'correct_samples': correct_samples,
                        'mask_accuracies': mask_accuracies,
                        'avg_mask_accuracy': float(np.mean(mask_accuracies)),
                        'std_mask_accuracy': float(np.std(mask_accuracies)),
                        'best_mask_idx': int(np.argmax(mask_accuracies)),
                        'best_mask_accuracy': float(np.max(mask_accuracies))
                    }
                    
                    print(f"   Overall accuracy: {overall_accuracy:.4f} ({correct_samples}/{len(dataset)})")
                    print(f"   Average mask accuracy: {np.mean(mask_accuracies):.4f}")
                    print(f"   Best mask accuracy: {np.max(mask_accuracies):.4f}")
                    
                except Exception as e:
                    print(f"   Error testing {dataset_name}: {e}")
                    results[dataset_name] = {"error": str(e)}
        
        # print summary
        print("\n=== Representative Mask Quality Summary ===")
        valid_results = {k: v for k, v in results.items() if "error" not in v}
        
        if valid_results:
            print(f"Number of test datasets: {len(valid_results)}")
            
            # calculate average performance
            if any(r['type'] == 'wikitext' for r in valid_results.values()):
                wikitext_results = [r for r in valid_results.values() if r['type'] == 'wikitext']
                avg_wikitext_ppl = np.mean([r['overall_ppl'] for r in wikitext_results])
                print(f"Average overall PPL for WikiText: {avg_wikitext_ppl:.4f}")
            
            if any(r['type'] == 'mc' for r in valid_results.values()):
                mc_results = [r for r in valid_results.values() if r['type'] == 'mc']
                avg_mc_accuracy = np.mean([r['overall_accuracy'] for r in mc_results])
                print(f"Average overall accuracy for MC datasets: {avg_mc_accuracy:.4f}")
        
        return results


    def process_all_datasets(self):
        """optimized total dataset processing controller"""
        print("Starting optimized dataset preprocessing...")
        
        all_samples = []
        
        # process WikiText
        if self.args.use_wikitext:
            wikitext_samples = self.process_wikitext_dataset()
            all_samples.extend(wikitext_samples)
            self.save_dataset_samples(wikitext_samples, "wikitext")
        
        # process multiple choice datasets
        mc_datasets = ['arc-e', 'arc-c', 'piqa', 'winogrande', 'hellaswag']
        
        for dataset_name in mc_datasets:
            if self.args.use_mc_data:
                mc_samples = self.process_mc_dataset(dataset_name)
                all_samples.extend(mc_samples)
                self.save_dataset_samples(mc_samples, dataset_name)
        
        # save merged version of all samples
        self.save_dataset_samples(all_samples, "all_datasets")
        
        # save metadata
        metadata = {
            'num_clusters': self.args.num_clusters,
            'representative_masks_info': self.cluster_assignments,
            'total_samples': len(all_samples),
            'datasets': {
                'wikitext': len([s for s in all_samples if s['dataset_name'] == 'wikitext']),
                'arc-e': len([s for s in all_samples if s['dataset_name'] == 'arc-e']),
                'arc-c': len([s for s in all_samples if s['dataset_name'] == 'arc-c']),
                'piqa': len([s for s in all_samples if s['dataset_name'] == 'piqa']),
                'winogrande': len([s for s in all_samples if s['dataset_name'] == 'winogrande']),
                'hellaswag': len([s for s in all_samples if s['dataset_name'] == 'hellaswag'])
            }
        }

        metadata = self.convert_numpy_types(metadata)
        
        metadata_path = os.path.join(self.args.output_dir, "preprocessing_metadata.json")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"\nOptimized preprocessing completed!")
        print(f"Total samples: {len(all_samples)}")
        print(f"Results saved to: {self.args.output_dir}")
        
        return all_samples


def main():
    parser = argparse.ArgumentParser(description="Dataset Preprocessor")
    
    # model path
    parser.add_argument("--model_path", type=str, default="xxx/llms/meta/Llama-2-7B-hf")
    parser.add_argument("--device", type=str, default="cuda:3")
    parser.add_argument("--clustering_results_path", type=str, default="xxx/project/DISP/mixed")
    parser.add_argument("--output_dir", type=str, default="xxx/project/DynPrune/llama-2-7b/041")
    
    # clustering parameters
    parser.add_argument("--num_clusters", type=int, default=10)
    
    # data parameters
    parser.add_argument("--use_wikitext", action="store_true", default=False)
    parser.add_argument("--use_mc_data", action="store_true", default=False)
    parser.add_argument("--wikitext_samples", type=int, default=None)
    parser.add_argument("--mc_samples_per_dataset", type=int, default=None)
    parser.add_argument("--seq_length", type=int, default=2048)
    
    # model parameters
    parser.add_argument("--p", type=float, default=0.6)
    parser.add_argument("--lam", type=float, default=4.0)
    
    # hypernetwork parameters (for wikitext mask generation)
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--lrp_scale", type=float, default=1.0)
    parser.add_argument("--base", type=float, default=0.5)
    parser.add_argument("--T_start", type=float, default=0.5)
    parser.add_argument("--T_end", type=float, default=0.1)
    parser.add_argument("--target_sparsity", type=float, default=0.4)
    
    # wikitext mask generation parameters
    parser.add_argument("--max_wikitext_samples", type=int, default=1000,
                        help="Maximum number of samples to use for wikitext mask generation")
    parser.add_argument("--normalize_lrp", type=bool, default=True, 
                        help="Whether to normalize LRP scores layer-wise")
    parser.add_argument("--normalize_activations", type=bool, default=False, 
                        help="Whether to normalize activations layer-wise")
    
    # test parameters
    parser.add_argument("--test_masks", action="store_true", default=False,
                        help="Test representative mask quality")
    parser.add_argument("--test_datasets", type=str, nargs='+', 
                        default=['arc-e'],
                        help="List of datasets to test")
    
    args = parser.parse_args()
    
    preprocessor = DatasetPreprocessor(args)
    
    if args.test_masks:
        # test representative mask quality
        results = preprocessor.test_representative_masks_quality(test_datasets=args.test_datasets)

        print(results)
    else:
        # normal data preprocessing
        samples = preprocessor.process_all_datasets()


if __name__ == "__main__":
    main()