"""
Multi-threaded data processing script for conversation format datasets.
Converts JSONL or JSON format datasets without model inference.

Inputs:
- A JSONL file (.jsonl extension) with conversations format (one JSON object per line).
    - Each line contains: {"conversations": [{"from": "human", "value": "..."}, {"from": "assistant", "value": "..."}], "system": "..."}
- Or a JSON file (.json extension) with conversations format (JSON array).
    - Contains an array of objects: [{"conversations": [{"from": "human", "value": "..."}, {"from": "assistant", "value": "..."}], "system": "..."}, ...]

Outputs:
- Processed dataset with data grouped by data_id, containing real_text, real_token, mask information
"""

import json
import os
import argparse
import signal
import sys
import threading
from tqdm import tqdm
from transformers import AutoTokenizer
import torch
import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

# Global variable to track running threads
running_threads = []
thread_lock = threading.Lock()

SYSTEM_PROMPT = """
You are a helpful assistant. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.
"""

def signal_handler(signum, frame):
    """Handle interrupt signals gracefully"""
    print(f"\nReceived signal {signum}, cleaning up threads...")
    global running_threads
    
    # Note: Python threads cannot be forcefully terminated like processes
    # We can only wait for them to complete naturally
    print("Waiting for threads to complete naturally...")
    for thread in running_threads:
        if thread.is_alive():
            print(f"Waiting for thread {thread.name}...")
            thread.join(timeout=24*60*60)
    
    print("Cleanup completed, exiting...")
    sys.exit(0)

def load_jsonl_json_dataset(file_path, index_range=None):
    """Load dataset from JSONL or JSON file based on file extension"""
    data = []
    
    # Determine file format based on extension
    file_extension = os.path.splitext(file_path)[1].lower()
    
    if file_extension == '.jsonl':
        # Load JSONL format (one JSON object per line)
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
    elif file_extension == '.json':
        # Load JSON format (single JSON array)
        with open(file_path, 'r', encoding='utf-8') as f:
            loaded_data = json.load(f)
            # If it's a list, use it directly; if it's a dict, wrap it in a list
            if isinstance(loaded_data, list):
                data = loaded_data
            else:
                data = [loaded_data]
    else:
        # Default to JSONL format for unknown extensions
        print(f"Warning: Unknown file extension '{file_extension}'. Trying to read as JSONL format.")
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
    
    print(f"Loaded {len(data)} samples from {file_extension if file_extension else 'unknown'} format file: {file_path}")
    
    if index_range:
        start_idx, end_idx = index_range
        data = data[start_idx:end_idx]
        print(f"Selected range [{start_idx}:{end_idx}], resulting in {len(data)} samples")
    
    return data

def split_dataset(dataset, num_splits):
    """Split dataset into num_splits parts"""
    chunk_size = len(dataset) // num_splits
    remainder = len(dataset) % num_splits
    
    splits = []
    start_idx = 0
    
    for i in range(num_splits):
        # Add one extra item to first 'remainder' splits
        current_chunk_size = chunk_size + (1 if i < remainder else 0)
        end_idx = start_idx + current_chunk_size
        
        splits.append((start_idx, end_idx))
        start_idx = end_idx
    
    return splits

def parse_conversations(conversations):
    """Parse conversations to extract input_text, model_reasoning, and model_response"""
    input_text = None
    assistant_response = None
    
    for conv in conversations:
        if conv["from"] == "human" or conv["from"] == "user":
            input_text = conv["value"]
        elif conv["from"] == "assistant":
            assistant_response = conv["value"]
    
    if not input_text or not assistant_response:
        return None, None, None
    
    # Split assistant response into reasoning and response parts
    if "<think>" in assistant_response and "</think>" in assistant_response:
        # Extract thinking content
        think_start = assistant_response.find("<think>")
        think_end = assistant_response.find("</think>") + len("</think>")
        
        model_reasoning = assistant_response[think_start + len("<think>"):assistant_response.find("</think>")].strip()
        model_response = assistant_response[think_end:].strip()
    else:
        # No thinking tags, treat entire response as final response
        model_reasoning = ""
        model_response = assistant_response
    
    return input_text, model_reasoning, model_response

