"""
从已标注 mismatch 的 huggingface 数据集中提取新小模型的 hidden states 和 logits
不重新计算 mismatch，而是保留原有标注，只添加新模型的 hidden states 和 logits

注意：假定原模型和新模型使用相同的 tokenizer
"""

import json
import os
import argparse
import signal
import sys
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from datasets import Dataset, DatasetDict, load_from_disk, Features, Value, Sequence, concatenate_datasets
import pandas as pd
import numpy as np
import random
from hr2r.utils.sampling import sample_token
import multiprocessing as mp

# Global variable to track running processes
running_processes = []

def signal_handler(signum, frame):
    """Handle interrupt signals gracefully"""
    print(f"\nReceived signal {signum}, cleaning up processes...")
    global running_processes
    
    for p in running_processes:
        if p.is_alive():
            print(f"Terminating process {p.pid}...")
            p.terminate()
            p.join(timeout=60)
            if p.is_alive():
                print(f"Force killing process {p.pid}...")
                p.kill()
                p.join()
    
    print("Cleanup completed, exiting...")
    sys.exit(0)

def load_model(model_name, device_id=0):
    """Load a model on specific GPU with basic error handling"""
    try:
        model_config = AutoConfig.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=model_config,
            device_map=f"cuda:{device_id}",
            torch_dtype=torch.float16,
        ).eval()
        print(f"Model {model_name} loaded successfully on GPU {device_id}!")
        return model
    except Exception as e:
        print(f"Error loading model on GPU {device_id}: {e}")
        return None


def load_annotated_dataset(dataset_path):
    """Load pre-annotated huggingface dataset"""
    print(f"Loading annotated dataset from {dataset_path}")
    
    try:
        # Load dataset from disk
        dataset = load_from_disk(dataset_path)
        
        # Handle both Dataset and DatasetDict
        if isinstance(dataset, DatasetDict):
            # If it's a DatasetDict, report and return
            total_samples = sum(len(split) for split in dataset.values())
            print(f"Loaded DatasetDict with {len(dataset)} splits and {total_samples} total samples")
            return dataset
        else:
            # If it's a single Dataset
            print(f"Loaded dataset with {len(dataset)} samples")
            return dataset
            
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

def as_single_dataset(dataset_or_dict):
    """Convert DatasetDict to a single Dataset by concatenating splits; otherwise return as-is"""
    if isinstance(dataset_or_dict, DatasetDict):
        parts = [split for split in dataset_or_dict.values()]
        return concatenate_datasets(parts)
    return dataset_or_dict

def split_indices(indices, num_splits):
    """Split a list of indices into num_splits nearly equal parts"""
    n = len(indices)
    base = n // num_splits
    rem = n % num_splits
    splits = []
    start = 0
    for i in range(num_splits):
        size = base + (1 if i < rem else 0)
        end = start + size
        splits.append(indices[start:end])
        start = end
    return splits

