#!/usr/bin/env python3
"""
Script to scan the training dataset and identify samples where all labels are -100.
This can cause training crashes since there are no valid tokens to compute loss on.
"""

import os
import yaml
import argparse
from typing import List, Tuple
from transformers import AutoTokenizer
from accelerate import Accelerator
from hr2r.utils.data_prepare import preprocess_dataset
from hr2r.train import CustomHR2RDataCollator
import torch


def scan_all_negative_labels(config_path: str, sample_limit: int = None, check_batches: bool = False):
    """Scan dataset for samples where all labels are -100."""
    
    # Load config
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    accelerator = Accelerator(mixed_precision='bf16')
    
    model_config = config['model']
    data_config = config['data']
    training_config = config['training']
    tokenizer_config = config['tokenizer']
    
    accelerator.print("=== Scanning Dataset for All-Negative Labels ===")
    accelerator.print(f"Config: {config_path}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_config['name'], 
        trust_remote_code=model_config['trust_remote_code'],
        padding_side="right"
    )
    
    if tokenizer.pad_token is None and tokenizer_config['use_pad_token_as_eos']:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load dataset
    accelerator.print("Loading and preprocessing dataset...")
    processed_train_dataset, processed_eval_dataset, _ = preprocess_dataset(
        training_config, data_config, model_config, accelerator
    )
    
    accelerator.print(f"Train dataset size: {len(processed_train_dataset)}")
    
    # Scan for problematic samples
    problematic_samples = []
    total_samples = len(processed_train_dataset) if sample_limit is None else min(sample_limit, len(processed_train_dataset))
    
    accelerator.print(f"Scanning {total_samples} samples for all-negative labels...")
    
    for idx in range(total_samples):
        if idx % 1000 == 0:
            accelerator.print(f"  Progress: {idx}/{total_samples}")
        
        example = processed_train_dataset[idx]
        labels = example['labels']
        
        # Check if all labels are -100
        if all(label == -100 for label in labels):
            problematic_samples.append(idx)
            accelerator.print(f"  ⚠️  Sample {idx}: ALL labels are -100!")
            
            # Print some details about this sample
            input_ids = example['input_ids']
            attention_mask = example['attention_mask']
            iter_count = example['iter_count']
            
            accelerator.print(f"    Input IDs length: {len(input_ids)}")
            accelerator.print(f"    Labels length: {len(labels)}")
            accelerator.print(f"    Attention mask sum: {sum(attention_mask)}")
            accelerator.print(f"    Iter count values: {set(iter_count)}")
            
            # Try to decode the text to see what it contains
            try:
                decoded_text = tokenizer.decode(input_ids, skip_special_tokens=False)
                accelerator.print(f"    Text preview: {decoded_text[:200]}...")
            except Exception as e:
                accelerator.print(f"    Failed to decode: {e}")
    
    # Also check for samples with very few valid labels
    few_labels_samples = []
    accelerator.print(f"\nScanning for samples with very few valid labels...")
    
    for idx in range(total_samples):
        if idx % 1000 == 0 and idx > 0:
            accelerator.print(f"  Progress: {idx}/{total_samples}")
        
        example = processed_train_dataset[idx]
        labels = example['labels']
        
        # Count valid labels (not -100)
        valid_labels = sum(1 for label in labels if label != -100)
        total_labels = len(labels)
        valid_ratio = valid_labels / total_labels if total_labels > 0 else 0
        
        # Flag samples with very few valid labels (less than 5% or less than 10 tokens)
        if valid_labels < 10 or valid_ratio < 0.05:
            few_labels_samples.append((idx, valid_labels, total_labels, valid_ratio))
            
            if valid_labels == 0:  # This should be caught by the previous check too
                continue
                
            accelerator.print(f"  ⚠️  Sample {idx}: Only {valid_labels}/{total_labels} valid labels ({valid_ratio:.1%})")
    
    # Check batches if requested
    batch_issues = []
    if check_batches:
        accelerator.print(f"\nChecking data collator batches...")
        
        data_collator = CustomHR2RDataCollator(
            tokenizer=tokenizer,
            padding=True,
            important_token_noise=training_config.get('important_token_noise', 0.0),
            normal_token_noise=training_config.get('normal_token_noise', 0.0),
        )
        
        batch_size = training_config['per_device_train_batch_size']
        gradient_accumulation_steps = training_config['gradient_accumulation_steps']
        samples_per_step = batch_size * gradient_accumulation_steps
        
        # Check first few steps worth of batches
        for step in range(min(10, total_samples // samples_per_step)):
            step_start = step * samples_per_step
            step_end = min(step_start + samples_per_step, total_samples)
            
            for batch_start in range(step_start, step_end, batch_size):
                batch_end = min(batch_start + batch_size, total_samples)
                batch_examples = [processed_train_dataset[i] for i in range(batch_start, batch_end)]
                
                try:
                    batch = data_collator(batch_examples)
                    labels = batch['labels']
                    
                    # Check if entire batch has no valid labels
                    if torch.is_tensor(labels):
                        valid_labels_in_batch = (labels != -100).sum().item()
                        total_labels_in_batch = labels.numel()
                    else:
                        valid_labels_in_batch = sum(sum(1 for label in sample_labels if label != -100) for sample_labels in labels)
                        total_labels_in_batch = sum(len(sample_labels) for sample_labels in labels)
                    
                    if valid_labels_in_batch == 0:
                        batch_issues.append((step, batch_start, batch_end))
                        accelerator.print(f"  ⚠️  Step {step}, Batch {batch_start}-{batch_end}: NO valid labels in entire batch!")
                    elif valid_labels_in_batch < 10:
                        accelerator.print(f"  ⚠️  Step {step}, Batch {batch_start}-{batch_end}: Only {valid_labels_in_batch} valid labels")
                        
                except Exception as e:
                    accelerator.print(f"  ❌ Step {step}, Batch {batch_start}-{batch_end}: Collation failed - {e}")
    
    # Summary
    accelerator.print(f"\n" + "="*60)
    accelerator.print(f"=== SCAN RESULTS ===")
    accelerator.print(f"Total samples scanned: {total_samples}")
    accelerator.print(f"Samples with ALL labels = -100: {len(problematic_samples)}")
    accelerator.print(f"Samples with very few valid labels: {len(few_labels_samples)}")
    
    if batch_issues:
        accelerator.print(f"Batches with no valid labels: {len(batch_issues)}")
    
    if problematic_samples:
        accelerator.print(f"\n🚨 CRITICAL ISSUE FOUND!")
        accelerator.print(f"Samples with all labels = -100: {problematic_samples}")
        accelerator.print(f"These samples will cause training issues!")
        
        # Calculate which steps would be affected
        batch_size = training_config['per_device_train_batch_size']
        gradient_accumulation_steps = training_config['gradient_accumulation_steps']
        samples_per_step = batch_size * gradient_accumulation_steps
        
        affected_steps = set()
        for idx in problematic_samples:
            step = idx // samples_per_step + 1  # +1 because steps start from 1
            affected_steps.add(step)
        
        accelerator.print(f"These would affect training steps: {sorted(affected_steps)}")
        
    if few_labels_samples:
        accelerator.print(f"\n⚠️  Samples with very few valid labels:")
        for idx, valid, total, ratio in few_labels_samples[:20]:  # Show first 20
            accelerator.print(f"  Sample {idx}: {valid}/{total} valid labels ({ratio:.1%})")
        if len(few_labels_samples) > 20:
            accelerator.print(f"  ... and {len(few_labels_samples) - 20} more")
    
    # Recommendations
    accelerator.print(f"\n💡 RECOMMENDATIONS:")
    if problematic_samples:
        accelerator.print(f"1. URGENT: Remove or fix samples with all labels = -100")
        accelerator.print(f"2. These samples have no learnable content")
        accelerator.print(f"3. They likely contain only prompt tokens (mask=0 in original data)")
    
    if few_labels_samples:
        accelerator.print(f"4. Consider reviewing samples with very few valid labels")
        accelerator.print(f"5. They might have very short responses or incorrect masking")
    
    # Return results for further processing
    return {
        'all_negative_samples': problematic_samples,
        'few_labels_samples': few_labels_samples,
        'batch_issues': batch_issues
    }


def create_filtered_dataset(config_path: str, problematic_indices: List[int], output_suffix: str = "_filtered"):
    """Create a filtered dataset that removes problematic samples."""
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    data_path = config['data']['train_data_path']
    output_path = data_path + output_suffix
    
    print(f"Creating filtered dataset...")
    print(f"Input: {data_path}")
    print(f"Output: {output_path}")
    print(f"Removing {len(problematic_indices)} problematic samples")
    
    from datasets import load_from_disk
    
    # Load the raw dataset (before preprocessing)
    dataset = load_from_disk(data_path)
    print(f"Original dataset size: {len(dataset)}")
    
    # Remove problematic indices
    good_indices = [i for i in range(len(dataset)) if i not in set(problematic_indices)]
    filtered_dataset = dataset.select(good_indices)
    
    print(f"Filtered dataset size: {len(filtered_dataset)}")
    print(f"Removed: {len(dataset) - len(filtered_dataset)} samples")
    
    # Save filtered dataset
    filtered_dataset.save_to_disk(output_path)
    print(f"✅ Filtered dataset saved to: {output_path}")
    
    # Create updated config
    config['data']['train_data_path'] = output_path
    config_output = config_path.replace('.yaml', '_filtered.yaml')
    
    with open(config_output, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
    
    print(f"✅ Updated config saved to: {config_output}")
    return output_path, config_output


def main():
    parser = argparse.ArgumentParser(description='Scan dataset for samples with all labels = -100')
    parser.add_argument('--config', type=str, default='script/recipes/dpsk_1.5/sft_hr2r.yaml',
                       help='Path to configuration file')
    parser.add_argument('--sample-limit', type=int, default=None,
                       help='Limit number of samples to scan (for quick testing)')
    parser.add_argument('--check-batches', action='store_true',
                       help='Also check data collator batches')
    parser.add_argument('--create-filtered', action='store_true',
                       help='Create filtered dataset removing problematic samples')
    
    args = parser.parse_args()
    
    # Scan dataset
    results = scan_all_negative_labels(args.config, args.sample_limit, args.check_batches)
    
    # Create filtered dataset if requested
    if args.create_filtered and results['all_negative_samples']:
        print(f"\n" + "="*60)
        create_filtered_dataset(args.config, results['all_negative_samples'])


if __name__ == "__main__":
    main() 