def apply_qwen_r1_chat_template(messages, add_generation_prompt=False):
    """Apply the Qwen R1 chat template to the messages"""
    prompt = "<｜begin▁of▁sentence｜>"
    ns = {
        "is_first": False,
        "is_tool": False,
        "is_output_first": True,
        "system_prompt": "",
    }

    # extract system prompt
    for message in messages:
        if message["role"] == "system":
            ns["system_prompt"] = message["content"]

    prompt += ns["system_prompt"]

    for message in messages:
        if message["role"] == "user":
            ns["is_tool"] = False
            prompt += "<｜User｜>" + message["content"]

        elif message["role"] == "assistant" and message["content"] is not None:
            content = message["content"]
            prompt += "<｜Assistant｜>" + content + "<｜end▁of▁sentence｜>"

    if add_generation_prompt:
        prompt += "<｜Assistant｜><think>\n"

    return prompt

def get_formatted_prompt_1(sample, tokenizer, model_name):
    """Format prompt from conversations structure"""
    question = sample.get("user_message", "")
    answer = sample.get("answer", "")

    messages = [
        {"role": "user", "content": question},
        {"role": "assistant", "content": answer},
    ]

    prompt = apply_qwen_r1_chat_template(messages, add_generation_prompt=False)

    return prompt

def get_formatted_prompt(sample, tokenizer, model_name):
    """Format prompt from conversations structure"""
    conversations = sample.get("conversations", [])
    system_prompt = sample.get("system", "")
    
    # Parse conversations
    input_text, model_reasoning, model_response = parse_conversations(conversations)
    
    if not input_text or model_response is None:
        print(f"Invalid conversation format, skipping")
        return None

    # Build messages
    messages = [
        {"role": "user", "content": input_text},
        {"role": "assistant", "content": None},
    ]
    
    # Add system prompt if present
    if system_prompt:
        messages.insert(0, {"role": "system", "content": system_prompt})
    
    # Format assistant response based on model type
    if "r1" in model_name.lower():
        if model_reasoning:
            messages[-1]["content"] = f"<think>\n{model_reasoning}\n</think>\n\n{model_response}"
        else:
            messages[-1]["content"] = model_response
        prompt = apply_qwen_r1_chat_template(messages, add_generation_prompt=False)
    else:
        if model_reasoning:
            messages[-1]["content"] = f"{model_reasoning}\n</think>\n\n{model_response}"
        else:
            messages[-1]["content"] = model_response
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, enable_thinking=True)
    
    return prompt

def categorize_masks(input_ids, tokenizer):
    """Categorize tokens into mask: system and query are 0, others are 1"""
    
    masks = []
    current_mask = 0  # Default to 0 for system and query
    
    for i, token_id in enumerate(input_ids[0]):
        token_id = token_id.item()
        
        if token_id == 151648:
            current_mask = 1
            
        masks.append(current_mask)
    
    return masks

