import os
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import json
from itertools import islice
import re
import string
import unicodedata
from collections import Counter
import difflib
import tree_sitter_python as tspython
from tree_sitter import Language, Parser
from transformers import AutoTokenizer
import argparse

# Assuming 'utils' contains necessary helper functions like read_jsonl, split_thinking_action_simple, etc.
from utils import * 

# ================= Global Variables =================
_global_tokenizer = None
_global_parser = None

# Protected categories (once assigned, cannot be overwritten by weight 1.0)
PROTECTED_CATEGORIES = {
    "comment",           # AST comment nodes
    "output_deleted",    # print/log statements
    "docstring",         # Docstrings
    "noise_zero",        # Tokens zeroed out by noise filter
    "first_occurrence_identifier",    # First occurrence of variable names
    "first_occurrence_function_name", # First occurrence of function names
    "first_occurrence_class_name",    # First occurrence of class names
}

class ParameterFilterConfig:
    """
    Configuration class for parameter filtering and weighting.
    
    Design Principles:
    1. Classification based on function and command structure.
    2. Detection restricted to repository file modifications.
    3. Independence between new_str_repo and new_str_repo_modified arrays.
    4. Support for frequency-based thinking noise filtering.
    5. Binary weights only (0.0 and 1.0).
    6. Complete filtering of comments and print statements.
    """
    def __init__(self):
        # [Core Config 1] Final list of parameter types (11 types)
        self.parameter_types = [
            'thinking',                      # 1. Thinking section
            'execute_bash',                  # 2. execute_bash function
            'submit',                        # 3. submit function
            'str_replace_editor_view',       # 4. view command (includes path and view_range)
            'str_replace_editor_create',     # 5. create command (includes path and file_text)
            'str_replace_editor_insert',     # 6. insert command (includes path, new_str, and insert_line)
            'old_str_repo',                  # 7. old_str in str_replace (repo files)
            'new_str_repo',                  # 8. new_str in str_replace (repo files)
            'old_str_created',               # 9. old_str in str_replace (created files)
            'new_str_created',               # 10. new_str in str_replace (created files)
            'new_str_repo_modified',         # 11. Actually modified tokens in repo files
        ]
        
        self.zero_xml_structure = True
        self.zero_noise_tokens = True
        
        # ===== Thinking Noise Filtering Configuration =====
        self.filter_thinking_by_frequency = False
        self.thinking_noise_token_file = None
        self.thinking_noise_token_ids = set()
        
        # ===== AST Filtering Configuration =====
        self.ast_filter_param_types = {
            'str_replace_editor_create',
            'str_replace_editor_insert',
            'old_str_repo',
            'new_str_repo',
            'old_str_created',
            'new_str_created',
        }
        
        # ===== Output Function Configuration =====
        self.output_function_names = {
            'print', 'pprint', 'printf',
            'log', 'debug', 'info', 'warning', 'warn', 'error', 'critical', 'exception',
            'write', 'writeln', 'puts', 'echo',
        }
        
        # ========================================
        # ===== Filtering Strategy Switches =====
        # ========================================
        
        # 1. Filter Comments
        self.filter_comments = True
        
        # 2. Filter Docstrings
        self.filter_docstrings = True
        
        # 3. Filter Output Statements (Print/Log)
        self.filter_output_statements = True
        
        # 4. Filter First Occurrence of Identifiers (Variables, Functions, Classes)
        self.enable_first_occurrence_filtering = True
        
        # 4.1 Parameter types where first occurrence filtering applies
        self.filter_first_occurrence_in_param_types = {
            'str_replace_editor_create',  # Test files
            'str_replace_editor_insert',  # Inserted code
        }
        
        # 4.2 Fine-grained control (effective only when enable_first_occurrence_filtering=True)
        self.filter_first_occurrence_identifiers = True      # Variable names
        self.filter_first_occurrence_function_names = True   # Function names
        self.filter_first_occurrence_class_names = True      # Class names
        
        # 5. Filter Specific File Types
        self.filter_files_by_extension = True
        self.filter_file_extensions = {'.md', '.txt', '.rst'}  # Extensions to filter
        
        # 6. Filter Path Parameters
        self.filter_path_parameters = True  # True: Filter path params, False: Keep them
        
        # 7. Remove Shebang for AST Parsing
        self.remove_shebang_for_ast = True
        
        # AST Weight Configuration (Binary: 0.0 or 1.0)
        self.ast_weights = {
            "identifier": 1.0,
            "attribute": 1.0,
            "if_statement": 1.0,
            "for_statement": 1.0,
            "while_statement": 1.0,
            "return_statement": 1.0,
            "break_statement": 1.0,
            "continue_statement": 1.0,
            "try_statement": 1.0,
            "except_clause": 1.0,
            "raise_statement": 1.0,
            "with_statement": 1.0,
            "function_definition": 1.0,
            "class_definition": 1.0,
            "binary_operator": 1.0,
            "boolean_operator": 1.0,
            "comparison_operator": 1.0,
            "assignment": 1.0,
            "import_statement": 1.0,
            "import_from_statement": 1.0,
            "return": 1.0,
            "if": 1.0,
            "else": 1.0,
            "elif": 1.0,
            "for": 1.0,
            "while": 1.0,
            "try": 1.0,
            "except": 1.0,
            "def": 1.0,
            "class": 1.0,
            "import": 1.0,
            "from": 1.0,
            "and": 1.0,
            "or": 1.0,
            "not": 1.0,
            "in": 1.0,
            "is": 1.0,
            "as": 1.0,
            "expression_statement": 1.0,
            "string": 1.0,
            "string_content": 1.0,
            "true": 1.0,
            "false": 1.0,
            "none": 1.0,
            "block": 1.0,
            "call": 1.0,
            "argument_list": 1.0,
            "parameter_list": 1.0,
            "list": 1.0,
            "dictionary": 1.0,
            "tuple": 1.0,
            "set": 1.0,
            "integer": 1.0,
            "float": 1.0,
            "comment": 0.0,
        }
        
        # Parameter Router Configuration (Unified weight 1.0)
        self.param_router = {
            "thinking": {"type": "direct", "weight": 1.0},
            "execute_bash": {"type": "direct", "weight": 1.0},
            "submit": {"type": "direct", "weight": 0.0},
            "str_replace_editor_view": {"type": "direct", "weight": 1.0},
            "str_replace_editor_create": {"type": "ast", "base_weight": 1.0},
            "str_replace_editor_insert": {"type": "ast", "base_weight": 1.0},
            "old_str_repo": {"type": "ast", "base_weight": 1.0},
            "new_str_repo": {"type": "ast", "base_weight": 1.0},
            "old_str_created": {"type": "ast", "base_weight": 1.0},
            "new_str_created": {"type": "ast", "base_weight": 1.0},
            "new_str_repo_modified": {"type": "ast", "base_weight": 1.0},
        }
    
    def load_thinking_noise_tokens(self, npz_file_path):
        """Load high-frequency noise tokens from .npz file."""
        try:
            data = np.load(npz_file_path)
            self.thinking_noise_token_ids = set(data['noise_token_ids'].tolist())
            metadata = data['metadata']
            
            print(f"[INFO] Loaded thinking noise tokens from: {npz_file_path}")
            print(f"       Total unique tokens: {int(metadata[0])}")
            print(f"       Total token count: {int(metadata[1])}")
            print(f"       Noise tokens (Top K): {int(metadata[2])}")
            print(f"       Coverage ratio: {metadata[3]:.4%}")
            
            self.thinking_noise_token_file = npz_file_path
            self.filter_thinking_by_frequency = True
            
        except Exception as e:
            print(f"[ERROR] Failed to load thinking noise tokens: {e}")
            self.filter_thinking_by_frequency = False
    
    def print_filter_config(self):
        """Print current filtering configuration."""
        print("\n" + "="*80)
        print("FILTER CONFIGURATION")
        print("="*80)
        print(f"1. Filter Comments:              {self.filter_comments}")
        print(f"2. Filter Docstrings:            {self.filter_docstrings}")
        print(f"3. Filter Output Statements:     {self.filter_output_statements}")
        print(f"4. First Occurrence Filtering:   {self.enable_first_occurrence_filtering}")
        if self.enable_first_occurrence_filtering:
            print(f"   - Filter Identifiers:         {self.filter_first_occurrence_identifiers}")
            print(f"   - Filter Function Names:      {self.filter_first_occurrence_function_names}")
            print(f"   - Filter Class Names:         {self.filter_first_occurrence_class_names}")
            print(f"   - Param Types: {', '.join(self.filter_first_occurrence_in_param_types)}")
        print(f"5. Filter Files by Extension:    {self.filter_files_by_extension}")
        if self.filter_files_by_extension:
            print(f"   - Extensions: {', '.join(self.filter_file_extensions)}")
        print(f"6. Filter Path Parameters:       {self.filter_path_parameters}")
        print(f"7. Remove Shebang for AST:       {self.remove_shebang_for_ast}")
        print("="*80 + "\n")


