import json
import os
import time
import argparse
import sys
from pathlib import Path
from tqdm import tqdm

# Try to import swift.llm, handle error if not installed
try:
    from swift.llm import get_model_tokenizer, get_template
except ImportError:
    print("Error: 'ms-swift' library is required. Please install it via: pip install ms-swift")
    sys.exit(1)

def parse_args():
    parser = argparse.ArgumentParser(description="Verify dataset tokenization and template formatting.")
    
    parser.add_argument("--input_file", type=str, required=True, 
                        help="Path to the input JSONL file (from converter step)")
    parser.add_argument("--output_file", type=str, required=True, 
                        help="Path to save the verification results")
    parser.add_argument("--model_path", type=str, required=True, 
                        help="Path to the model directory")
    parser.add_argument("--template_type", type=str, default="hermes", 
                        help="Agent template type (default: hermes)")
    
    # Optional flags
    parser.add_argument("--limit", type=int, default=-1, 
                        help="Limit number of lines to process for debugging (-1 for all)")
    parser.add_argument("--verbose", action="store_true", 
                        help="Enable verbose logging for every sample")
    
    return parser.parse_args()

def safe_decode(tokenizer, ids):
    """Decodes ids skipping special tokens like -100."""
    if ids is None or len(ids) == 0:
        return "[EMPTY_OR_NONE]"
    
    # Filter out -100 (ignore index) and None
    valid_ids = [id for id in ids if id != -100 and id is not None]
    
    if not valid_ids:
        return "[NO_VALID_TOKENS]"
    
    try:
        return tokenizer.decode(
            valid_ids, 
            skip_special_tokens=False,
            clean_up_tokenization_spaces=True,
            errors="replace"
        )
    except Exception as e:
        return f"[DECODE_ERROR: {str(e)}]"

def validate_toolcall_data_format(data):
    """Validates the structure of the data entry."""
    if not isinstance(data, dict):
        return False, "Data is not a dictionary"
    
    if "messages" not in data:
        return False, "Missing 'messages' field"
    
    messages = data.get("messages")
    if not isinstance(messages, list) or len(messages) == 0:
        return False, "'messages' must be a non-empty list"
    
    for idx, msg in enumerate(messages):
        if not isinstance(msg, dict):
            return False, f"Message {idx} is not a dict"
        if "role" not in msg or "content" not in msg:
            return False, f"Message {idx} missing role or content"
            
    return True, "Valid"

def main():
    args = parse_args()
    
    print(f"\n=== Starting Tokenization Verification ===")
    print(f"Input File: {os.path.basename(args.input_file)}")
    print(f"Model: {os.path.basename(os.path.normpath(args.model_path))}")
    print(f"Template: {args.template_type}")

    # --- 1. Load Tokenizer & Template ---
    print("Loading tokenizer and template...")
    try:
        _, tokenizer = get_model_tokenizer(
            args.model_path, 
            load_model=False,
            trust_remote_code=True
        )
        template = get_template(
            tokenizer.model_meta.template,
            tokenizer,
            agent_template=args.template_type
        )
        template.set_mode('train')
    except Exception as e:
        print(f"Failed to load model/template: {e}")
        return

    # --- 2. Setup Output ---
    output_dir = Path(args.output_file).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # --- 3. Process Data ---
    token_stats_list = []
    start_time = time.time()
    
    processed_count = 0
    error_count = 0
    format_error_count = 0
    
    print(f"Processing data...")
    
    with open(args.input_file, 'r', encoding='utf-8') as f_in, \
         open(args.output_file, 'w', encoding='utf-8') as f_out:
        
        # Read all lines or just count them for tqdm
        lines = f_in.readlines()
        total_lines = len(lines)
        if args.limit > 0:
            lines = lines[:args.limit]
            total_lines = len(lines)
            print(f"Debugging mode: Processing first {total_lines} lines only.")

        for line_num, line in enumerate(tqdm(lines, desc="Verifying"), 1):
            if not line.strip():
                continue

            try:
                # 3.1 Parse JSON
                data = json.loads(line.strip())
                
                # 3.2 Validate Format
                is_valid, format_msg = validate_toolcall_data_format(data)
                
                if not is_valid:
                    format_error_count += 1
                    # Log error entry
                    output_data = {
                        "line_number": line_num,
                        "original_data": data,
                        "error_msg": format_msg,
                        "has_none_ids": True,
                        "token_stats": {"total_input_tokens": 0}
                    }
                    f_out.write(json.dumps(output_data, ensure_ascii=False) + '\n')
                    continue

                # 3.3 Encode using Template
                encoded = template.encode(data)
                input_ids = encoded.get("input_ids")
                labels = encoded.get("labels")

                # 3.4 Calculate Stats
                total_input_tokens = len(input_ids) if input_ids else 0
                total_labels_tokens = len(labels) if labels else 0
                valid_labels_tokens = len([id for id in labels if id != -100]) if labels else 0
                valid_label_ratio = (valid_labels_tokens / total_labels_tokens * 100) if total_labels_tokens > 0 else 0.0
                
                token_stats = {
                    "total_input_tokens": total_input_tokens,
                    "total_labels_tokens": total_labels_tokens,
                    "valid_labels_tokens": valid_labels_tokens,
                    "valid_label_ratio": round(valid_label_ratio, 2)
                }
                token_stats_list.append(token_stats)

                # 3.5 Decode for Verification (The "Actual IO")
                input_text = safe_decode(tokenizer, input_ids)
                label_text = safe_decode(tokenizer, labels)

                # 3.6 Write Output
                output_data = {
                    "line_number": line_num,
                    "original_data": data,
                    "actual_input": input_text,
                    "actual_label": label_text,
                    "has_none_ids": input_ids is None or labels is None,
                    "error_msg": "",
                    "token_stats": token_stats
                }
                
                f_out.write(json.dumps(output_data, ensure_ascii=False) + '\n')
                processed_count += 1
                
                # Verbose Logging (Optional)
                if args.verbose:
                    print(f"\n[Line {line_num}] Stats: {token_stats}")

            except Exception as e:
                error_count += 1
                if args.verbose:
                    print(f"Error on line {line_num}: {e}")
                # Write minimal error record to keep line sync
                f_out.write(json.dumps({"line_number": line_num, "error_msg": str(e)}, ensure_ascii=False) + '\n')

    # --- 4. Final Summary ---
    elapsed = time.time() - start_time
    print("\n" + "="*40)
    print("Verification Complete")
    print(f"Output File: {args.output_file}")
    print(f"Total Lines: {total_lines}")
    print(f"Successfully Processed: {processed_count}")
    print(f"Format Errors: {format_error_count}")
    print(f"Processing Errors: {error_count}")
    print(f"Time Taken: {elapsed:.2f}s")
    
    if token_stats_list:
        input_lens = [s["total_input_tokens"] for s in token_stats_list]
        print("\nToken Statistics (Input Lengths):")
        print(f"  - Max: {max(input_lens)}")
        print(f"  - Avg: {sum(input_lens)/len(input_lens):.1f}")
        print(f"  - Min: {min(input_lens)}")
    print("="*40)

if __name__ == "__main__":
    main()