import json
import argparse
import os
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description="Filter dataset based on verification statistics.")
    
    parser.add_argument("--input_file", type=str, required=True, 
                        help="Path to the verification result file (.jsonl)")
    parser.add_argument("--output_file", type=str, required=True, 
                        help="Path to save the final filtered dataset (.jsonl)")
    
    parser.add_argument("--max_tokens", type=int, default=32000, 
                        help="Maximum allowed input tokens (default: 32000)")
    
    return parser.parse_args()

def main():
    args = parse_args()
    
    print(f"\n=== Starting Dataset Filtering ===")
    print(f"Input File: {os.path.basename(args.input_file)}")
    print(f"Output File: {args.output_file}")
    print(f"Max Tokens: {args.max_tokens}")

    if not os.path.exists(args.input_file):
        print(f"Error: Input file not found: {args.input_file}")
        return

    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)

    stats = {
        "total": 0,
        "valid": 0,
        "filtered_error": 0,
        "filtered_length": 0,
        "filtered_none": 0
    }
    
    print("Filtering data...")
    
    with open(args.input_file, 'r', encoding='utf-8') as f_in, \
         open(args.output_file, 'w', encoding='utf-8') as f_out:
        
        for line in tqdm(f_in, desc="Processing"):
            stats["total"] += 1
            try:
                data = json.loads(line.strip())
                
                has_none_ids = data.get("has_none_ids", True)
                error_msg = data.get("error_msg", "")
                
                token_stats = data.get("token_stats")
                if token_stats is None:
                    token_stats = {}
                total_input_tokens = token_stats.get("total_input_tokens", 0)
                
                is_valid = True
                
                if error_msg:
                    is_valid = False
                    stats["filtered_error"] += 1
                
                elif has_none_ids:
                    is_valid = False
                    stats["filtered_none"] += 1
                    
                elif total_input_tokens >= args.max_tokens:
                    is_valid = False
                    stats["filtered_length"] += 1

                if is_valid:
                    original_data = data.get("original_data")
                    if original_data:
                        f_out.write(json.dumps(original_data, ensure_ascii=False) + '\n')
                        stats["valid"] += 1
                    else:
                        stats["filtered_error"] += 1
                    
            except Exception as e:
                print(f"Error processing line: {e}")
                stats["filtered_error"] += 1
                continue

    print("\n" + "="*40)
    print("Filtering Complete")
    print(f"Total Processed:   {stats['total']}")
    print(f"✅ Kept (Valid):    {stats['valid']}")
    print(f"❌ Removed Total:   {stats['total'] - stats['valid']}")
    print("-" * 20)
    print(f"   - Error/Empty:   {stats['filtered_error']}")
    print(f"   - Has None IDs:  {stats['filtered_none']}")
    print(f"   - Over Length:   {stats['filtered_length']}")
    
    retention_rate = (stats['valid'] / stats['total'] * 100) if stats['total'] > 0 else 0
    print(f"Retention Rate:    {retention_rate:.2f}%")
    print(f"Final Dataset:     {args.output_file}")
    print("="*40)

if __name__ == "__main__":
    main()