# ================= Helper Functions =================

def should_zero_thinking_token_with_frequency(token_id, token_str, category, config):
    """Determine if a thinking token should be zeroed (combining frequency filtering)."""
    if should_zero_weight_strict(token_str, category, "thinking"):
        return True
    
    if config.filter_thinking_by_frequency:
        if token_id in config.thinking_noise_token_ids:
            return True
    
    return False


def _should_filter_file_by_extension(file_path, config):
    """Determine if a file should be filtered based on its extension."""
    if not config.filter_files_by_extension:
        return False
    
    if not file_path or not config.filter_file_extensions:
        return False
    
    _, ext = os.path.splitext(file_path)
    
    return ext.lower() in config.filter_file_extensions


def classify_files_in_trajectory(function_calls):
    """Classify file types as 'repo' or 'created'."""
    viewed_files = set()
    created_files = set()
    file_types = {}
    
    for func_call in function_calls:
        func_name = func_call['function_name']
        command = func_call.get('command')
        path = func_call.get('path')
        params = func_call['params']
        
        if not path or func_name != 'str_replace_editor':
            continue
        
        if command == 'view':
            viewed_files.add(path)
        elif command == 'create':
            created_files.add(path)
        
        if 'file_text' in params:
            created_files.add(path)
    
    all_paths = viewed_files | created_files
    
    for path in all_paths:
        if path in viewed_files:
            file_types[path] = 'repo'
        elif path in created_files:
            file_types[path] = 'created'
        else:
            file_types[path] = 'repo'

    return file_types


def extract_function_calls_with_params(decoded_text):
    """Extract all function calls and their parameters."""
    function_matches = list(re.finditer(r'<function=([^>]+)>(.*?)</function>', decoded_text, re.DOTALL))
    
    function_calls = []
    
    for match in function_matches:
        func_name = match.group(1).strip('\'"')
        func_start = match.start()
        func_end = match.end()
        func_text = match.group(2)
        func_text_start = match.start(2)
        
        params = {}
        param_matches = re.finditer(r'<parameter=(["\']?)(\w+)\1>(.*?)</parameter>', func_text, re.DOTALL)
        
        for p_match in param_matches:
            key_str = p_match.group(2).strip()
            val_str = p_match.group(3)
            val_start = func_text_start + p_match.start(3)
            val_end = func_text_start + p_match.end(3)
            
            params[key_str] = {
                'start': val_start,
                'end': val_end,
                'value': val_str
            }
        
        command = None
        path = None
        
        if 'command' in params:
            command = params['command']['value'].strip()
        
        if 'path' in params:
            path = params['path']['value'].strip()
        
        function_calls.append({
            'function_name': func_name,
            'func_start': func_start,
            'func_end': func_end,
            'command': command,
            'path': path,
            'params': params,
        })
    
    return function_calls


def pair_old_new_str(function_calls, file_types):
    """Pair old_str and new_str (specifically for str_replace command)."""
    paired = []
    unpaired_old = []
    unpaired_new = []
    
    for func_call in function_calls:
        if func_call['function_name'] != 'str_replace_editor':
            continue
        
        if func_call['command'] != 'str_replace':
            continue
        
        params = func_call['params']
        path = func_call['path']
        
        has_old = 'old_str' in params
        has_new = 'new_str' in params
        
        file_type = file_types.get(path, 'repo')
        
        if has_new and has_old:
            old_info = params['old_str']
            new_info = params['new_str']
            paired.append({
                'old': old_info,
                'new': new_info,
                'func_call': func_call,
                'file_type': file_type,
                'path': path
            })
        else:
            if has_old:
                unpaired_old.append({
                    'info': params['old_str'],
                    'func_call': func_call,
                    'file_type': file_type,
                    'path': path
                })
            if has_new:
                unpaired_new.append({
                    'info': params['new_str'],
                    'func_call': func_call,
                    'file_type': file_type,
                    'path': path
                })
    
    return {
        'paired': paired,
        'unpaired_old': unpaired_old,
        'unpaired_new': unpaired_new
    }