def process_dataset_with_new_model(dataset, new_model_name, device_id=0, test_ratio=0.1, random_num=None, chunk_size=200000, save_entropy=False):
    """Process dataset samples with new model to extract hidden states and logits
    
    Assumes each row in the dataset corresponds to one data_id and contains complete token sequence.
    - small_token: predicted token by the new model for each position
    - token_id: position index of token in the sequence 
    - No predictions field needed
    Output format matches slm_prefill.py when save_hidden_logit=True.
    """
    
    # Load new model and tokenizer
    new_model = load_model(new_model_name, device_id)
    if new_model is None:
        return None
        
    tokenizer = AutoTokenizer.from_pretrained(new_model_name)
    
    print(f"Processing {len(dataset)} samples with {new_model_name}")
    
    if random_num is not None:
        sampled_indices = random.sample(range(len(dataset)), random_num)
        data_iter_dataset = dataset.select(sampled_indices)
    else:
        data_iter_dataset = dataset
    
    # Store results for all tokens (following slm_prefill format)
    all_data_ids = []
    all_small_tokens = []
    all_real_tokens = []
    all_real_texts = []
    all_token_ids = []
    all_mismatches = []
    all_hidden_states = []
    all_top_logits = []
    all_top_logits_indices = []
    all_small_entropy = [] if save_entropy else None
    datasets_chunks = []

    # Prepare features to speed up Arrow writing
    features_dict = {
        'data_id': Value('string'),
        'small_token': Value('int32'),
        'real_token': Value('int32'),
        'real_text': Value('string'),
        'token_id': Value('int32'),
        'mismatch': Value('int8'),
        'small_last_hidden_states': Sequence(Value('float32')),
        'small_logits': Sequence(Value('float32'), length=100),
        'small_indices': Sequence(Value('int32'), length=100),
    }
    if save_entropy:
        features_dict['small_entropy'] = Value('float32')
    features = Features(features_dict)

    def flush_buffer():
        """Flush buffered rows into an on-disk compatible Dataset chunk
        and clear buffers to reduce peak memory usage.
        """
        if not all_data_ids:
            return
        buffer_dict = {
            'data_id': [str(x) for x in all_data_ids],
            'small_token': all_small_tokens,
            'real_token': all_real_tokens,
            'real_text': all_real_texts,
            'token_id': all_token_ids,
            'mismatch': all_mismatches,
            'small_last_hidden_states': [
                (x.cpu().numpy().astype('float32') if torch.is_tensor(x) else np.asarray(x, dtype='float32')).tolist()
                for x in all_hidden_states
            ],
            'small_logits': [
                (x.cpu().numpy().astype('float32') if torch.is_tensor(x) else np.asarray(x, dtype='float32')).tolist()
                for x in all_top_logits
            ],
            'small_indices': [
                (x.cpu().numpy().astype('int32') if torch.is_tensor(x) else np.asarray(x, dtype='int32')).tolist()
                for x in all_top_logits_indices
            ],
        }
        if save_entropy:
            buffer_dict['small_entropy'] = all_small_entropy
        ds_chunk = Dataset.from_dict(buffer_dict, features=features)
        datasets_chunks.append(ds_chunk)
        all_data_ids.clear()
        all_small_tokens.clear()
        all_real_tokens.clear()
        all_real_texts.clear()
        all_token_ids.clear()
        all_mismatches.clear()
        all_hidden_states.clear()
        all_top_logits.clear()
        all_top_logits_indices.clear()
        if save_entropy:
            all_small_entropy.clear()
    
    with torch.no_grad():
        pbar = tqdm(total=len(data_iter_dataset), desc=f"Processing samples with {new_model_name.split('/')[-1]}")

        for idx, row in enumerate(data_iter_dataset):
            # Extract data from current row - each row should contain a complete sequence
            real_tokens = row['real_token']
                
            # Convert to tensor for model input
            input_ids = torch.tensor(real_tokens, dtype=torch.long).unsqueeze(0).to(new_model.device)
            
            # Run inference
            outputs = new_model(input_ids, output_hidden_states=True)
            logits = outputs.logits[0]  # [seq_len, vocab_size]
            
            # Extract hidden states (last layer)
            hidden_states = outputs.hidden_states[-1][0]  # [seq_len, hidden_size]
            
            # Extract top logits (top 100 to match slm_prefill)
            top_logits, top_logits_indices = torch.topk(logits, 100, dim=-1)
            
            # Compute entropy over full vocabulary if requested
            if save_entropy:
                logits_float = logits.float()
                log_probs = F.log_softmax(logits_float, dim=-1)
                probs = torch.exp(log_probs)
                entropy = -(probs * log_probs).sum(dim=-1)  # [seq_len]
            
            # Sample all positions at once using vectorized operations
            pred_tokens = sample_token(logits, temperature=0.0, top_p=1.0, top_k=-1)
            pred_tokens = pred_tokens.cpu().tolist()
            
            # Convert to CPU and appropriate dtypes
            hidden_states = hidden_states.float().cpu()
            top_logits = top_logits.float().cpu()
            top_logits_indices = top_logits_indices.cpu()
            if save_entropy:
                entropy = entropy.cpu().float()
            
            # Get other fields from original data, or use defaults
            data_id = row['data_id']
            mask_values = row['mask']
            # find the first 1 in mask_values
            if isinstance(mask_values, np.ndarray):
                mask_start = np.argmax(mask_values == 1)
            else:
                mask_start = mask_values.index(1)
            mismatch_values = row['mismatch']
            
            # Decode all texts at once
            real_texts = [tokenizer.decode([token]) for token in real_tokens]
            
            # Extend all lists with batch data
            seq_len = len(real_tokens) - mask_start
            all_data_ids.extend([data_id] * seq_len)
            all_small_tokens.extend(pred_tokens[mask_start:])  # Predicted tokens for all positions
            all_real_tokens.extend(real_tokens[mask_start:])
            all_real_texts.extend(real_texts[mask_start:])
            all_token_ids.extend(list(range(seq_len)))  # Position indices
            # Only keep fields that are needed downstream
            # Convert types to minimize inference cost later
            if isinstance(mismatch_values, (list, np.ndarray)):
                all_mismatches.extend([int(x) for x in mismatch_values[mask_start:]])
            else:
                all_mismatches.extend([int(mismatch_values)] * seq_len)
            all_hidden_states.extend(hidden_states[mask_start:])
            all_top_logits.extend(top_logits[mask_start:])
            all_top_logits_indices.extend(top_logits_indices[mask_start:])
            if save_entropy:
                all_small_entropy.extend(entropy[mask_start:].tolist())

            # Flush on every step to create a Dataset per sample
            flush_buffer()
                
            pbar.update(1)
            
            # Clear GPU cache periodically
            if idx % 100 == 0:
                torch.cuda.empty_cache()
        
        pbar.close()
    
    # Flush remaining buffer
    flush_buffer()
    
    if not datasets_chunks:
        print("No valid samples processed")
        return None
    
    # Merge all chunks into a single Dataset
    processed_dataset = concatenate_datasets(datasets_chunks)
    
    # Split into train and test (matching slm_prefill logic)
    print(f"Splitting dataset into train/test with test_ratio={test_ratio}")
    split_dataset = processed_dataset.train_test_split(test_size=test_ratio, seed=42)
    train_dataset = split_dataset['train']
    test_dataset = split_dataset['test']
    
    print(f"Train dataset: {len(train_dataset)} samples")
    print(f"Test dataset: {len(test_dataset)} samples")
    
    # Create DatasetDict (matching slm_prefill format)
    from datasets import DatasetDict
    dataset_dict = DatasetDict({
        'train': train_dataset,
        'test': test_dataset
    })
    
    print(f"Final dataset created with train/test splits")
    return dataset_dict

