#!/usr/bin/env python3
"""
Precisely locate samples returning None
"""

import os
import sys
import torch
from pathlib import Path

# Add project root directory to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))

from transformers import AutoTokenizer
from data.dataset_english_medical import EnglishMedicalDatasetFast
from torch.utils.data import DataLoader, Subset
from utils.yaml_loader import load_yaml


def check_sample_fields(sample, idx):
    """Check all fields of a single sample"""
    none_fields = []

    print(f"\nSample {idx} detailed check:")
    for key, value in sample.items():
        print(f"  {key}: {type(value)}")

        if value is None:
            none_fields.append(key)
            print(f"    ❌ Field '{key}' is None!")
        elif isinstance(value, torch.Tensor):
            print(f"    shape: {value.shape}, dtype: {value.dtype}")
            if value.numel() == 0:
                none_fields.append(f"{key}_empty")
                print(f"    ❌ Field '{key}' is empty tensor!")
            if torch.isnan(value).any():
                print(f"    ❌ Field '{key}' contains NaN values!")
            if torch.isinf(value).any():
                print(f"    ❌ Field '{key}' contains Inf values!")
        elif isinstance(value, (list, tuple)):
            print(f"    length: {len(value)}")
            if len(value) == 0:
                print(f"    ⚠ Field '{key}' is empty list")
        else:
            print(f"    value: {repr(value)}")

    return none_fields


def simulate_training_order(dataset, batch_size=16):
    """Simulate training DataLoader order to identify problematic batch 32"""

    print(f"Simulating training order, batch size={batch_size}")
    print(f"Dataset size: {len(dataset)}")

    # Create DataLoader with same settings as training
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,  # Note: Using shuffle, so order may vary between runs
        num_workers=0,
        drop_last=True,
        pin_memory=False
    )

    print(f"DataLoader created successfully, total batches: {len(dataloader)}")

    for batch_idx, batch in enumerate(dataloader):
        print(f"\nBatch {batch_idx + 1}:")

        # Check each field in the batch
        batch_has_none = False
        for key, value in batch.items():
            if value is None:
                print(f"  ❌ Batch field '{key}' is None!")
                batch_has_none = True
            elif isinstance(value, torch.Tensor):
                print(f"  {key}: shape={value.shape}")
                if torch.isnan(value).any():
                    print(f"  ❌ Batch field '{key}' contains NaN values!")
                    batch_has_none = True
            elif isinstance(value, (list, tuple)):
                print(f"  {key}: length={len(value)}")
                # Check if list contains None values
                none_count = sum(1 for item in value if item is None)
                if none_count > 0:
                    print(f"  ❌ Batch field '{key}' contains {none_count} None values!")
                    batch_has_none = True

        if batch_has_none:
            print(f"❌ Batch {batch_idx + 1} contains None values!")
            return batch_idx + 1

        # If we reach 35 batches without issues, it might be due to randomness
        if batch_idx >= 35:  # Check first 36 batches
            print(f"First {batch_idx + 1} batches are normal")
            break

    return None


def check_specific_samples(dataset, start_idx=0, num_samples=100):
    """Check specific range of samples"""

    print(f"\nChecking samples {start_idx} to {start_idx + num_samples - 1}")

    problem_samples = []

    for i in range(start_idx, min(start_idx + num_samples, len(dataset))):
        try:
            sample = dataset[i]
            none_fields = check_sample_fields(sample, i)

            if none_fields:
                print(f"❌ Sample {i} has problematic fields: {none_fields}")
                problem_samples.append((i, none_fields))

                # If problematic sample found, check raw data
                if i < len(dataset.samples):
                    raw_sample = dataset.samples[i]
                    print(f"Raw sample data:")
                    for key, value in raw_sample.items():
                        print(f"    {key}: {type(value)} = {repr(value)}")

        except Exception as e:
            print(f"❌ Sample {i} failed to load: {e}")
            problem_samples.append((i, f"Exception: {e}"))

    return problem_samples


def test_collate_with_specific_samples(dataset, problematic_indices):
    """Test collate process with specific problematic samples"""

    if not problematic_indices:
        print("No problematic samples found, cannot test collate")
        return

    print(f"\nTesting collate process with problematic samples...")

    from torch.utils.data._utils.collate import default_collate

    for idx in problematic_indices[:3]:  # Only test first 3 problematic samples
        try:
            print(f"\nTesting sample {idx}:")
            sample = dataset[idx]

            # Try to create single-sample batch
            batch = [sample]
            print("Single-sample batch created successfully")

            # Try to collate
            collated = default_collate(batch)
            print("Single-sample collate successful")

        except Exception as e:
            print(f"Sample {idx} collate failed: {e}")
            import traceback
            traceback.print_exc()


def main():
    """Main function"""

    print("🔍 Precisely locating None sample issues...")

    # Load configuration and dataset
    cfg = load_yaml('configs/english_medical.yaml')

    tokenizer_path = "/root/autodl-tmp/pubmedbert-base-uncased-abstract-local"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    dataset = EnglishMedicalDatasetFast(
        dataset_root=cfg['dataset']['root'],
        tokenizer=tokenizer,
        sample_ratio=0.01  # Use same ratio as training
    )

    print(f"Dataset loaded: {len(dataset)} samples")

    # Method 1: Check initial samples for obvious None fields
    print("\n" + "=" * 60)
    print("Method 1: Checking first 100 samples")
    print("=" * 60)

    problem_samples = check_specific_samples(dataset, 0, 100)

    if problem_samples:
        print(f"\nFound {len(problem_samples)} problematic samples:")
        for idx, issue in problem_samples:
            print(f"  Sample {idx}: {issue}")

        # Test collate with these problematic samples
        problematic_indices = [idx for idx, _ in problem_samples if isinstance(idx, int)]
        test_collate_with_specific_samples(dataset, problematic_indices)
    else:
        print("First 100 samples are normal")

    # Method 2: Simulate training process
    print("\n" + "=" * 60)
    print("Method 2: Simulating training DataLoader")
    print("=" * 60)

    failed_batch = simulate_training_order(dataset, batch_size=16)

    if failed_batch:
        print(f"Problem found at batch {failed_batch}")
    else:
        print("First 36 batches are normal, issue might be due to random shuffle")

    # Method 3: If not found yet, check more samples
    if not problem_samples and not failed_batch:
        print("\n" + "=" * 60)
        print("Method 3: Expanding check range")
        print("=" * 60)

        # Check more samples
        for start in range(100, min(500, len(dataset)), 100):
            print(f"\nChecking samples {start}-{start + 99}")
            batch_problems = check_specific_samples(dataset, start, 100)
            if batch_problems:
                problem_samples.extend(batch_problems)
                break


if __name__ == "__main__":
    main()