def find_modified_tokens_precise(old_text, new_text, tokenizer, token_offsets, new_start_char, new_end_char):
    """Precisely identify tokens in new_str that are actually modified relative to old_str."""
    if not old_text or not new_text:
        return []
    
    matcher = difflib.SequenceMatcher(None, old_text, new_text)
    modified_ranges_in_new = []
    
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag in ['replace', 'insert']:
            modified_ranges_in_new.append((j1, j2))
    
    if not modified_ranges_in_new:
        return []
    
    modified_token_indices = []
    
    for i, (t_start, t_end) in enumerate(token_offsets):
        if not (new_start_char <= t_start < new_end_char or new_start_char < t_end <= new_end_char):
            continue
        
        rel_t_start = t_start - new_start_char
        rel_t_end = t_end - new_start_char
        
        if rel_t_start < 0 or rel_t_end > len(new_text):
            continue
        
        t_center = (rel_t_start + rel_t_end) / 2
        
        for mod_start, mod_end in modified_ranges_in_new:
            if mod_start <= t_center < mod_end:
                modified_token_indices.append(i)
                break
    
    return modified_token_indices


def is_decorative_symbol(token_str):
    """Check if a token is a decorative symbol."""
    if not token_str or not token_str.strip():
        return False
    
    stripped = token_str.strip()
    
    decorative_chars = {
        '✓', '✔', '✗', '✘', '☑', '☒', '✅', '❌',
        '→', '←', '↑', '↓', '⇒', '⇐', '➜', '➔', '⟶', '⟵',
        '●', '○', '■', '□', '▪', '▫', '◆', '◇', '★', '☆',
        '•', '·', '※', '§', '¶', '†', '‡', '⁂',
        '♠', '♣', '♥', '♦', '♪', '♫', '☀', '☁', '☂',
    }
    
    if stripped in decorative_chars:
        return True
    
    for char in stripped:
        code_point = ord(char)
        if 0x1F600 <= code_point <= 0x1F64F:
            return True
        if 0x1F300 <= code_point <= 0x1F5FF:
            return True
        if 0x1F680 <= code_point <= 0x1F6FF:
            return True
        if 0x1F900 <= code_point <= 0x1F9FF:
            return True
        if 0x2600 <= code_point <= 0x26FF:
            return True
        if 0x2700 <= code_point <= 0x27BF:
            return True
        if 0x2190 <= code_point <= 0x21FF:
            return True
        if 0x25A0 <= code_point <= 0x25FF:
            return True
    
    if len(stripped) <= 3:
        try:
            categories = [unicodedata.category(c) for c in stripped]
            if all(cat == 'So' for cat in categories):
                return True
        except:
            pass
    
    return False

def should_zero_weight_strict(token_str, category, param_type):
    """Strictly determine if weight should be zeroed."""
    if not token_str:
        return True
    
    stripped = token_str.strip()
    
    if stripped == '':
        return True
    
    if len(stripped) >= 3 and len(set(stripped)) == 1:
        char = stripped[0]
        separator_chars = {'=', '-', '*', '_', '#', '~', '+'}
        
        if char in separator_chars:
            return True
    
    if re.match(r'^</?(?:function|parameter).*?>$', token_str):
        return True
    
    if stripped in ['>', '<', '/>']:
        return True
    
    if all(c in string.punctuation + string.whitespace for c in token_str):
        if stripped not in ['=', '+', '-', '*', '/', '@', '&', '|', '!', '?', '%', '^', '~', '`', ':', '(', ')', '[', ']', '{', '}']:
            return True
    
    if stripped in ['"', "'", '"""', "'''"]:
        return True
    
    if is_decorative_symbol(token_str):
        return True
    
    return False


def init_worker(tokenizer_path):
    global _global_tokenizer, _global_parser
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    try:
        _global_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        PY_LANGUAGE = Language(tspython.language())
        _global_parser = Parser(PY_LANGUAGE)
    except Exception as e:
        print(f"[FATAL] Worker initialization failed: {e}")
        raise e


def _get_offsets_by_action_segments(target_index_all, action_ranges, tokenizer, decoded_text):
    """Calculate offsets by action segments."""
    offsets = [(0, 0)] * len(target_index_all)
    i = 0
    current_search_pos = 0
    
    for i in range(len(action_ranges)):
        start, end = action_ranges[i]

        seg_ids = target_index_all[start:end]

        seg_text = tokenizer.decode(seg_ids, skip_special_tokens=False)
        if not seg_text:
            continue

        base = decoded_text.find(seg_text, current_search_pos)

        if base == -1:
            seg_text_stripped = ''.join(seg_text.split())
            decoded_stripped = ''.join(decoded_text[current_search_pos:].split())
            if seg_text_stripped in decoded_stripped:
                base = current_search_pos
            else:
                current_search_pos += len(seg_text)
                continue
        
        current_search_pos = base + len(seg_text)

        try:
            enc = tokenizer(seg_text, return_offsets_mapping=True, add_special_tokens=False)
            seg_offsets = enc['offset_mapping']
        except Exception:
            continue
        
        for j, (s, e) in enumerate(seg_offsets):
            offsets[start + j] = (base + s, base + e)
    return offsets

def load_trajectory_from_jsonl(args):
    """Load single trajectory from JSONL file."""
    global _global_tokenizer
    
    jsonl_path = args['jsonl_path']
    line_index = args['line_index']
    max_tokens = args.get('max_tokens', 32700)

    line = next(islice(read_jsonl(jsonl_path), line_index, None), None)
    
    if line is None:
        return {
            'line_index': line_index,
            'status': 'failed',
            'error': f'Line {line_index} not found'
        }
    
    text = line['messages']

    try:
        fragments = extract_dynamic_fragments_with_precise_indices(
            text, _global_tokenizer, max_tokens=max_tokens
        )

        if not fragments or len(fragments) == 0:
            raise ValueError("No fragments extracted")
        
        llm_indices_all = np.concatenate(fragments[0]['llm_token_ids'])

        return {
            'line_index': line_index,
            'llm_token_ids': llm_indices_all,
            'status': 'success',
            'error': None
        }
    except Exception as e:
        return {
            'line_index': line_index,
            'status': 'failed',
            'error': str(e)
        }