def process_single_gpu_hidden(args, device_id, indices, new_model_name):
    """Process a subset of dataset indices on a single GPU and save a partial Dataset to disk"""
    model_path = new_model_name.split("/")[-1]
    print(f"GPU {device_id}: Processing {len(indices)} samples for model {new_model_name}")

    # Load dataset and reduce to a single Dataset
    dataset_loaded = load_annotated_dataset(args.dataset_path)
    if dataset_loaded is None:
        return None
    dataset_single = as_single_dataset(dataset_loaded)

    # Create subset view
    if len(indices) == 0:
        print(f"GPU {device_id}: No indices to process")
        return None
    data_iter_dataset = dataset_single.select(indices)

    # Load model and tokenizer
    new_model = load_model(new_model_name, device_id)
    if new_model is None:
        return None
    tokenizer = AutoTokenizer.from_pretrained(new_model_name)

    # Buffers matching existing features
    all_data_ids = []
    all_small_tokens = []
    all_real_tokens = []
    all_real_texts = []
    all_token_ids = []
    all_mismatches = []
    all_hidden_states = []
    all_top_logits = []
    all_top_logits_indices = []
    all_small_entropy = [] if getattr(args, 'save_entropy', False) else None
    datasets_chunks = []

    features_dict = {
        'data_id': Value('string'),
        'small_token': Value('int32'),
        'real_token': Value('int32'),
        'real_text': Value('string'),
        'token_id': Value('int32'),
        'mismatch': Value('int8'),
        'small_last_hidden_states': Sequence(Value('float32')),
        'small_logits': Sequence(Value('float32'), length=100),
        'small_indices': Sequence(Value('int32'), length=100),
    }
    if getattr(args, 'save_entropy', False):
        features_dict['small_entropy'] = Value('float32')
    features = Features(features_dict)

    def flush_buffer():
        if not all_data_ids:
            return
        buffer_dict = {
            'data_id': [str(x) for x in all_data_ids],
            'small_token': all_small_tokens,
            'real_token': all_real_tokens,
            'real_text': all_real_texts,
            'token_id': all_token_ids,
            'mismatch': all_mismatches,
            'small_last_hidden_states': [
                (x.cpu().numpy().astype('float32') if torch.is_tensor(x) else np.asarray(x, dtype='float32')).tolist()
                for x in all_hidden_states
            ],
            'small_logits': [
                (x.cpu().numpy().astype('float32') if torch.is_tensor(x) else np.asarray(x, dtype='float32')).tolist()
                for x in all_top_logits
            ],
            'small_indices': [
                (x.cpu().numpy().astype('int32') if torch.is_tensor(x) else np.asarray(x, dtype='int32')).tolist()
                for x in all_top_logits_indices
            ],
        }
        if getattr(args, 'save_entropy', False):
            buffer_dict['small_entropy'] = all_small_entropy
        ds_chunk = Dataset.from_dict(buffer_dict, features=features)
        datasets_chunks.append(ds_chunk)
        all_data_ids.clear()
        all_small_tokens.clear()
        all_real_tokens.clear()
        all_real_texts.clear()
        all_token_ids.clear()
        all_mismatches.clear()
        all_hidden_states.clear()
        all_top_logits.clear()
        all_top_logits_indices.clear()
        if getattr(args, 'save_entropy', False):
            all_small_entropy.clear()

    try:
        with torch.no_grad():
            pbar = tqdm(total=len(data_iter_dataset), desc=f"GPU {device_id} - {model_path}", position=device_id)
            for idx, row in enumerate(data_iter_dataset):
                real_tokens = row['real_token']
                input_ids = torch.tensor(real_tokens, dtype=torch.long).unsqueeze(0).to(new_model.device)

                outputs = new_model(input_ids, output_hidden_states=True)
                logits = outputs.logits[0]
                hidden_states = outputs.hidden_states[-1][0]
                top_logits, top_logits_indices = torch.topk(logits, 100, dim=-1)
                
                # Compute entropy if requested
                if getattr(args, 'save_entropy', False):
                    logits_float = logits.float()
                    log_probs = F.log_softmax(logits_float, dim=-1)
                    probs = torch.exp(log_probs)
                    entropy = -(probs * log_probs).sum(dim=-1)
                pred_tokens = sample_token(logits, temperature=0.0, top_p=1.0, top_k=-1)
                pred_tokens = pred_tokens.cpu().tolist()

                hidden_states = hidden_states.float().cpu()
                top_logits = top_logits.float().cpu()
                top_logits_indices = top_logits_indices.cpu()
                if getattr(args, 'save_entropy', False):
                    entropy = entropy.cpu().float()

                data_id = row['data_id']
                mask_values = row['mask']
                if isinstance(mask_values, np.ndarray):
                    mask_start = np.argmax(mask_values == 1)
                else:
                    mask_start = mask_values.index(1)
                mismatch_values = row['mismatch']

                real_texts = [tokenizer.decode([token]) for token in real_tokens]

                seq_len = len(real_tokens) - mask_start
                all_data_ids.extend([data_id] * seq_len)
                all_small_tokens.extend(pred_tokens[mask_start:])
                all_real_tokens.extend(real_tokens[mask_start:])
                all_real_texts.extend(real_texts[mask_start:])
                all_token_ids.extend(list(range(seq_len)))
                if isinstance(mismatch_values, (list, np.ndarray)):
                    all_mismatches.extend([int(x) for x in mismatch_values[mask_start:]])
                else:
                    all_mismatches.extend([int(mismatch_values)] * seq_len)
                all_hidden_states.extend(hidden_states[mask_start:])
                all_top_logits.extend(top_logits[mask_start:])
                all_top_logits_indices.extend(top_logits_indices[mask_start:])
                if getattr(args, 'save_entropy', False):
                    all_small_entropy.extend(entropy[mask_start:].tolist())

                flush_buffer()
                pbar.update(1)
                if idx % 100 == 0:
                    torch.cuda.empty_cache()
            pbar.close()

        flush_buffer()

        if not datasets_chunks:
            print(f"GPU {device_id}: No valid samples processed")
            return None

        processed_dataset = concatenate_datasets(datasets_chunks)
        suffix = "_with_entropy" if getattr(args, 'save_entropy', False) else ""
        output_file = os.path.join(args.output_path, f"results_gpu_{device_id}_{model_path}{suffix}")
        processed_dataset.save_to_disk(output_file)
        print(f"GPU {device_id}: Dataset saved to {output_file}")
        return True

    except Exception as e:
        print(f"GPU {device_id}: Error during processing: {e}")
        return None
    finally:
        try:
            del new_model
            del tokenizer
            if 'datasets_chunks' in locals():
                del datasets_chunks
            if torch.cuda.is_available():
                torch.cuda.synchronize(device=device_id)
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()
            import gc
            gc.collect()
            print(f"GPU {device_id}: Cleanup completed")
        except Exception as cleanup_error:
            print(f"GPU {device_id}: Error during cleanup: {cleanup_error}")