def process_single_thread(args, thread_id, data_range, model_name, output_queue):
    """Process dataset on a single thread"""
    start_idx, end_idx = data_range
    model_path = model_name.split("/")[-1]
    
    print(f"Thread {thread_id}: Processing data range {start_idx}-{end_idx} for model {model_name}")
    
    try:
        # Load dataset subset
        dataset = load_jsonl_json_dataset(args.dataset_path, (start_idx, end_idx))
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Store results
        token_ids_list = []
        data_ids_list = []
        masks_list = []
        real_tokens_list = []
        predictions_list = []
        mismatch_list = []
        
        # Process each sample
        pbar = tqdm(total=len(dataset), desc=f"Thread {thread_id} - {model_path}", position=thread_id)
        
        for local_data_id, sample in enumerate(dataset):
            global_data_id = start_idx + local_data_id
            
            # Get formatted prompt
            prompt = get_formatted_prompt_1(sample, tokenizer, model_name)
            if prompt is None:
                pbar.update(1)
                continue

            # Tokenize
            input_ids = tokenizer(prompt, return_tensors="pt").input_ids

            # Skip if too long
            if len(input_ids[0]) > args.max_input_length:
                print(f"Thread {thread_id}: Input length {len(input_ids[0])} exceeds max length {args.max_input_length}, skipping")
                pbar.update(1)
                continue

            # Extract token IDs and data IDs
            token_id = torch.arange(0, input_ids.shape[-1], 1).cpu()
            data_id_tensor = torch.full_like(token_id, global_data_id).cpu()

            # Extract real tokens
            real_token = input_ids[0].cpu()

            # Categorize masks
            masks = categorize_masks(input_ids, tokenizer)
            masks_tensor = torch.tensor(masks, dtype=torch.int32).cpu()

            # Create dummy predictions (same as real tokens or zeros)
            predictions = real_token.clone()  # or torch.zeros_like(real_token)
            
            # Set all mismatches to 0 as requested
            mismatch = torch.zeros_like(real_token, dtype=torch.int32).cpu()

            # Append to lists
            token_ids_list.append(token_id)
            data_ids_list.append(data_id_tensor)
            masks_list.append(masks_tensor)
            real_tokens_list.append(real_token)
            predictions_list.append(predictions)
            mismatch_list.append(mismatch)

            pbar.update(1)

        pbar.close()

        if not token_ids_list:
            print(f"Thread {thread_id}: No valid samples processed")
            output_queue.put((thread_id, None))
            return

        # Concatenate results
        token_ids = torch.cat(token_ids_list, dim=0)
        data_ids = torch.cat(data_ids_list, dim=0)
        masks = torch.cat(masks_list, dim=0)
        real_tokens = torch.cat(real_tokens_list, dim=0)
        predictions = torch.cat(predictions_list, dim=0)
        mismatch = torch.cat(mismatch_list, dim=0)
        
        # Convert tensors to python lists for Dataset compatibility
        results_dict = {
            "predictions": predictions.tolist(),
            "token_id": token_ids.tolist(),
            "data_id": data_ids.tolist(),
            "mask": masks.tolist(),
            "real_token": real_tokens.tolist(),
            "mismatch": mismatch.tolist(),
        }
        
        # Create Dataset from dict
        dataset = Dataset.from_dict(results_dict)
        
        # Save as Dataset
        output_file = os.path.join(args.output_path, f"results_thread_{thread_id}_{model_path}")
        dataset.save_to_disk(output_file)
        
        print(f"Thread {thread_id}: Dataset saved to {output_file}")
        output_queue.put((thread_id, dataset))
        
    except Exception as e:
        print(f"Thread {thread_id}: Error during processing: {e}")
        output_queue.put((thread_id, None))