# ================= AST Helper Functions =================

def _get_function_name(call_node, code_str):
    """Extract function name from call node."""
    code_bytes = code_str.encode('utf-8')
    
    for child in call_node.children:
        if child.type == "identifier":
            try:
                func_name_bytes = code_bytes[child.start_byte:child.end_byte]
                return func_name_bytes.decode('utf-8', errors='ignore')
            except:
                return ""
        
        if child.type == "attribute":
            parts = []
            _extract_attribute_parts(child, code_bytes, parts)
            return ".".join(parts) if parts else ""
    
    return ""


def _extract_attribute_parts(node, code_bytes, parts):
    """Recursively extract parts of an attribute node."""
    for child in node.children:
        if child.type == "identifier":
            try:
                part_bytes = code_bytes[child.start_byte:child.end_byte]
                parts.append(part_bytes.decode('utf-8', errors='ignore'))
            except:
                pass
        elif child.type == "attribute":
            _extract_attribute_parts(child, code_bytes, parts)


def _is_output_function(call_node, code_str, config):
    """Determine if a call node is an output function call."""
    func_name = _get_function_name(call_node, code_str)
    if not func_name:
        return False
    
    func_name_lower = func_name.lower()
    
    if '.' in func_name_lower:
        method_name = func_name_lower.split('.')[-1]
    else:
        method_name = func_name_lower
    
    return method_name in config.output_function_names


def _is_docstring(string_node):
    """Determine if a node is a docstring."""
    parent = string_node.parent
    if not parent or parent.type != "expression_statement":
        return False
    
    try:
        string_content = string_node.text.decode('utf-8', errors='ignore')
    except:
        return False
    
    stripped = string_content.lstrip('rfbRFB')
    
    if not (stripped.startswith('"""') or stripped.startswith("'''")):
        return False
    
    grandparent = parent.parent
    if not grandparent:
        return False
    
    if grandparent.type == "module":
        for child in grandparent.children:
            if child.type == "expression_statement":
                return child == parent
        return False
    
    if grandparent.type == "block":
        great_grandparent = grandparent.parent
        
        if great_grandparent and great_grandparent.type in ["function_definition", "class_definition"]:
            for child in grandparent.children:
                if child.type == "expression_statement":
                    return child == parent
            return False
    
    return False


# ===== Helper Functions for First Occurrence Filtering =====

def _reset_seen_identifiers():
    """Reset the set of seen identifiers (independent per file)."""
    if hasattr(_apply_ast_recursive_to_array, 'seen_identifiers'):
        _apply_ast_recursive_to_array.seen_identifiers.clear()
    else:
        _apply_ast_recursive_to_array.seen_identifiers = set()
    
    if hasattr(_apply_ast_recursive_to_array, 'seen_function_names'):
        _apply_ast_recursive_to_array.seen_function_names.clear()
    else:
        _apply_ast_recursive_to_array.seen_function_names = set()
    
    if hasattr(_apply_ast_recursive_to_array, 'seen_class_names'):
        _apply_ast_recursive_to_array.seen_class_names.clear()
    else:
        _apply_ast_recursive_to_array.seen_class_names = set()


def _get_node_text(node, code_str):
    """Extract text content from a node."""
    try:
        code_bytes = code_str.encode('utf-8')
        node_bytes = code_bytes[node.start_byte:node.end_byte]
        return node_bytes.decode('utf-8', errors='ignore')
    except:
        return ""


def _is_assignment_target(node):
    """Determine if an identifier is an assignment target."""
    parent = node.parent
    if not parent:
        return False
    
    if parent.type == "assignment":
        for child in parent.children:
            if child == node:
                return True
            if child.type in ["=", "+=", "-=", "*=", "/=", "//=", "%=", "**=", "&=", "|=", "^=", ">>=", "<<="]:
                return False
    
    if parent.type == "for_statement":
        for i, child in enumerate(parent.children):
            if child == node:
                return True
            if child.type == "in":
                return False
    
    if parent.type == "as_pattern":
        for i, child in enumerate(parent.children):
            if child.type == "as":
                if i + 1 < len(parent.children) and parent.children[i + 1] == node:
                    return True
    
    return False


def _get_function_name_node(function_def_node):
    """Extract function name node from function_definition node."""
    for child in function_def_node.children:
        if child.type == "identifier":
            return child
    return None


def _get_class_name_node(class_def_node):
    """Extract class name node from class_definition node."""
    for child in class_def_node.children:
        if child.type == "identifier":
            return child
    return None


# ================= Core Weight Calculation =================