def merge_gpu_results_hidden(args, new_model_name):
    """Merge results from all GPUs, split train/test, and save final dataset"""
    model_path = new_model_name.split("/")[-1]
    all_datasets = []
    for gpu_id in range(args.num_gpu):
        suffix = "_with_entropy" if getattr(args, 'save_entropy', False) else ""
        result_dir = os.path.join(args.output_path, f"results_gpu_{gpu_id}_{model_path}{suffix}")
        if os.path.exists(result_dir):
            try:
                ds = load_from_disk(result_dir)
                all_datasets.append(ds)
                print(f"Loaded dataset from GPU {gpu_id}")
            except Exception as e:
                print(f"Failed to load dataset from GPU {gpu_id}: {e}")

    if not all_datasets:
        print("No GPU datasets found to merge")
        return None

    merged_dataset = concatenate_datasets(all_datasets)

    print(f"Splitting dataset into train/test with test_ratio={args.test_ratio}")
    split_dataset = merged_dataset.train_test_split(test_size=args.test_ratio, seed=42)
    train_dataset = split_dataset['train']
    test_dataset = split_dataset['test']

    dataset_dict = DatasetDict({
        'train': train_dataset,
        'test': test_dataset
    })

    suffix = "_with_entropy" if getattr(args, 'save_entropy', False) else ""
    final_output_path = os.path.join(args.output_path, f"{model_path}_processed_with_hidden{suffix}")
    dataset_dict.save_to_disk(final_output_path)
    print(f"Processed dataset with train/test splits saved to {final_output_path}")

    # Clean up individual GPU files
    for gpu_id in range(args.num_gpu):
        suffix = "_with_entropy" if getattr(args, 'save_entropy', False) else ""
        result_dir = os.path.join(args.output_path, f"results_gpu_{gpu_id}_{model_path}{suffix}")
        if os.path.exists(result_dir):
            import shutil
            shutil.rmtree(result_dir)

    return dataset_dict