def analyze_detailed_statistics(df, tokenizer):
    """Perform detailed statistical analysis on the dataset"""
    analysis_results = {}
    
    # Basic statistics
    total_tokens = len(df)
    total_samples = df['data_id'].nunique()
    total_mismatch_tokens = sum(df['mismatch'])
    
    analysis_results['basic'] = {
        'total_tokens': int(total_tokens),
        'total_samples': int(total_samples),
        'total_mismatch_tokens': int(total_mismatch_tokens),
        'mismatch_ratio': float((total_mismatch_tokens / total_tokens * 100)) if total_tokens > 0 else 0.0
    }
    
    # Mask-based analysis (0=system/user, 1=assistant)
    mask_0_tokens = len(df[df['mask'] == 0])
    mask_1_tokens = len(df[df['mask'] == 1])
    mask_0_mismatch = sum(df[df['mask'] == 0]['mismatch'])
    mask_1_mismatch = sum(df[df['mask'] == 1]['mismatch'])
    
    analysis_results['mask_analysis'] = {
        'system_user_tokens': int(mask_0_tokens),
        'assistant_tokens': int(mask_1_tokens),
        'system_user_mismatch': int(mask_0_mismatch),
        'assistant_mismatch': int(mask_1_mismatch),
        'system_user_mismatch_ratio': float((mask_0_mismatch / mask_0_tokens * 100)) if mask_0_tokens > 0 else 0.0,
        'assistant_mismatch_ratio': float((mask_1_mismatch / mask_1_tokens * 100)) if mask_1_tokens > 0 else 0.0
    }
    
    # Per-sample analysis
    sample_stats = []
    grouped = df.groupby('data_id')
    token_lengths = []
    assistant_token_lengths = []
    mismatch_ratios = []
    
    for data_id, group in grouped:
        sample_total_tokens = len(group)
        sample_assistant_tokens = len(group[group['mask'] == 1])
        sample_mismatch_tokens = sum(group['mismatch'])
        sample_assistant_mismatch = sum(group[group['mask'] == 1]['mismatch'])
        
        token_lengths.append(sample_total_tokens)
        assistant_token_lengths.append(sample_assistant_tokens)
        
        sample_mismatch_ratio = (sample_mismatch_tokens / sample_total_tokens * 100) if sample_total_tokens > 0 else 0
        mismatch_ratios.append(sample_mismatch_ratio)
        
        sample_stats.append({
            'data_id': int(data_id),
            'total_tokens': int(sample_total_tokens),
            'assistant_tokens': int(sample_assistant_tokens),
            'mismatch_tokens': int(sample_mismatch_tokens),
            'assistant_mismatch_tokens': int(sample_assistant_mismatch),
            'mismatch_ratio': float(sample_mismatch_ratio),
            'assistant_mismatch_ratio': float((sample_assistant_mismatch / sample_assistant_tokens * 100)) if sample_assistant_tokens > 0 else 0.0
        })
    
    # Token length statistics
    analysis_results['length_analysis'] = {
        'avg_tokens_per_sample': float(np.mean(token_lengths)),
        'median_tokens_per_sample': float(np.median(token_lengths)),
        'min_tokens_per_sample': int(np.min(token_lengths)),
        'max_tokens_per_sample': int(np.max(token_lengths)),
        'std_tokens_per_sample': float(np.std(token_lengths)),
        'avg_assistant_tokens': float(np.mean(assistant_token_lengths)),
        'median_assistant_tokens': float(np.median(assistant_token_lengths)),
        'min_assistant_tokens': int(np.min(assistant_token_lengths)),
        'max_assistant_tokens': int(np.max(assistant_token_lengths))
    }
    
    # Mismatch ratio distribution
    analysis_results['mismatch_distribution'] = {
        'avg_mismatch_ratio': float(np.mean(mismatch_ratios)),
        'median_mismatch_ratio': float(np.median(mismatch_ratios)),
        'min_mismatch_ratio': float(np.min(mismatch_ratios)),
        'max_mismatch_ratio': float(np.max(mismatch_ratios)),
        'std_mismatch_ratio': float(np.std(mismatch_ratios)),
        'samples_with_no_mismatch': int(sum(1 for ratio in mismatch_ratios if ratio == 0)),
        'samples_with_high_mismatch': int(sum(1 for ratio in mismatch_ratios if ratio > 50))
    }
    
    # # Token frequency analysis for real tokens
    # real_token_counts = df['real_token'].value_counts()
    # most_common_tokens = real_token_counts.head(20).to_dict()
    
    # # Decode most common tokens
    # decoded_common_tokens = {}
    # for token_id, count in most_common_tokens.items():
    #     try:
    #         decoded_token = tokenizer.decode([int(token_id)])
    #         decoded_common_tokens[f"{int(token_id)} ({repr(decoded_token)})"] = int(count)
    #     except:
    #         decoded_common_tokens[str(int(token_id))] = int(count)
    
    # analysis_results['token_frequency'] = {
    #     'most_common_tokens': decoded_common_tokens,
    #     'unique_tokens': int(len(real_token_counts)),
    #     'total_token_occurrences': int(real_token_counts.sum())
    # }
    
    return analysis_results, sample_stats

