import json
import os
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Merge multiple token files into one.")
    
    parser.add_argument("--input_files", nargs='+', required=True, 
                        help="List of token files to merge (e.g. structural.json content.json ...)")
    
    parser.add_argument("--output_file", type=str, required=True, 
                        help="Path to save the final merged token list")
    
    return parser.parse_args()

def main():
    args = parse_args()
    
    print(f"\n=== Starting Token Merge ===")
    print(f"Inputs ({len(args.input_files)} files):")
    for f in args.input_files:
        print(f"  - {os.path.basename(f)}")
    print(f"Output: {args.output_file}")

    merged_tokens = []
    seen_tokens = set()
    
    total_inputs = 0
    
    for file_path in args.input_files:
        if not os.path.exists(file_path):
            print(f"⚠️  Warning: File not found, skipping: {file_path}")
            continue
            
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                tokens = json.load(f)
                
            file_count = 0
            new_count = 0
            
            for token in tokens:
                file_count += 1
                if token not in seen_tokens:
                    merged_tokens.append(token)
                    seen_tokens.add(token)
                    new_count += 1
            
            total_inputs += file_count
            print(f"  -> Loaded {file_count} tokens from {os.path.basename(file_path)} (Added new: {new_count})")
            
        except json.JSONDecodeError:
            print(f"❌ Error: Invalid JSON format in {file_path}")
            return

    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
    
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump(merged_tokens, f, indent=2, ensure_ascii=False)

    print("\n" + "="*40)
    print("Merge Complete")
    print(f"Total Input Tokens: {total_inputs}")
    print(f"Final Unique Tokens: {len(merged_tokens)}")
    print(f"Saved to: {args.output_file}")
    print("="*40)

if __name__ == "__main__":
    main()