def process_dataset_multi_gpu_hidden(args):
    """Process the annotated dataset with multiple GPUs and merge results"""
    global running_processes

    os.makedirs(args.output_path, exist_ok=True)

    print(f"Loading annotated dataset from {args.dataset_path}")
    dataset_loaded = load_annotated_dataset(args.dataset_path)
    if dataset_loaded is None:
        print("Failed to load dataset")
        return
    dataset_single = as_single_dataset(dataset_loaded)
    total_len = len(dataset_single)
    print(f"Dataset length: {total_len}")

    # Prepare indices according to index_range and random_num
    if args.index_range is not None:
        start_idx, end_idx = args.index_range
        base_indices = list(range(start_idx, min(end_idx, total_len)))
    else:
        base_indices = list(range(total_len))

    if args.random_num is not None and args.random_num > 0:
        random.seed(42)
        if args.random_num < len(base_indices):
            base_indices = random.sample(base_indices, args.random_num)
        else:
            print(f"random_num ({args.random_num}) >= available indices ({len(base_indices)}), using all indices")

    # Split indices into num_gpu parts
    data_splits = split_indices(base_indices, args.num_gpu)
    print(f"Dataset split into {args.num_gpu} parts with sizes: {[len(x) for x in data_splits]}")

    # print(f"Processing model: {args.new_model_name}")
    # processes = []
    # for gpu_id in range(args.num_gpu):
    #     p = mp.Process(
    #         target=process_single_gpu_hidden,
    #         args=(args, gpu_id, data_splits[gpu_id], args.new_model_name)
    #     )
    #     processes.append(p)
    #     p.start()

    # running_processes = processes

    # timeout = 24 * 60 * 60
    # for i, p in enumerate(processes):
    #     try:
    #         p.join(timeout=timeout)
    #         if p.is_alive():
    #             print(f"GPU {i}: Process timeout after {timeout} seconds, terminating...")
    #             p.terminate()
    #             p.join(timeout=30)
    #             if p.is_alive():
    #                 print(f"GPU {i}: Force killing process...")
    #                 p.kill()
    #                 p.join()
    #         elif p.exitcode != 0:
    #             print(f"GPU {i}: Process exited with code {p.exitcode}")
    #         else:
    #             print(f"GPU {i}: Process completed successfully")
    #     except Exception as e:
    #         print(f"GPU {i}: Error during process join: {e}")
    #         if p.is_alive():
    #             p.terminate()
    #             p.join(timeout=30)
    #             if p.is_alive():
    #                 p.kill()
    #                 p.join()

    # for p in processes:
    #     if p.is_alive():
    #         p.terminate()
    #         p.join(timeout=60)
    #         if p.is_alive():
    #             p.kill()
    #             p.join()

    # print(f"All GPU processes completed for {args.new_model_name}")
    # running_processes = []

    # Merge results and save final dataset
    merge_gpu_results_hidden(args, args.new_model_name)