def process_and_convert_dataset(merged_dataset, model_name, output_path):
    """Convert merged dataset to final processed format grouped by data_id"""
    print("Converting dataset to final format...")
    
    # Load tokenizer for text decoding
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Convert to pandas for easier grouping
    df = merged_dataset.to_pandas()
    grouped = df.groupby('data_id')
    print(f"Found {len(grouped)} unique data_ids.")
    
    # Perform detailed analysis on raw data
    print("Performing detailed statistical analysis...")
    analysis_results, sample_stats = analyze_detailed_statistics(df, tokenizer)
    
    # Initialize counters for statistics
    total_tokens = 0
    total_mismatch_tokens = 0
    
    final_data_list = []
    print("Processing groups...")
    for data_id, group in tqdm(grouped):
        # Convert the real_token list to text
        real_tokens = group['real_token'].tolist()
        real_text = tokenizer.decode(real_tokens)
        
        # Get mismatch indices (all zeros as requested)
        mismatch_indices = group['mismatch'].tolist()
        
        # Update statistics
        total_tokens += len(real_tokens)
        total_mismatch_tokens += sum(1 for x in mismatch_indices if x == 1)
        
        processed_item = {
            'data_id': data_id,
            'real_text': real_text,
            'real_token': real_tokens,
            'mask': group['mask'].tolist(),
            'mismatch': mismatch_indices,
        }
        final_data_list.append(processed_item)
    
    # Print statistics
    print("\n" + "="*80)
    print("DETAILED STATISTICS SUMMARY")
    print("="*80)
    
    basic = analysis_results['basic']
    print(f"Basic Statistics:")
    print(f"  Total samples: {basic['total_samples']:,}")
    print(f"  Total tokens: {basic['total_tokens']:,}")
    print(f"  Total mismatch tokens: {basic['total_mismatch_tokens']:,}")
    print(f"  Overall mismatch ratio: {basic['mismatch_ratio']:.2f}%")
    
    mask = analysis_results['mask_analysis']
    print(f"\nMask-based Analysis:")
    print(f"  System/User tokens (mask=0): {mask['system_user_tokens']:,}")
    print(f"  Assistant tokens (mask=1): {mask['assistant_tokens']:,}")
    print(f"  System/User mismatch: {mask['system_user_mismatch']:,} ({mask['system_user_mismatch_ratio']:.2f}%)")
    print(f"  Assistant mismatch: {mask['assistant_mismatch']:,} ({mask['assistant_mismatch_ratio']:.2f}%)")
    
    length = analysis_results['length_analysis']
    print(f"\nToken Length Analysis:")
    print(f"  Avg tokens per sample: {length['avg_tokens_per_sample']:.1f}")
    print(f"  Median tokens per sample: {length['median_tokens_per_sample']:.1f}")
    print(f"  Token range: {length['min_tokens_per_sample']:.0f} - {length['max_tokens_per_sample']:.0f}")
    print(f"  Avg assistant tokens: {length['avg_assistant_tokens']:.1f}")
    
    mismatch_dist = analysis_results['mismatch_distribution']
    print(f"\nMismatch Distribution:")
    print(f"  Avg mismatch ratio per sample: {mismatch_dist['avg_mismatch_ratio']:.2f}%")
    print(f"  Median mismatch ratio: {mismatch_dist['median_mismatch_ratio']:.2f}%")
    print(f"  Samples with no mismatch: {mismatch_dist['samples_with_no_mismatch']}")
    print(f"  Samples with >50% mismatch: {mismatch_dist['samples_with_high_mismatch']}")
    
    print("="*80)
    
    # Create processed dataset
    processed_dataset = Dataset.from_pandas(pd.DataFrame(final_data_list))
    print(f"Processed dataset info:")
    print(processed_dataset)
    
    # Save processed dataset
    model_path = model_name.split("/")[-1]
    final_output_path = os.path.join(output_path, f"processed_data_{model_path}")
    processed_dataset.save_to_disk(final_output_path)
    print(f"Processed dataset saved to {final_output_path}")
    
    # Save detailed analysis to files
    analysis_dir = final_output_path
    
    # Save detailed statistics
    detailed_stats_file = os.path.join(analysis_dir, "detailed_analysis.json")
    with open(detailed_stats_file, 'w', encoding='utf-8') as f:
        json.dump(analysis_results, f, indent=2, ensure_ascii=False)
    
    # Save per-sample statistics
    sample_stats_file = os.path.join(analysis_dir, "per_sample_statistics.csv")
    sample_df = pd.DataFrame(sample_stats)
    sample_df.to_csv(sample_stats_file, index=False)
    
    # Save comprehensive text report
    report_file = os.path.join(analysis_dir, "analysis_report.txt")
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write("COMPREHENSIVE DATA ANALYSIS REPORT\n")
        f.write("="*50 + "\n\n")
        
        f.write("1. BASIC STATISTICS\n")
        f.write("-"*20 + "\n")
        f.write(f"Total samples: {basic['total_samples']:,}\n")
        f.write(f"Total tokens: {basic['total_tokens']:,}\n")
        f.write(f"Total mismatch tokens: {basic['total_mismatch_tokens']:,}\n")
        f.write(f"Overall mismatch ratio: {basic['mismatch_ratio']:.4f}%\n\n")
        
        f.write("2. MASK-BASED ANALYSIS\n")
        f.write("-"*20 + "\n")
        f.write(f"System/User tokens (mask=0): {mask['system_user_tokens']:,}\n")
        f.write(f"Assistant tokens (mask=1): {mask['assistant_tokens']:,}\n")
        f.write(f"System/User mismatch: {mask['system_user_mismatch']:,} ({mask['system_user_mismatch_ratio']:.4f}%)\n")
        f.write(f"Assistant mismatch: {mask['assistant_mismatch']:,} ({mask['assistant_mismatch_ratio']:.4f}%)\n\n")
        
        f.write("3. TOKEN LENGTH ANALYSIS\n")
        f.write("-"*25 + "\n")
        f.write(f"Average tokens per sample: {length['avg_tokens_per_sample']:.2f}\n")
        f.write(f"Median tokens per sample: {length['median_tokens_per_sample']:.2f}\n")
        f.write(f"Min tokens per sample: {length['min_tokens_per_sample']:.0f}\n")
        f.write(f"Max tokens per sample: {length['max_tokens_per_sample']:.0f}\n")
        f.write(f"Std deviation: {length['std_tokens_per_sample']:.2f}\n")
        f.write(f"Average assistant tokens: {length['avg_assistant_tokens']:.2f}\n")
        f.write(f"Median assistant tokens: {length['median_assistant_tokens']:.2f}\n\n")
        
        f.write("4. MISMATCH DISTRIBUTION\n")
        f.write("-"*25 + "\n")
        f.write(f"Average mismatch ratio per sample: {mismatch_dist['avg_mismatch_ratio']:.4f}%\n")
        f.write(f"Median mismatch ratio: {mismatch_dist['median_mismatch_ratio']:.4f}%\n")
        f.write(f"Min mismatch ratio: {mismatch_dist['min_mismatch_ratio']:.4f}%\n")
        f.write(f"Max mismatch ratio: {mismatch_dist['max_mismatch_ratio']:.4f}%\n")
        f.write(f"Std deviation: {mismatch_dist['std_mismatch_ratio']:.4f}%\n")
        f.write(f"Samples with no mismatch: {mismatch_dist['samples_with_no_mismatch']}\n")
        f.write(f"Samples with >50% mismatch: {mismatch_dist['samples_with_high_mismatch']}\n\n")
        
        # f.write("5. TOKEN FREQUENCY ANALYSIS\n")
        # f.write("-"*30 + "\n")
        # token_freq = analysis_results['token_frequency']
        # f.write(f"Unique tokens: {token_freq['unique_tokens']:,}\n")
        # f.write(f"Total token occurrences: {token_freq['total_token_occurrences']:,}\n")
        # f.write("Most common tokens:\n")
        # for token, count in list(token_freq['most_common_tokens'].items())[:10]:
        #     f.write(f"  {token}: {count:,}\n")
    
    print(f"Detailed analysis saved to:")
    print(f"  - JSON format: {detailed_stats_file}")
    print(f"  - Per-sample CSV: {sample_stats_file}")
    print(f"  - Text report: {report_file}")
    
    return processed_dataset

