"""
Ultra-simplified English Medical MedCLIP training script
Using the fastest dataset loading and training process
"""

import os
import sys
import torch
import warnings
import argparse
import time
from pathlib import Path

# Add project root directory to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

from utils.yaml_loader import load_yaml
from utils.logger import UnifiedLogger
from train.trainer_english_medical import EnglishMedicalTrainer
from models.model_factory import create_model_from_config
from data.dataset_factory import (
    create_dataset_from_config_simple,
    create_dataset_with_preset_simple,
    print_dataset_summary_simple,
    validate_dataset_path,
    estimate_loading_time
)
from transformers import AutoTokenizer

# Ignore warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

def setup_environment():
    """Set up training environment"""
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.enabled = True

    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

def load_tokenizer_fast(cfg):
    """Quickly load tokenizer"""
    print("🚀 Loading tokenizer...")

    model_cfg = cfg.get('model', {})
    text_cfg = model_cfg.get('text', {})
    pretrained_path = text_cfg.get('pretrained_path')

    # Local paths take priority
    local_paths = [
        "/root/autodl-tmp/pubmedbert-base-uncased-abstract-local",
        pretrained_path
    ]

    for path in local_paths:
        if path and os.path.exists(path):
            print(f"Using local tokenizer: {path}")
            tokenizer = AutoTokenizer.from_pretrained(path)
            print(f"✅ Tokenizer loaded successfully (vocabulary size: {len(tokenizer)})")
            return tokenizer

    # Download online
    print("Local files not found, downloading from HuggingFace...")
    tokenizer = AutoTokenizer.from_pretrained(
        'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
    )
    print(f"✅ Online tokenizer loaded successfully (vocabulary size: {len(tokenizer)})")
    return tokenizer

def create_dataset_with_preset_simple_wrapper(cfg, preset_name, tokenizer):
    """Wrapper function for creating dataset with preset"""
    if preset_name:
        return create_dataset_with_preset_simple(cfg, preset_name, tokenizer)
    else:
        return create_dataset_from_config_simple(cfg, tokenizer)

def print_simple_summary(cfg, args):
    """Print simplified training summary"""
    print("\n" + "="*80)
    print("🚀 Ultra-fast English Medical MedCLIP Training")
    print("="*80)

    # Dataset information
    dataset_cfg = cfg.get('dataset', {})
    sample_ratio = dataset_cfg.get('sample_ratio', 1.0)

    if args.preset:
        preset_cfg = cfg.get('presets', {}).get(args.preset, {})
        preset_dataset = preset_cfg.get('dataset', {})
        sample_ratio = preset_dataset.get('sample_ratio', sample_ratio)

    if args.sample_ratio:
        sample_ratio = args.sample_ratio

    print(f"📁 Dataset: {dataset_cfg.get('root')}")
    print(f"📊 Data ratio: {sample_ratio*100:.1f}%")
    print(f"🖼️  Image size: {dataset_cfg.get('image_size', 224)}")
    print(f"📝 Text length: {dataset_cfg.get('max_text_length', 128)}")

    # Training information
    training_cfg = cfg.get('training', {})
    print(f"🏋️  Batch size: {training_cfg.get('batch_size', 16)}")
    print(f"📈 Learning rate: {training_cfg.get('learning_rate', 2e-4)}")

    # Hardware information
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name()
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"🖥️  GPU: {gpu_name} ({gpu_memory:.1f}GB)")

    print("="*80)