def main():
    global running_processes
    # Register signal handlers for graceful cleanup
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    parser = argparse.ArgumentParser(
        description="Extract hidden states and logits from pre-annotated dataset using a new small model. " + 
                   "Output format matches slm_prefill.py when save_hidden_logit=True."
    )
    parser.add_argument(
        "--dataset_path", 
        type=str, 
        required=True, 
        help="Path to the pre-annotated huggingface dataset"
    )
    parser.add_argument(
        "--new_model_name", 
        type=str, 
        required=True, 
        help="Name of the new small model to use for extraction"
    )
    parser.add_argument(
        "--output_path", 
        type=str, 
        required=True, 
        help="Directory to save output files"
    )
    parser.add_argument(
        "--device_id", 
        type=int, 
        default=0, 
        help="GPU device ID to use (single-GPU fallback)"
    )
    parser.add_argument(
        "--test_ratio",
        type=float,
        default=0.05,
        help="Ratio of data to use for test split"
    )
    parser.add_argument(
        "--random_num",
        type=int,
        default=None,
        help="Number of samples to randomly sample from the dataset"
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=200000,
        help="Number of rows to buffer before flushing to a chunked Dataset"
    )
    parser.add_argument(
        "--num_gpu",
        type=int,
        default=1,
        help="Number of GPUs to use for parallel processing"
    )
    parser.add_argument(
        "--index_range",
        nargs=2,
        type=int,
        default=None,
        help="Range of dataset samples to process [start_idx, end_idx]"
    )
    parser.add_argument(
        "--save_entropy",
        action="store_true",
        help="If set, also compute and save per-token entropy as 'small_entropy'"
    )
    
    args = parser.parse_args()

    # Set multiprocessing start method
    mp.set_start_method('spawn', force=True)
    
    try:
        if args.num_gpu and args.num_gpu > 1:
            process_dataset_multi_gpu_hidden(args)
        else:
            # Fallback: single GPU processing using existing path
            os.makedirs(args.output_path, exist_ok=True)
            dataset = load_annotated_dataset(args.dataset_path)
            if dataset is None:
                print("Failed to load dataset")
                return
            processed_dataset = process_dataset_with_new_model(
                as_single_dataset(dataset),
                args.new_model_name,
                args.device_id,
                args.test_ratio,
                args.random_num,
                args.chunk_size,
                args.save_entropy
            )
            if processed_dataset is not None:
                new_model_path = args.new_model_name.split("/")[-1]
                suffix = "_with_entropy" if args.save_entropy else ""
                output_file = os.path.join(args.output_path, f"{new_model_path}_processed_with_hidden{suffix}")
                processed_dataset.save_to_disk(output_file)
                print(f"Results saved to {output_file}")

        with open(os.path.join(args.output_path, "args.json"), "w") as f:
            json.dump(vars(args), f, indent=2)

        print("Processing completed!")
    except Exception as e:
        print(f"Error during processing: {e}")
        for p in running_processes:
            if p.is_alive():
                p.terminate()
                p.join(timeout=60)
                if p.is_alive():
                    p.kill()
                    p.join()
        raise e
    finally:
        running_processes = []


if __name__ == "__main__":
    main()
