import json
import os
import argparse
from collections import Counter
from tqdm import tqdm
from transformers import AutoTokenizer

def parse_args():
    parser = argparse.ArgumentParser(description="Mine tokens based on BPE-like frequency analysis.")
    
    parser.add_argument("--input_file", type=str, required=True, 
                        help="Path to the verification debug file (must contain 'actual_input')")
    parser.add_argument("--output_file", type=str, required=True, 
                        help="Full path to save the mined tokens json file")
    parser.add_argument("--model_path", type=str, required=True, 
                        help="Path to the base model/tokenizer")
    
    parser.add_argument("--num_rounds", type=int, default=3, 
                        help="Number of merge rounds (default: 3)")
    parser.add_argument("--merge_size", type=int, default=10000, 
                        help="Number of pairs to merge per round (default: 10000)")
    
    return parser.parse_args()

def load_and_tokenize_data(file_path, tokenizer):
    """
    Reads 'actual_input' from file and converts to token IDs.
    """
    corpus_ids = []
    print(f"Loading data from {os.path.basename(file_path)}...")
    
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in tqdm(lines, desc="Tokenizing"):
            try:
                data = json.loads(line.strip())
                # Prefer 'actual_input' from verification step
                content = data.get('actual_input', '')
                
                # Fallback: construct from messages if actual_input is missing
                if not content and 'messages' in data:
                     for msg in data['messages']:
                         content += str(msg.get('content', '')) + "\n"

                if not content: continue
                
                # Tokenize (keep special tokens separated to avoid merging them)
                ids = tokenizer.encode(content, add_special_tokens=False)
                if len(ids) > 1:
                    corpus_ids.append(ids)
            except Exception:
                continue
    return corpus_ids

def count_id_pairs(corpus_ids, blocklist_ids):
    """
    Counts adjacent ID pairs. Skips pairs containing blocklisted IDs.
    """
    pairs = Counter()
    for ids in tqdm(corpus_ids, desc="Counting Pairs", unit="sent"):
        for i in range(len(ids) - 1):
            id1 = ids[i]
            id2 = ids[i+1]
            
            # Critical Safety: Do not merge if either ID is a special token
            if id1 in blocklist_ids or id2 in blocklist_ids:
                continue
                
            pair = (id1, id2)
            pairs[pair] += 1
    return pairs

def merge_ids_in_corpus(corpus_ids, pairs_to_merge_map):
    """
    Applies the merges to the corpus.
    """
    new_corpus = []
    for ids in tqdm(corpus_ids, desc="Merging Corpus", unit="sent"):
        new_ids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1:
                pair = (ids[i], ids[i+1])
                if pair in pairs_to_merge_map:
                    # Merge successful
                    new_ids.append(pairs_to_merge_map[pair])
                    i += 2
                    continue
            new_ids.append(ids[i])
            i += 1
        new_corpus.append(new_ids)
    return new_corpus

def main():
    args = parse_args()
    
    print(f"\n=== Starting Content-Aware Token Mining ===")
    print(f"Input: {os.path.basename(args.input_file)}")
    print(f"Output: {args.output_file}")
    
    # 1. Load Tokenizer
    print("Loading tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        return

    vocab_size = len(tokenizer)
    # Identify special tokens (like <|im_start|>) to prevent merging across boundaries
    blocklist_ids = set(tokenizer.all_special_ids)
    print(f"Locked {len(blocklist_ids)} special token IDs.")
    
    if not os.path.exists(args.input_file):
        print(f"Error: Input file not found: {args.input_file}")
        return

    # 2. Prepare Data
    corpus_ids = load_and_tokenize_data(args.input_file, tokenizer)
    print(f"Loaded {len(corpus_ids)} samples.")

    mined_tokens_info = {}
    current_virtual_id = vocab_size

    # 3. Mining Loop
    for round_num in range(1, args.num_rounds + 1):
        print(f"\n>>> [Round {round_num}/{args.num_rounds}]")
        
        # A. Count
        pair_counts = count_id_pairs(corpus_ids, blocklist_ids)
        
        # B. Select Top N
        sorted_pairs = pair_counts.most_common(args.merge_size)
        if not sorted_pairs:
            print("No more pairs to merge.")
            break
        print(f"Merging top {len(sorted_pairs)} pairs...")
        
        # C. Register
        temp_merge_map = {}
        for pair, freq in sorted_pairs:
            id1, id2 = pair
            
            # Trace back to original IDs
            seq1 = mined_tokens_info[id1]['original_ids'] if id1 in mined_tokens_info else [id1]
            seq2 = mined_tokens_info[id2]['original_ids'] if id2 in mined_tokens_info else [id2]
            
            full_sequence = seq1 + seq2
            token_str = tokenizer.decode(full_sequence)
            
            # Store metadata
            mined_tokens_info[current_virtual_id] = {
                "token_str": token_str,
                "freq": freq,
                "original_ids": full_sequence,
                "savings": freq * (len(full_sequence) - 1) 
            }
            temp_merge_map[pair] = current_virtual_id
            current_virtual_id += 1
            
        # D. Apply
        corpus_ids = merge_ids_in_corpus(corpus_ids, temp_merge_map)

    # 4. Save Results
    print("\nFinalizing results...")
    final_results = []
    for vid, info in mined_tokens_info.items():
        raw_tokens = tokenizer.convert_ids_to_tokens(info['original_ids'])
        safe_view = " | ".join([str(t) if isinstance(t, bytes) else str(t) for t in raw_tokens])
        
        final_results.append({
            "token": info['token_str'],
            "freq": info['freq'],
            "old_token_count": len(info['original_ids']),
            "tokenizer_view": safe_view,
            "total_savings": info['savings']
        })

    final_results.sort(key=lambda x: x['total_savings'], reverse=True)

    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump(final_results, f, indent=2, ensure_ascii=False)

    print(f"\n🎉 Content Mining Complete!")
    print(f"Results saved to: {args.output_file}")
    
    print("\n🏆 Top 10 Content Tokens")
    print("-" * 60)
    for i, item in enumerate(final_results[:10]):
        preview = item['token'].replace('\n', '↵')
        if len(preview) > 40: preview = preview[:37] + "..."
        print(f"{i+1:<4} | Savings: {item['total_savings']:<8} | {preview}")

if __name__ == "__main__":
    main()