def get_per_parameter_weights(target_index_all, tokenizer, parser, config, decoded_text=None):
    """Generate independent weight arrays for each parameter type."""
    if decoded_text is None:
        decoded_text = tokenizer.decode(target_index_all, skip_special_tokens=True)

    num_tokens = len(target_index_all)
    
    param_weights = {}
    for param_type in config.parameter_types:
        param_weights[param_type] = {
            'weights': np.zeros(num_tokens, dtype=np.float32),
            'categories': ["none"] * num_tokens,
            'token_param_types': ["none"] * num_tokens,
            'occurrence_count': 0
        }

    # ===== Keep original logic =====
    thinking_mask, action_mask, thinking_ranges, action_ranges = split_thinking_action_simple(
        target_index_all, 
        tokenizer=tokenizer,
        decoded_text=decoded_text
    )
    token_offsets = _get_offsets_by_action_segments(target_index_all, action_ranges, tokenizer, decoded_text)

    thinking_action_pairs = []
    i = 0
    while i < num_tokens:
        if thinking_mask[i]:
            thinking_start = i
            while i < num_tokens and thinking_mask[i]:
                i += 1
            thinking_end = i
            
            action_start = i
            while i < num_tokens and action_mask[i]:
                i += 1
            action_end = i
            
            if action_end > action_start:
                thinking_action_pairs.append({
                    'thinking_start': thinking_start,
                    'thinking_end': thinking_end,
                    'action_start': action_start,
                    'action_end': action_end
                })
        else:
            i += 1
    
    # print(f"  Found {len(thinking_action_pairs)} thinking-action pairs")

    function_calls = extract_function_calls_with_params(decoded_text)
    file_types = classify_files_in_trajectory(function_calls)
    
    file_type_counter = Counter(file_types.values())
    # print(f"  File types: repo={file_type_counter.get('repo', 0)}, created={file_type_counter.get('created', 0)}")

    # Process Thinking
    for pair_idx, pair in enumerate(thinking_action_pairs):
        thinking_start = pair['thinking_start']
        thinking_end = pair['thinking_end']
        
        if pair_idx >= len(function_calls):
            weights = param_weights['thinking']['weights']
            categories = param_weights['thinking']['categories']
            token_types = param_weights['thinking']['token_param_types']
            
            for i in range(thinking_start, thinking_end):
                weights[i] = 1.0
                categories[i] = "thinking"
                token_types[i] = "thinking"
            
            param_weights['thinking']['occurrence_count'] += 1
            continue
        
        func_call = function_calls[pair_idx]
        func_name = func_call['function_name']
        command = func_call.get('command')
        
        thinking_category = None
        
        if func_name == 'execute_bash':
            thinking_category = 'thinking_for_execute_bash'
        elif func_name == 'submit':
            thinking_category = 'thinking_for_submit'
        elif func_name == 'str_replace_editor':
            if command == 'view':
                thinking_category = 'thinking_for_str_replace_editor_view'
            elif command == 'create':
                thinking_category = 'thinking_for_str_replace_editor_create'
            elif command == 'insert':
                thinking_category = 'thinking_for_str_replace_editor_insert'
            elif command == 'str_replace':
                path = func_call.get('path')
                file_type = file_types.get(path, 'repo')
                thinking_category = 'thinking_for_new_str_repo' if file_type == 'repo' else 'thinking_for_new_str_created'
        
        if thinking_category:
            weights = param_weights['thinking']['weights']
            categories = param_weights['thinking']['categories']
            token_types = param_weights['thinking']['token_param_types']
            
            for i in range(thinking_start, thinking_end):
                weights[i] = 1.0
                categories[i] = thinking_category
                token_types[i] = "thinking"
            
            param_weights['thinking']['occurrence_count'] += 1
    
    # Process Action
    for func_call in function_calls:
        func_name = func_call['function_name']
        command = func_call.get('command')
        func_start = func_call['func_start']
        func_end = func_call['func_end']
        params = func_call['params']
        
        if func_name == 'execute_bash':
            weights = param_weights['execute_bash']['weights']
            categories = param_weights['execute_bash']['categories']
            token_types = param_weights['execute_bash']['token_param_types']

            param_weights['execute_bash']['occurrence_count'] += 1
            
            if 'command' in params:
                param_info = params['command']
                _apply_overlap_weight_to_array(
                    weights, categories, token_types, token_offsets,
                    param_info['start'], param_info['end'], 1.0,
                    "execute_bash_command", "execute_bash"
                )
        
        elif func_name == 'submit':
            weights = param_weights['submit']['weights']
            categories = param_weights['submit']['categories']
            token_types = param_weights['submit']['token_param_types']
            
            param_weights['submit']['occurrence_count'] += 1
            
            _apply_overlap_weight_to_array(
                weights, categories, token_types, token_offsets,
                func_start, func_end, 0.0, "submit", "submit"
            )
        
        elif func_name == 'str_replace_editor':
            if command == 'view':
                weights = param_weights['str_replace_editor_view']['weights']
                categories = param_weights['str_replace_editor_view']['categories']
                token_types = param_weights['str_replace_editor_view']['token_param_types']
                
                param_weights['str_replace_editor_view']['occurrence_count'] += 1
                
                for param_name in ['path', 'view_range']:
                    if param_name in params:
                        param_info = params[param_name]
                        
                        # ===== view command: no filtering for all params =====
                        _apply_overlap_weight_to_array(
                            weights, categories, token_types, token_offsets,
                            param_info['start'], param_info['end'], 1.0,
                            f"view_{param_name}", "str_replace_editor_view"
                        )
            
            elif command == 'create':
                weights = param_weights['str_replace_editor_create']['weights']
                categories = param_weights['str_replace_editor_create']['categories']
                token_types = param_weights['str_replace_editor_create']['token_param_types']

                param_weights['str_replace_editor_create']['occurrence_count'] += 1
                
                # ===== Check file extension =====
                path = func_call.get('path')
                should_filter_file = _should_filter_file_by_extension(path, config)
                
                for param_name in ['path', 'file_text']:
                    if param_name in params:
                        param_info = params[param_name]
                        
                        if param_name == 'path':
                            # ===== Filter path param based on config =====
                            if config.filter_path_parameters:
                                _apply_overlap_weight_to_array(
                                    weights, categories, token_types, token_offsets,
                                    param_info['start'], param_info['end'], 0.0,
                                    f"create_{param_name}", "str_replace_editor_create"
                                )
                            else:
                                _apply_overlap_weight_to_array(
                                    weights, categories, token_types, token_offsets,
                                    param_info['start'], param_info['end'], 1.0,
                                    f"create_{param_name}", "str_replace_editor_create"
                                )
                        
                        elif param_name == 'file_text':
                            # ===== file_text: Use AST parsing =====
                            strategy = config.param_router['str_replace_editor_create']
                            
                            # ===== Force zero if file type is filtered =====
                            force_zero = should_filter_file
                            
                            _apply_weight_with_ast(
                                param_info, weights, categories, token_types,
                                token_offsets, parser, config,
                                "str_replace_editor_create", strategy,
                                force_zero=force_zero
                            )
            
            elif command == 'insert':
                weights = param_weights['str_replace_editor_insert']['weights']
                categories = param_weights['str_replace_editor_insert']['categories']
                token_types = param_weights['str_replace_editor_insert']['token_param_types']
                
                param_weights['str_replace_editor_insert']['occurrence_count'] += 1
                
                # ===== Check file extension =====
                path = func_call.get('path')
                should_filter_file = _should_filter_file_by_extension(path, config)
                
                for param_name in ['path', 'new_str', 'insert_line']:
                    if param_name in params:
                        param_info = params[param_name]
                        
                        if param_name == 'new_str':
                            strategy = config.param_router['str_replace_editor_insert']
                            
                            # ===== Force zero if file type is filtered =====
                            force_zero = should_filter_file
                            
                            _apply_weight_with_ast(
                                param_info, weights, categories, token_types,
                                token_offsets, parser, config,
                                "str_replace_editor_insert", strategy,
                                force_zero=force_zero
                            )
                        else:
                            # ===== insert command: no filtering for path and insert_line =====
                            _apply_overlap_weight_to_array(
                                weights, categories, token_types, token_offsets,
                                param_info['start'], param_info['end'], 1.0,
                                f"insert_{param_name}", "str_replace_editor_insert"
                            )
            
    # Process str_replace
    pairing_result = pair_old_new_str(function_calls, file_types)
    
    paired_list = pairing_result['paired']
    unpaired_old = pairing_result['unpaired_old']
    unpaired_new = pairing_result['unpaired_new']
    
    # print(f"  str_replace paired: {len(paired_list)}, unpaired: {len(unpaired_old) + len(unpaired_new)}")
    
    for pair in paired_list:
        old_info = pair['old']
        new_info = pair['new']
        file_type = pair['file_type']
        path = pair['path']

        suffix = '_created' if file_type == 'created' else '_repo'
        
        # Process old_str
        old_param_type = f'old_str{suffix}'
        old_weights = param_weights[old_param_type]['weights']
        old_categories = param_weights[old_param_type]['categories']
        old_token_types = param_weights[old_param_type]['token_param_types']
        
        param_weights[old_param_type]['occurrence_count'] += 1
        
        strategy = config.param_router.get(old_param_type, {"type": "ast", "base_weight": 1.0})
        _apply_weight_with_ast(
            old_info, old_weights, old_categories, old_token_types,
            token_offsets, parser, config, old_param_type, strategy
        )

        # Process new_str
        new_param_type = f'new_str{suffix}'
        new_weights = param_weights[new_param_type]['weights']
        new_categories = param_weights[new_param_type]['categories']
        new_token_types = param_weights[new_param_type]['token_param_types']
        
        param_weights[new_param_type]['occurrence_count'] += 1
        
        strategy = config.param_router.get(new_param_type, {"type": "ast", "base_weight": 1.0})
        _apply_weight_with_ast(
            new_info, new_weights, new_categories, new_token_types,
            token_offsets, parser, config, new_param_type, strategy,
            force_zero=should_filter_file
        )
        
        # Process modification detection
        if file_type == 'repo':
            modified_weights = param_weights['new_str_repo_modified']['weights']
            modified_categories = param_weights['new_str_repo_modified']['categories']
            modified_token_types = param_weights['new_str_repo_modified']['token_param_types']
            
            modified_indices = find_modified_tokens_precise(
                old_info['value'],
                new_info['value'],
                tokenizer,
                token_offsets,
                new_info['start'],
                new_info['end']
            )
            
            for idx in modified_indices:
                if 0 <= idx < num_tokens:
                    modified_weights[idx] = 1.0
                    modified_categories[idx] = "modified"
                    modified_token_types[idx] = "new_str_repo_modified"
            
            if modified_indices:
                param_weights['new_str_repo_modified']['occurrence_count'] += 1

    # Noise Filtering
    if config.zero_noise_tokens:
        frequency_filtered_count = 0
        
        for param_type in config.parameter_types:
            weights = param_weights[param_type]['weights']
            categories = param_weights[param_type]['categories']
            token_param_types = param_weights[param_type]['token_param_types']
            
            for i in range(num_tokens):
                if thinking_mask[i]:
                    token_str = tokenizer.decode([target_index_all[i]])
                    token_id = target_index_all[i]
                    
                    if should_zero_thinking_token_with_frequency(token_id, token_str, categories[i], config):
                        weights[i] = 0.0
                        categories[i] = "noise_zero"
                        
                        if config.filter_thinking_by_frequency and token_id in config.thinking_noise_token_ids:
                            frequency_filtered_count += 1
                
                elif action_mask[i]:
                    token_str = tokenizer.decode([target_index_all[i]])
                    
                    if should_zero_weight_strict(token_str, categories[i], token_param_types[i]):
                        weights[i] = 0.0
                        categories[i] = "noise_zero"
        
        if config.filter_thinking_by_frequency:
            thinking_token_count = np.sum(thinking_mask)
            if thinking_token_count > 0:
                filter_ratio = frequency_filtered_count / thinking_token_count
                # print(f"  Thinking frequency filter: {frequency_filtered_count}/{thinking_token_count} ({filter_ratio:.2%}) tokens filtered")
                        
    return param_weights