def benchmark_loading_speed(dataset_root, tokenizer):
    """Benchmark dataset loading speed"""
    from data.dataset_english_medical_fast import create_fast_dataset

    print("\n🏁 Dataset loading speed benchmark")
    print("-" * 50)

    ratios = [0.01, 0.1]  # 1%, 10% (skip 100% data test)

    for ratio in ratios:
        print(f"Testing {ratio*100:.0f}% of data...")

        # Estimate time
        estimate = estimate_loading_time(dataset_root, ratio)
        if estimate:
            print(f"  📊 Estimated samples: {estimate['estimated_samples']:,}")
            print(f"  ⏱️  Estimated time: {estimate['time_range']}")

        start_time = time.time()
        dataset = create_fast_dataset(
            dataset_root=dataset_root,
            tokenizer=tokenizer,
            sample_ratio=ratio
        )
        load_time = time.time() - start_time

        # Test sample loading
        start_time = time.time()
        sample = dataset[0]
        sample_time = time.time() - start_time

        print(f"  ✅ Actual samples: {len(dataset):,}")
        print(f"  ⏱️  Actual loading time: {load_time:.2f} seconds")
        print(f"  🔍 Single sample time: {sample_time*1000:.1f} milliseconds")
        print(f"  🚀 Loading speed: {len(dataset)/load_time:.0f} samples/second")
        print()

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description='Ultra-fast English Medical MedCLIP Training')
    parser.add_argument('--config', '-c', type=str,
                       default='configs/english_medical_simple.yaml',
                       help='Path to configuration file')
    parser.add_argument('--preset', type=str, default=None,
                       choices=['quick_test', 'small_test', 'full_train'],
                       help='Use preset configuration')
    parser.add_argument('--sample-ratio', type=float, default=None,
                       help='Dataset ratio (0.01=1%, 0.1=10%, 1.0=100%)')
    parser.add_argument('--benchmark', action='store_true',
                       help='Run loading speed benchmark')
    parser.add_argument('--test-dataset', action='store_true',
                       help='Only test dataset loading')
    parser.add_argument('--resume', type=str, default=None,
                       help='Path to checkpoint for resuming training')

    args = parser.parse_args()

    # Set up environment
    setup_environment()

    # Load configuration
    print(f"📄 Loading configuration: {args.config}")
    if not os.path.exists(args.config):
        raise FileNotFoundError(f"Configuration file not found: {args.config}")

    cfg = load_yaml(args.config)

    # Command line override
    if args.sample_ratio is not None:
        cfg['dataset']['sample_ratio'] = args.sample_ratio

    # Print summary
    print_simple_summary(cfg, args)

    # Print dataset summary and validation
    print("\n🔍 Validating dataset...")
    try:
        validate_dataset_path(cfg['dataset']['root'])
        print_dataset_summary_simple(cfg['dataset']['root'])
    except Exception as e:
        print(f"❌ Dataset validation failed: {e}")
        return

    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🖥️  Device: {device}")

    # Create output directories
    os.makedirs(cfg['output_dir'], exist_ok=True)
    os.makedirs(cfg['log_dir'], exist_ok=True)

    # Load tokenizer
    tokenizer = load_tokenizer_fast(cfg)

    # Benchmark
    if args.benchmark:
        benchmark_loading_speed(cfg['dataset']['root'], tokenizer)
        return

    # Create dataset
    print("\n🚀 Creating dataset...")
    start_time = time.time()

    dataset = create_dataset_with_preset_simple_wrapper(cfg, args.preset, tokenizer)

    load_time = time.time() - start_time
    print(f"✅ Dataset created: {len(dataset)} samples, time taken {load_time:.2f} seconds")

    # Only test dataset
    if args.test_dataset:
        from data.dataset_english_medical_fast import create_fast_dataset

        print("\n🔍 Testing dataset samples...")

        # Create dataset directly for testing
        test_dataset = create_fast_dataset(
            dataset_root=cfg['dataset']['root'],
            tokenizer=tokenizer,
            sample_ratio=0.01  # Use only 1% data for testing
        )

        # Test first 5 samples
        for i in range(min(5, len(test_dataset))):
            start_time = time.time()
            sample = test_dataset[i]
            sample_time = time.time() - start_time

            print(f"Sample {i+1}:")
            print(f"  ⏱️  Loading time: {sample_time*1000:.1f}ms")
            print(f"  🖼️  Image shape: {sample['image'].shape}")
            print(f"  🔍 ROI type: {sample['roi_type']}")
            print(f"  ❌ No Finding: {sample['is_no_finding']}")
            print(f"  📝 Negative samples count: {sample['negative_ids'].shape[0]}")
            print(f"  🏥 Domain: {sample['domain']}")

        # Statistical information
        no_finding_count = 0
        roi_count = 0
        for i in range(min(100, len(test_dataset))):
            sample = test_dataset[i]
            if sample['is_no_finding']:
                no_finding_count += 1
            if sample['has_roi']:
                roi_count += 1

        print(f"\n📊 Statistical information (first 100 samples):")
        print(f"  No Finding ratio: {no_finding_count}%")
        print(f"  ROI existence ratio: {roi_count}%")
        return

    # Create model
    print("\n🤖 Creating model...")
    model, model_manager = create_model_from_config(cfg, model_type='english_medclip')

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"✅ Model created")
    print(f"  📊 Total parameters: {total_params:,}")
    print(f"  🎯 Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")

    # Initialize logger
    logger = UnifiedLogger(log_dir=cfg['log_dir'], name='english_medical_simple')

    # Create trainer
    print("\n🏋️  Creating trainer...")