def merge_thread_results(args, model_name):
    """Merge results from all threads and convert to final format"""
    model_path = model_name.split("/")[-1]
    all_datasets = []
    
    # Load results from all threads
    for thread_id in range(args.num_threads):
        result_dir = os.path.join(args.output_path, f"results_thread_{thread_id}_{model_path}")
        if os.path.exists(result_dir):
            dataset = Dataset.load_from_disk(result_dir)
            all_datasets.append(dataset)
            print(f"Loaded dataset from thread {thread_id}")
    
    if not all_datasets:
        print("No thread datasets found to merge")
        return None
    
    # Concatenate all datasets
    merged_dataset = concatenate_datasets(all_datasets)
    
    # Process and convert to final format
    processed_dataset = process_and_convert_dataset(merged_dataset, model_name, args.output_path)
    
    # Clean up individual thread files
    for thread_id in range(args.num_threads):
        result_dir = os.path.join(args.output_path, f"results_thread_{thread_id}_{model_path}")
        if os.path.exists(result_dir):
            import shutil
            shutil.rmtree(result_dir)
    
    return processed_dataset

def process_dataset_multi_thread(args):
    """Process the JSONL dataset with multiple threads"""
    global running_threads
    
    # Create output directory
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    # Load full dataset to get length and split
    print(f"Loading dataset from {args.dataset_path}")
    full_dataset = load_jsonl_json_dataset(args.dataset_path, args.index_range)
    print(f"Dataset length: {len(full_dataset)}")
    
    # Split dataset into multiple parts based on num_threads
    data_splits = split_dataset(full_dataset, args.num_threads)
    print(f"Dataset split into {args.num_threads} parts: {data_splits}")

    # Process each model
    for model_name in args.test_model_list:
        model_path = model_name.split("/")[-1]
        
        print(f"Processing model: {model_name}")
        
        # Use ThreadPoolExecutor for better thread management
        from queue import Queue
        output_queue = Queue()
        
        with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
            # Submit tasks to thread pool
            future_to_thread = {}
            for thread_id in range(args.num_threads):
                future = executor.submit(
                    process_single_thread,
                    args, thread_id, data_splits[thread_id], model_name, output_queue
                )
                future_to_thread[future] = thread_id
            
            # Update global thread list for signal handling
            running_threads = list(executor._threads)
            
            # Wait for all threads to complete
            completed_threads = 0
            for future in as_completed(future_to_thread):
                thread_id = future_to_thread[future]
                try:
                    future.result()  # This will raise any exception that occurred in the thread
                    print(f"Thread {thread_id}: Completed successfully")
                except Exception as exc:
                    print(f"Thread {thread_id}: Generated an exception: {exc}")
                
                completed_threads += 1
                print(f"Progress: {completed_threads}/{args.num_threads} threads completed")
        
        print(f"All threads completed for {model_name}")
        
        # Clear global thread list
        running_threads = []
        
        # Merge results from all threads and convert to final format
        processed_results = merge_thread_results(args, model_name)
        if processed_results is None:
            print(f"Failed to process results for {model_name}")
            continue

    print("Multi-thread processing completed!")