def _apply_weight_with_ast(param_info, weights, categories, token_types, 
                           token_offsets, parser, config, param_type, strategy,
                           force_zero=False):
    """Apply weights with AST parsing."""
    val_start = param_info['start']
    val_end = param_info['end']
    val_str = param_info['value']
    
    # ===== Force zero weight =====
    if force_zero:
        _apply_overlap_weight_to_array(
            weights, categories, token_types, token_offsets,
            val_start, val_end, 0.0, f"filtered_{param_type}", param_type
        )
        return
    
    if strategy["type"] == "direct":
        _apply_overlap_weight_to_array(
            weights, categories, token_types, token_offsets,
            val_start, val_end, strategy["weight"], f"val_{param_type}", param_type
        )
    
    elif strategy["type"] == "ast":
        _apply_overlap_weight_to_array(
            weights, categories, token_types, token_offsets,
            val_start, val_end, strategy["base_weight"], f"val_{param_type}", param_type
        )
        
        if parser and val_str.strip() and param_type in config.ast_filter_param_types:
            try:
                # ===== Remove shebang =====
                val_str_for_ast = val_str
                
                if config.remove_shebang_for_ast and val_str.startswith('#!'):
                    first_newline = val_str.find('\n')
                    if first_newline != -1:
                        val_str_for_ast = val_str[first_newline + 1:]
                    else:
                        val_str_for_ast = ""
                
                if not val_str_for_ast.strip():
                    return
                
                # ===== Reset identifier tracking =====
                _reset_seen_identifiers()
                
                # ===== Determine if first occurrence filtering is needed =====
                should_filter_first_occurrence = (
                    config.enable_first_occurrence_filtering and 
                    param_type in config.filter_first_occurrence_in_param_types
                )
                
                tree = parser.parse(bytes(val_str_for_ast, "utf8"))
                
                if tree.root_node.has_error:
                    return
                
                _apply_ast_recursive_to_array(
                    tree.root_node, 
                    weights, 
                    categories,
                    token_types,
                    token_offsets, 
                    val_start,
                    val_str_for_ast,
                    config.ast_weights,
                    param_type,
                    config,
                    should_filter_first_occurrence
                )
            except Exception as e:
                pass


