import json
import os
import re
import warnings
import argparse
from collections import Counter
from tqdm import tqdm
from transformers import AutoTokenizer

# Suppress warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ================= Configuration =================
ALLOWED_KEYWORDS = {
    # JSON Schema
    "type", "properties", "required", "description", "title", "default", 
    "patternProperties", "parameters", 
    "enum", "items", "format", "pattern", "minimum", "maximum",
    "minLength", "maxLength", "additionalProperties", "$schema", "$id", "$ref", "$comment",
    "definitions", "allOf", "anyOf", "oneOf", "not", "const", "examples",
    "multipleOf", "exclusiveMinimum", "exclusiveMaximum", "minItems",
    "maxItems", "uniqueItems", "contains", "maxProperties", "minProperties",
    "dependencies", "propertyNames", "if", "then", "else", "readOnly", "writeOnly", 
    "additionalItems", "contentMediaType", "contentEncoding",
    
    # Values
    "string", "number", "integer", "object", "array", "boolean", "null",
    "true", "false",
    
    # XML & Function Call
    "tools", "tool_call", "tool_response", "function", "name", "arguments"
}

ALLOWED_SYMBOLS = set('{}[]:,"/<>\t ')

# ================= Helper Functions =================

def parse_args():
    parser = argparse.ArgumentParser(description="Mine structural tokens from dataset.")
    parser.add_argument("--input_file", type=str, required=True, 
                        help="Path to the input dataset (.jsonl)")
    parser.add_argument("--output_file", type=str, required=True, 
                        help="Full path to save the scored tokens (.json)")
    parser.add_argument("--model_path", type=str, required=True, 
                        help="Path to the base model/tokenizer")
    parser.add_argument("--min_frequency", type=int, default=5, 
                        help="Minimum frequency to consider a token (default: 5)")
    return parser.parse_args()

def is_structural_atom(token):
    if not token: return False
    if all(char in ALLOWED_SYMBOLS for char in token): return True
    if token in ALLOWED_KEYWORDS: return True
    return False

def tokenize_for_scanning(text):
    pattern = re.compile(r'(<[^>]+>|\s+|[{}[\],:"]+|[\w$]+)')
    return [t for t in pattern.findall(text) if t]

def mine_structural_islands(text):
    tokens = tokenize_for_scanning(text)
    islands, current_island = [], []
    for token in tokens:
        if token.startswith("<") and token.endswith(">"):
            tag_content = token[1:-1].lstrip('/')
            if tag_content in ALLOWED_KEYWORDS:
                current_island.append(token); continue
        if is_structural_atom(token):
            current_island.append(token)
        else:
            if current_island:
                islands.append("".join(current_island))
                current_island = []
    if current_island:
        islands.append("".join(current_island))
    return islands

def analyze_token_savings(token_str, freq, tokenizer):
    old_ids = tokenizer.encode(token_str, add_special_tokens=False)
    old_len = len(old_ids)
    if old_len <= 1: return None
    sub_tokens = [tokenizer.decode([_id]) for _id in old_ids]
    tokenizer_view = " | ".join(sub_tokens)
    savings = (old_len - 1) * freq
    return {"old_len": old_len, "savings": savings, "view": tokenizer_view}

def main():
    args = parse_args()
    
    print(f"\n=== Starting Structural Token Mining ===")
    print(f"Input: {os.path.basename(args.input_file)}")
    print(f"Output: {args.output_file}")
    
    # 1. Load Tokenizer
    print("Loading tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        return

    if not os.path.exists(args.input_file):
        print(f"Error: Input file not found: {args.input_file}")
        return
        
    candidate_counter = Counter()
    
    # 2. Mining Phase
    print("Mining structural patterns...")
    with open(args.input_file, 'r', encoding='utf-8') as f:
        total_lines = sum(1 for line in f); f.seek(0)
        
        for line in tqdm(f, total=total_lines, desc="Scanning"):
            try:
                data = json.loads(line)
                
                # Logic 1: Handle 'tools' field
                tools_content_str = data.get('tools')
                if tools_content_str and isinstance(tools_content_str, str):
                    try:
                        tools_list_obj = json.loads(tools_content_str)
                        if isinstance(tools_list_obj, list):
                            items_to_mine = ['<tools>']
                            for tool_obj in tools_list_obj:
                                tool_line_str = json.dumps(tool_obj, ensure_ascii=False)
                                items_to_mine.append(tool_line_str)
                            items_to_mine.append('</tools>')
                            for item_str in items_to_mine:
                                for island in mine_structural_islands(item_str):
                                    candidate_counter[island.strip()] += 1
                    except (json.JSONDecodeError, TypeError):
                        for item_str in ['<tools>', tools_content_str, '</tools>']:
                           for island in mine_structural_islands(item_str):
                                candidate_counter[island.strip()] += 1

                # Logic 2: Handle 'tool_call' messages
                for msg in data.get('messages', []):
                    if msg.get('role') == 'tool_call':
                        content = msg.get('content', '')
                        if content and isinstance(content, str):
                            items_to_mine = ['<tool_call>', content, '</tool_call>']
                            for item_str in items_to_mine:
                                for island in mine_structural_islands(item_str):
                                    candidate_counter[island.strip()] += 1
            except (json.JSONDecodeError, AttributeError):
                continue

    print(f"\nMining complete. Found {len(candidate_counter)} candidates.")
    print("Analyzing tokenizer savings...")
    
    # 3. Analysis Phase
    scored_tokens = []
    frequent_candidates = {k: v for k, v in candidate_counter.items() if v >= args.min_frequency and k}

    for token_str, freq in tqdm(frequent_candidates.items(), desc="Scoring"):
        if len(token_str) < 2: continue
        
        analysis = analyze_token_savings(token_str, freq, tokenizer)
        if analysis:
            scored_tokens.append({
                "token": token_str, 
                "freq": freq,
                "old_token_count": analysis['old_len'],
                "tokenizer_view": analysis['view'],
                "total_savings": analysis['savings']
            })
            
    scored_tokens.sort(key=lambda x: x["total_savings"], reverse=True)
    
    # 4. Save Results
    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump(scored_tokens, f, indent=2, ensure_ascii=False)
        
    print("\n" + "="*80)
    print("🏆 Top 10 Mined Structural Tokens")
    print("-" * 80)
    for i, item in enumerate(scored_tokens[:10]):
        preview = item['token'].replace('\n', '↵')
        if len(preview) > 50: preview = preview[:47] + "..."
        print(f"{i+1:<4} | Savings: {item['total_savings']:<8} | {preview}")
    print("-" * 80)
    print(f"\n✅ Results saved to: {args.output_file}")

if __name__ == "__main__":
    main()