def main():
    global running_threads
    
    # Register signal handlers for graceful cleanup
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    
    parser = argparse.ArgumentParser(
        description="Run multi-threaded data processing on JSONL or JSON datasets"
    )
    parser.add_argument(
        "--dataset_path", type=str, required=True, help="Path to the dataset file (supports .jsonl and .json formats)"
    )
    parser.add_argument(
        "--num_threads", type=int, default=8, help="Number of threads to use"
    )
    parser.add_argument(
        "--test_model_list",
        nargs="+",
        type=str,
        required=True,
        help="List of test models for tokenizer (no actual model loading)",
    )
    parser.add_argument(
        "--output_path", type=str, required=True, help="Directory to save output files"
    )
    parser.add_argument(
        "--max_input_length",
        type=int,
        default=32768,
        help="Maximum length of input tokens",
    )
    parser.add_argument(
        "--index_range",
        nargs=2,
        type=int,
        default=None,
        help="Range of dataset samples to process [start_idx, end_idx]",
    )
    args = parser.parse_args()
    
    try:
        # Process dataset with multiple threads
        process_dataset_multi_thread(args)

        # Save args as json with append mode to avoid overwriting
        args_file_path = os.path.join(args.output_path, "args.json")
        
        # Load existing args if file exists
        existing_args = []
        if os.path.exists(args_file_path):
            try:
                with open(args_file_path, "r") as f:
                    content = f.read().strip()
                    if content:
                        # Try to load as list first (multiple runs), then as single dict
                        try:
                            existing_args = json.loads(content)
                            if not isinstance(existing_args, list):
                                existing_args = [existing_args]
                        except json.JSONDecodeError:
                            existing_args = []
            except Exception as e:
                print(f"Warning: Could not read existing args.json: {e}")
                existing_args = []
        
        # Add current run args with timestamp
        import datetime
        current_args = args.__dict__.copy()
        current_args["run_timestamp"] = datetime.datetime.now().isoformat()
        existing_args.append(current_args)
        
        # Save updated args list
        with open(args_file_path, "w") as f:
            json.dump(existing_args, f, indent=2)

        print("All processing completed!")
        
    except Exception as e:
        print(f"Error during processing: {e}")
        raise e
    finally:
        # Final cleanup
        running_threads = []

if __name__ == "__main__":
    main()