def _apply_overlap_weight_to_array(weights, categories, token_param_types, token_offsets, 
                                    target_start, target_end, weight, label, param_type):
    """Apply weight to specified array range."""
    for i, (t_start, t_end) in enumerate(token_offsets):
        t_center = (t_start + t_end) / 2
        
        if target_start <= t_center <= target_end:
            current_category = categories[i]
            is_protected = current_category in PROTECTED_CATEGORIES
            
            if weight == 0.0:
                weights[i] = 0.0
                categories[i] = label
                token_param_types[i] = param_type
            
            elif weight == 1.0:
                if not is_protected:
                    weights[i] = 1.0
                    categories[i] = label
                    token_param_types[i] = param_type


def _apply_ast_recursive_to_array(node, weights, categories, token_param_types, token_offsets, 
                                   code_offset_char, code_str, ast_weights, param_type, config,
                                   should_filter_first_occurrence=False):
    """Recursive AST processing function (Enhanced with switch control)."""
    
    # ===== Initialize tracking sets =====
    if not hasattr(_apply_ast_recursive_to_array, 'seen_identifiers'):
        _apply_ast_recursive_to_array.seen_identifiers = set()
    if not hasattr(_apply_ast_recursive_to_array, 'seen_function_names'):
        _apply_ast_recursive_to_array.seen_function_names = set()
    if not hasattr(_apply_ast_recursive_to_array, 'seen_class_names'):
        _apply_ast_recursive_to_array.seen_class_names = set()
    
    node_type = node.type
    
    code_bytes = code_str.encode('utf-8')
    node_start_char_in_code = len(code_bytes[:node.start_byte].decode('utf-8', errors='ignore'))
    node_end_char_in_code = len(code_bytes[:node.end_byte].decode('utf-8', errors='ignore'))
    
    node_start_char = code_offset_char + node_start_char_in_code
    node_end_char = code_offset_char + node_end_char_in_code
    
    # ===== Priority 1: Force Zero (Based on switches) =====
    
    # 1.1 Comment
    if node_type == "comment" and config.filter_comments:
        _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                               node_start_char, node_end_char, 0.0, "comment", param_type)
        return
    
    # 1.2 Docstring
    if node_type == "string" and config.filter_docstrings:
        if _is_docstring(node):
            _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                                   node_start_char, node_end_char, 0.0, "docstring", param_type)
            return
    
    # 1.3 Print/Log Statements
    if node_type == "expression_statement" and config.filter_output_statements:
        contains_output = False
        for child in node.children:
            if child.type == "call":
                if _is_output_function(child, code_str, config):
                    contains_output = True
                    break
        
        if contains_output:
            _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                                   node_start_char, node_end_char, 0.0, "output_deleted", param_type)
            return
    
    # ===== Priority 2: Filter First Occurrence =====
    
    if should_filter_first_occurrence and config.enable_first_occurrence_filtering:
        
        # 2.1 Function Name
        if node_type == "function_definition" and config.filter_first_occurrence_function_names:
            func_name_node = _get_function_name_node(node)
            if func_name_node:
                func_name_text = _get_node_text(func_name_node, code_str)
                
                if func_name_text not in _apply_ast_recursive_to_array.seen_function_names:
                    func_name_start_char_in_code = len(code_bytes[:func_name_node.start_byte].decode('utf-8', errors='ignore'))
                    func_name_end_char_in_code = len(code_bytes[:func_name_node.end_byte].decode('utf-8', errors='ignore'))
                    
                    func_name_start_char = code_offset_char + func_name_start_char_in_code
                    func_name_end_char = code_offset_char + func_name_end_char_in_code
                    
                    _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                                           func_name_start_char, func_name_end_char, 0.0, 
                                           "first_occurrence_function_name", param_type)
                    
                    _apply_ast_recursive_to_array.seen_function_names.add(func_name_text)
        
        # 2.2 Class Name
        if node_type == "class_definition" and config.filter_first_occurrence_class_names:
            class_name_node = _get_class_name_node(node)
            if class_name_node:
                class_name_text = _get_node_text(class_name_node, code_str)
                
                if class_name_text not in _apply_ast_recursive_to_array.seen_class_names:
                    class_name_start_char_in_code = len(code_bytes[:class_name_node.start_byte].decode('utf-8', errors='ignore'))
                    class_name_end_char_in_code = len(code_bytes[:class_name_node.end_byte].decode('utf-8', errors='ignore'))
                    
                    class_name_start_char = code_offset_char + class_name_start_char_in_code
                    class_name_end_char = code_offset_char + class_name_end_char_in_code
                    
                    _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                                           class_name_start_char, class_name_end_char, 0.0, 
                                           "first_occurrence_class_name", param_type)
                    
                    _apply_ast_recursive_to_array.seen_class_names.add(class_name_text)
        
        # 2.3 Variable Name
        if node_type == "identifier" and config.filter_first_occurrence_identifiers:
            identifier_text = _get_node_text(node, code_str)
            
            if _is_assignment_target(node):
                if identifier_text not in _apply_ast_recursive_to_array.seen_identifiers:
                    _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                                           node_start_char, node_end_char, 0.0, 
                                           "first_occurrence_identifier", param_type)
                    _apply_ast_recursive_to_array.seen_identifiers.add(identifier_text)
                    return
    
    # ===== Priority 3: Apply AST Weights =====
    target_weight = ast_weights.get(node_type)
    if target_weight is not None:
        _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                               node_start_char, node_end_char, target_weight, node_type, param_type)
    
    # ===== Priority 4: Recursively Process Children =====
    for child in node.children:
        _apply_ast_recursive_to_array(
            child, weights, categories, token_param_types, token_offsets, 
            code_offset_char, code_str, ast_weights, param_type, config,
            should_filter_first_occurrence
        )


def _apply_weight_to_range(weights, categories, token_param_types, token_offsets,
                           start_char, end_char, weight, category, param_type):
    """Apply weight to a specific character range."""
    for i, (t_start, t_end) in enumerate(token_offsets):
        if t_start == 0 and t_end == 0:
            continue
        
        t_center = (t_start + t_end) / 2
        if start_char <= t_center <= end_char:
            current_category = categories[i]
            is_protected = current_category in PROTECTED_CATEGORIES
            
            if weight == 0.0:
                weights[i] = 0.0
                categories[i] = category
                token_param_types[i] = param_type
            
            elif weight == 1.0:
                if not is_protected:
                    weights[i] = 1.0
                    categories[i] = category
                    token_param_types[i] = param_type


def process_trajectory_with_per_parameter_weights(args):
    """Worker entry point."""
    global _global_tokenizer, _global_parser
    
    load_result = load_trajectory_from_jsonl(args)
    
    if load_result['status'] == 'failed':
        return load_result
    
    llm_token_ids = load_result['llm_token_ids']
    line_index = load_result['line_index']
    config = args['config']
    
    try:
        param_weights = get_per_parameter_weights(
            llm_token_ids, _global_tokenizer, _global_parser, config
        )
        
        return {
            'line_index': line_index,
            'llm_token_ids': llm_token_ids,
            'param_weights': param_weights,
            'token_count': len(llm_token_ids),
            'status': 'success',
            'error': None
        }
    except Exception as e:
        import traceback
        return {
            'line_index': line_index,
            'status': 'failed',
            'error': f"Processing error: {str(e)}\n{traceback.format_exc()}"
        }


def save_per_parameter_weights(results, output_path, config):
    """Save weights for each parameter type."""
    print(f"\n[INFO] Saving per-parameter weights to: {output_path}")
    
    sorted_indices = sorted(results.keys())
    save_dict = {}
    
    param_occurrence_total = Counter()
    
    for text_idx in sorted_indices:
        data = results[text_idx]
        
        save_dict[f"line_{text_idx}_llm_token_ids"] = data['llm_token_ids']
        save_dict[f"line_{text_idx}_token_count"] = data['token_count']
        
        for param_type in config.parameter_types:
            param_data = data['param_weights'][param_type]
            
            save_dict[f"line_{text_idx}_{param_type}_weights"] = param_data['weights']
            save_dict[f"line_{text_idx}_{param_type}_categories"] = np.array(param_data['categories'], dtype=object)
            save_dict[f"line_{text_idx}_{param_type}_occurrence_count"] = param_data['occurrence_count']
            
            param_occurrence_total[param_type] += param_data['occurrence_count']
    
    save_dict['_all_text_indices'] = np.array(sorted_indices, dtype=np.int64)
    save_dict['_total_trajectories'] = len(sorted_indices)
    save_dict['_parameter_types'] = np.array(config.parameter_types, dtype=object)
    
    for param_type in config.parameter_types:
        save_dict[f'_total_{param_type}_occurrences'] = param_occurrence_total[param_type]
    
    np.savez_compressed(output_path, **save_dict)
    
    file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"[INFO] Saved {len(sorted_indices)} trajectories. Size: {file_size_mb:.2f} MB")
    
    print("\n" + "="*80)
    print("PARAMETER OCCURRENCE STATISTICS")
    print("="*80)
    print(f"{'Parameter Type':<35} {'Total Occurrences':>20}")
    print("-"*80)
    
    for param_type in config.parameter_types:
        count = param_occurrence_total[param_type]
        print(f"{param_type:<35} {count:>20,}")
    
    print("="*80)
    print("\n✅ Using Configurable Filtering System")
    print("="*80)


def batch_process_jsonl_per_parameter(
    jsonl_path,
    tokenizer_path,
    output_path,
    config,
    start_line=0,
    end_line=None,
    max_tokens=32700,
    max_workers=8
):
    if end_line is None:
        print("[INFO] Counting total lines...")
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            total_lines = sum(1 for _ in f)
        end_line = total_lines
        print(f"[INFO] Total lines: {total_lines}")
    
    tasks = []
    for line_idx in range(start_line, end_line):
        tasks.append({
            'jsonl_path': jsonl_path,
            'line_index': line_idx,
            'max_tokens': max_tokens,
            'config': config
        })
    
    print(f"[INFO] Processing lines {start_line} to {end_line} with {max_workers} workers")
    
    results = {}
    failed_list = []
    
    with ProcessPoolExecutor(max_workers=max_workers, 
                             initializer=init_worker, 
                             initargs=(tokenizer_path,)) as executor:
        
        futures = {executor.submit(process_trajectory_with_per_parameter_weights, task): task for task in tasks}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
            result = future.result()
            
            if result['status'] == 'success':
                results[result['line_index']] = {
                    'llm_token_ids': result['llm_token_ids'],
                    'param_weights': result['param_weights'],
                    'token_count': result['token_count']
                }
            else:
                failed_list.append({
                    'line_index': result['line_index'],
                    'error': result['error']
                })
    
    print(f"\n[INFO] Success: {len(results)}, Failed: {len(failed_list)}")
    
    if failed_list:
        failed_path = output_path.replace('.npz', '_failed.json')
        with open(failed_path, 'w') as f:
            json.dump(failed_list, f, indent=2)
        print(f"[INFO] Failed list saved to: {failed_path}")
    
    if results:
        save_per_parameter_weights(results, output_path, config)
    else:
        print("[WARN] No results to save.")
    
    return results


# ================= Main Execution =================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process JSONL for AST-based parameter weighting.")
    parser.add_argument("--jsonl_path", type=str, required=True, help="Path to input JSONL file")
    parser.add_argument("--tokenizer_path", type=str, required=True, help="Path to tokenizer")
    parser.add_argument("--output_path", type=str, required=True, help="Path to output NPZ file")
    parser.add_argument("--start_line", type=int, default=0, help="Start line index")
    parser.add_argument("--end_line", type=int, default=500, help="End line index")
    parser.add_argument("--max_workers", type=int, default=32, help="Number of worker processes")
    parser.add_argument("--max_tokens", type=int, default=131000, help="Max tokens per trajectory")
    
    args = parser.parse_args()

    config = ParameterFilterConfig()
    
    # ========================================
    # ===== Configure Filtering Strategy =====
    # ========================================
    
    config.filter_comments = True
    config.filter_docstrings = False
    config.filter_output_statements = False
    config.enable_first_occurrence_filtering = False
    config.filter_first_occurrence_identifiers = False
    config.filter_first_occurrence_function_names = False
    config.filter_first_occurrence_class_names = False
    config.filter_files_by_extension = True
    config.filter_file_extensions = {'.md', '.txt', '.rst'}
    config.filter_path_parameters = True
    config.remove_shebang_for_ast = False
    
    # Print configuration
    config.print_filter_config()

    batch_process_jsonl_per_parameter(
        jsonl_path=args.jsonl_path,
        tokenizer_path=args.tokenizer_path,
        output_path=args.output_path,
        config=config,
        start_line=args.start_line,
        end_line=args.end_line,
        max_workers=args.max_workers,
        max_tokens=args.max_tokens
    )