#!/usr/bin/env python3
# scripts/clear_roi_bbox.py

import os
import shutil
import argparse


def clear_roi_dirs(root: str):
    """
    Recursively find and delete all folders named 'roi_bbox'
    """
    removed = []
    for dirpath, dirnames, _ in os.walk(root):
        if 'roi_bbox' in dirnames:
            path = os.path.join(dirpath, 'roi_bbox')
            try:
                shutil.rmtree(path)
                removed.append(path)
                print(f"Removed: {path}")
            except Exception as e:
                print(f"[ERROR] Failed to delete {path}: {e}")
    if not removed:
        print("No roi_bbox directories found.")
    else:
        print(f"Successfully removed {len(removed)} roi_bbox directories.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Clean roi_bbox folders in all subdirectories")
    parser.add_argument(
        "--dataset_root", "-d",
        default='/root/autodl-tmp/dataset',
        help="Dataset root directory, e.g., /root/autodl-tmp/dataset"
    )
    args = parser.parse_args()
    clear_roi_dirs(args.dataset_root)

# =====================================

# compute_entity_stats.py

import os
import json
import argparse
import numpy as np


def collect_entity_counts(dataset_root):
    en_counts = []
    zh_counts = []
    # Recursively traverse dataset_root
    for dirpath, dirnames, filenames in os.walk(dataset_root):
        for fn in filenames:
            if fn.endswith('_en_entities.jsonl'):
                path = os.path.join(dirpath, fn)
                with open(path, 'r', encoding='utf-8') as f:
                    for line in f:
                        data = json.loads(line)
                        en_counts.append(len(data.get('entities', [])))
            elif fn.endswith('_zh_entities.jsonl'):
                path = os.path.join(dirpath, fn)
                with open(path, 'r', encoding='utf-8') as f:
                    for line in f:
                        data = json.loads(line)
                        zh_counts.append(len(data.get('entities', [])))
    return en_counts, zh_counts


def print_stats(name, counts):
    arr = np.array(counts, dtype=int)
    total = len(arr)
    zeros = int((arr == 0).sum())
    nonzeros = total - zeros

    print(f"\n--- {name} Entity Statistics ---")
    print(f"Total reports: {total}")
    print(f"Reports with 0 entities : {zeros} ({zeros / total:.2%})")
    print(f"Reports with >=1 entities: {nonzeros} ({nonzeros / total:.2%})\n")

    print(f"Min entities per report   : {arr.min()}")
    print(f"Max entities per report   : {arr.max()}")
    print(f"Mean entities per report  : {arr.mean():.2f}")
    print(f"Median entities per report: {int(np.median(arr))}\n")

    print("Percentiles:")
    for p in [25, 50, 75, 90, 95, 99]:
        print(f"  {p:>2}th percentile: {int(np.percentile(arr, p))}")

    # Simplified distribution statistics
    print("\nCount distribution (entities -> reports):")
    max_display = min(arr.max(), 10)
    for i in range(0, max_display):
        cnt = int((arr == i).sum())
        print(f"  {i:>3} -> {cnt}")
    cnt_ge = int((arr >= max_display).sum())
    print(f" >={max_display} -> {cnt_ge}")


def main():
    parser = argparse.ArgumentParser(
        description="Count entity distribution in all subdirectories for en/zh"
    )
    parser.add_argument(
        '--dataset_root', '-d', required=True,
        help='Dataset root directory, e.g., /root/autodl-tmp/dataset'
    )
    args = parser.parse_args()

    en_counts, zh_counts = collect_entity_counts(args.dataset_root)
    if not en_counts and not zh_counts:
        print("No *_en_entities.jsonl or *_zh_entities.jsonl files found")
        return

    if en_counts:
        print_stats("English Report", en_counts)
    else:
        print("No English entity files found")

    if zh_counts:
        print_stats("Chinese Report", zh_counts)
    else:
        print("No Chinese entity files found")


if __name__ == "__main__":
    main()

# =====================================

import os
import json
import torch
from torch.utils.data import Dataset
from data.dataset_utils import ImageLoader


class CoarseDataset(Dataset):
    def __init__(
            self, dataset_root, tokenizer_en, tokenizer_zh,
            max_length=128, sample_n=None,
            image_size=224, max_negatives=10
    ):
        self.tokenizer_en = tokenizer_en
        self.tokenizer_zh = tokenizer_zh
        self.loader = ImageLoader(target_size=image_size)
        self.max_length = max_length
        self.max_negatives = max_negatives
        self.samples = []

        for sub in sorted(os.listdir(dataset_root)):
            img_dir = os.path.join(dataset_root, sub, "images")
            if not os.path.isdir(img_dir):
                continue
            for fn in os.listdir(os.path.join(dataset_root, sub)):
                if not fn.endswith('_en.jsonl') and not fn.endswith('_en_entities.jsonl'):
                    continue
                en_path = os.path.join(dataset_root, sub, fn)
                zh_path = en_path.replace('_en.jsonl', '_zh.jsonl') \
                    .replace('_en_entities.jsonl', '_zh_entities.jsonl')
                if not os.path.exists(zh_path):
                    continue

                with open(en_path, encoding='utf-8') as f_en, \
                        open(zh_path, encoding='utf-8') as f_zh:
                    for line_en, line_zh in zip(f_en, f_zh):
                        d_en = json.loads(line_en)
                        d_zh = json.loads(line_zh)
                        img_path = os.path.join(img_dir, d_en['image'])
                        if not os.path.isfile(img_path):
                            continue
                        self.samples.append({
                            'img_path': img_path,
                            'positive_en': d_en['positive_caption'],
                            'negative_en': d_en['negative_captions'],
                            'short_en': d_en.get('short_caption', d_en['positive_caption']),
                            'positive_zh': d_zh['positive_caption'],
                            'negative_zh': d_zh['negative_captions'],
                            'short_zh': d_zh.get('short_caption', d_zh['positive_caption']),
                        })

        if sample_n:
            self.samples = self.samples[:sample_n]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        with torch.no_grad():
            rec = self.samples[idx]
            image = self.loader.load(rec['img_path'])

            # Positive & short, English
            pos_en = self.tokenizer_en(
                rec['positive_en'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            short_en = self.tokenizer_en(
                rec['short_en'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Negative, English
            neg_list_en = rec['negative_en'][:self.max_negatives]
            if len(neg_list_en) < self.max_negatives:
                neg_list_en = neg_list_en + [neg_list_en[0]] * (self.max_negatives - len(neg_list_en))
            neg_en = self.tokenizer_en(
                neg_list_en,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Positive & short, Chinese
            pos_zh = self.tokenizer_zh(
                rec['positive_zh'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            short_zh = self.tokenizer_zh(
                rec['short_zh'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Negative, Chinese
            neg_list_zh = rec['negative_zh'][:self.max_negatives]
            if len(neg_list_zh) < self.max_negatives:
                neg_list_zh = neg_list_zh + [neg_list_zh[0]] * (self.max_negatives - len(neg_list_zh))
            neg_zh = self.tokenizer_zh(
                neg_list_zh,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            return {
                'image': image,

                'pos_ids_en': pos_en['input_ids'].squeeze(0),
                'pos_mask_en': pos_en['attention_mask'].squeeze(0),
                'short_ids_en': short_en['input_ids'].squeeze(0),
                'short_mask_en': short_en['attention_mask'].squeeze(0),
                'neg_ids_en': neg_en['input_ids'],  # [max_negatives, L]
                'neg_mask_en': neg_en['attention_mask'],  # [max_negatives, L]

                'pos_ids_zh': pos_zh['input_ids'].squeeze(0),
                'pos_mask_zh': pos_zh['attention_mask'].squeeze(0),
                'short_ids_zh': short_zh['input_ids'].squeeze(0),
                'short_mask_zh': short_zh['attention_mask'].squeeze(0),
                'neg_ids_zh': neg_zh['input_ids'],  # [max_negatives, L]
                'neg_mask_zh': neg_zh['attention_mask'],  # [max_negatives, L]
            }


# =====================================

import os
import json
import torch
import cv2
import glob
import random
import re
from torch.utils.data import Dataset
from data.dataset_utils import safe_tokenize
import torch.nn.functional as F
from collections import defaultdict


class OptimizedROIManager:
    """Optimized ROI Manager - Adapted for real data structure"""

    def __init__(self, dataset_root):
        self.dataset_root = dataset_root
        self.roi_cache = {}
        self.domain_stats = {}
        self._build_optimized_cache()

    def _build_optimized_cache(self):
        """Build optimized ROI cache - Based on real data structure"""
        print("Building optimized ROI cache (adapted for single ROI type design)...")

        domain_mapping = {
            'Abdominal Imaging': 'Abdominal',
            'Bone and Joint Imaging': 'Bone_Joint',
            'Breast Imaging': 'Breast',
            'Cardiac Imaging': 'Cardiac',
            'Chest Imaging': 'Chest',
            'Cranial Imaging': 'Cranial',
            'Dental Imaging': 'Dental',
            'Dermatological Imaging': 'Dermatological',
            'Endoscopy Imaging': 'Endoscopy',
            'Fundus Imaging': 'Fundus',
            'Gynecological Imaging': 'Gynecological',
            'Pathology Slide Imaging': 'Pathology'
        }

        for domain in domain_mapping.keys():
            roi_dir = os.path.join(self.dataset_root, domain, 'roi_images')

            if os.path.exists(roi_dir):
                roi_files = os.listdir(roi_dir)
                normal_files = [f for f in roi_files if 'Normal' in f]
                abnormal_files = [f for f in roi_files if 'Abnormal' in f]

                # Create base mappings for each type
                normal_to_basename = {}
                abnormal_to_basename = {}

                for f in normal_files:
                    parts = f.split('_')
                    if len(parts) >= 3:
                        basename = '_'.join(parts[2:]).replace('.png', '')
                        normal_to_basename[basename] = f

                for f in abnormal_files:
                    parts = f.split('_')
                    if len(parts) >= 3:
                        basename = '_'.join(parts[2:]).replace('.png', '')
                        abnormal_to_basename[basename] = f

                self.roi_cache[domain] = {
                    'prefix': domain_mapping[domain],
                    'normal_files': normal_files,
                    'abnormal_files': abnormal_files,
                    'normal_to_basename': normal_to_basename,
                    'abnormal_to_basename': abnormal_to_basename,
                }

                normal_ratio = len(normal_files) / len(roi_files) if roi_files else 0
                abnormal_ratio = len(abnormal_files) / len(roi_files) if roi_files else 0

                self.domain_stats[domain] = {
                    'total_roi_files': len(roi_files),
                    'normal_count': len(normal_files),
                    'abnormal_count': len(abnormal_files),
                    'normal_ratio': normal_ratio,
                    'abnormal_ratio': abnormal_ratio,
                }

                print(
                    f"  {domain}: {len(roi_files)} ROI, Normal:{len(normal_files)}({normal_ratio:.1%}), Abnormal:{len(abnormal_files)}({abnormal_ratio:.1%})")

    def find_optimal_roi_match(self, image_name, domain, is_no_finding):
        """Optimized ROI matching - Adapted for single ROI type design"""
        if domain not in self.roi_cache:
            return None, 'no_roi_dir', 0.0

        cache = self.roi_cache[domain]
        prefix = cache['prefix']
        roi_dir = os.path.join(self.dataset_root, domain, 'roi_images')
        base_name = os.path.splitext(image_name)[0]

        # Strategy 1: Prioritize expected type (most ideal)
        target_type = 'Normal' if is_no_finding else 'Abnormal'
        target_mapping = cache['normal_to_basename'] if is_no_finding else cache['abnormal_to_basename']

        if base_name in target_mapping:
            roi_filename = target_mapping[base_name]
            roi_path = os.path.join(roi_dir, roi_filename)
            return roi_path, f'{target_type.lower()}_perfect_match', 1.0

        # Strategy 2: Find opposite type (dataset design feature)
        opposite_type = 'Abnormal' if is_no_finding else 'Normal'
        opposite_mapping = cache['abnormal_to_basename'] if is_no_finding else cache['normal_to_basename']

        if base_name in opposite_mapping:
            roi_filename = opposite_mapping[base_name]
            roi_path = os.path.join(roi_dir, roi_filename)
            # Note: Weight shouldn't be too low in this case, as this is a dataset design feature, not an error
            return roi_path, f'{opposite_type.lower()}_data_design', 0.8

        # Strategy 3: Intelligent similarity matching
        similarity_match = self._find_similar_roi(base_name, cache, target_type)
        if similarity_match:
            roi_path, similarity_score = similarity_match
            return roi_path, f'{target_type.lower()}_similarity_match', 0.6 * similarity_score

        # Strategy 4: Random high-quality ROI as fallback
        target_files = cache['normal_files'] if is_no_finding else cache['abnormal_files']
        if target_files:
            # Choose files with smaller numbers in filename (usually better quality)
            sorted_files = sorted(target_files)
            selected_file = sorted_files[len(sorted_files) // 4]  # Choose from first 25% of files
            roi_path = os.path.join(roi_dir, selected_file)
            return roi_path, f'{target_type.lower()}_quality_fallback', 0.4

        # Strategy 5: Final fallback
        opposite_files = cache['abnormal_files'] if is_no_finding else cache['normal_files']
        if opposite_files:
            selected_file = sorted(opposite_files)[0]  # Choose first one (usually smallest number)
            roi_path = os.path.join(roi_dir, selected_file)
            return roi_path, f'{opposite_type.lower()}_final_fallback', 0.3

        return None, 'no_roi_available', 0.0

    def _find_similar_roi(self, base_name, cache, target_type):
        """Find similar ROI files"""
        target_mapping = cache['normal_to_basename'] if target_type == 'Normal' else cache['abnormal_to_basename']
        roi_dir = os.path.join(self.dataset_root, cache.get('domain', ''), 'roi_images')

        # Extract numeric parts for similarity matching
        base_numbers = re.findall(r'\d+', base_name)
        if not base_numbers:
            return None

        base_num = int(base_numbers[-1])  # Use the last number
        best_match = None
        best_score = 0

        for basename, filename in target_mapping.items():
            candidate_numbers = re.findall(r'\d+', basename)
            if candidate_numbers:
                candidate_num = int(candidate_numbers[-1])
                # The closer the numbers, the higher the similarity
                diff = abs(base_num - candidate_num)
                if diff < 100:  # Within reasonable range
                    score = max(0.1, 1.0 - diff / 100.0)
                    if score > best_score:
                        best_score = score
                        best_match = os.path.join(roi_dir, filename)

        if best_match:
            return best_match, best_score
        return None


class EnglishMedicalDatasetFast(Dataset):
    """
    Ultra-simplified English medical dataset - Fastest loading speed
    Simplified all logic according to user requirements:
    1. No Finding -> negative_captions = []
    2. Has Finding -> negative_captions has content (not necessarily 5, handled robustly)
    3. ROI logic: Normal corresponds to No Finding, Abnormal corresponds to Has Finding
    """

    def __init__(self, dataset_root, tokenizer, image_size=224, max_text_length=128, sample_ratio=1.0):
        self.tokenizer = tokenizer
        self.image_size = image_size
        self.max_text_length = max_text_length
        self.dataset_root = dataset_root

        print(f"Starting fast dataset loading (ratio: {sample_ratio * 100:.1f}%)")

        # Optimized ROI manager
        self.optimized_roi_manager = None

        # Fast load all samples
        self.samples = []
        self._load_all_samples()

        # Sample by ratio
        if 0 < sample_ratio < 1.0:
            original_count = len(self.samples)
            target_count = int(original_count * sample_ratio)
            random.seed(42)
            self.samples = random.sample(self.samples, target_count)
            print(f"Sampled: {original_count} -> {target_count} ({sample_ratio * 100:.1f}%)")

        print(f"Dataset loading complete: {len(self.samples)} samples")

    def _load_all_samples(self):
        """Fast load all samples - only read JSON, no validation"""
        # 12 medical domains
        domains = [
            'Abdominal Imaging', 'Bone and Joint Imaging', 'Breast Imaging',
            'Cardiac Imaging', 'Chest Imaging', 'Cranial Imaging',
            'Dental Imaging', 'Dermatological Imaging', 'Endoscopy Imaging',
            'Fundus Imaging', 'Gynecological Imaging', 'Pathology Slide Imaging'
        ]

        for domain in domains:
            domain_path = os.path.join(self.dataset_root, domain)
            if not os.path.exists(domain_path):
                continue

            # Find *_region_en.jsonl files
            jsonl_files = glob.glob(os.path.join(domain_path, "*_region_en.jsonl"))

            for file_path in jsonl_files:
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        for line_num, line in enumerate(f, 1):
                            if line.strip():
                                try:
                                    data = json.loads(line)
                                    sample = self._process_sample(data, domain, file_path, line_num)
                                    if sample:
                                        self.samples.append(sample)
                                except:
                                    continue  # Skip error lines
                except:
                    continue  # Skip error files

    def _process_sample(self, data, domain, file_path, line_num):
        """Process single sample - ultra-simplified logic"""
        try:
            image_name = data.get('image')
            if not image_name:
                return None

            report = data.get('report', '')
            region_caption = data.get('region_caption', '')
            negative_captions = data.get('negative_captions', [])

            # Determine if No Finding
            is_no_finding = ('no finding' in region_caption.lower().strip())

            # Process negative samples according to user logic
            if is_no_finding:
                # No Finding case: negative_captions should be empty
                processed_negatives = []
            else:
                # Has Finding case: negative_captions has content, handle robustly
                processed_negatives = self._process_negatives_robust(negative_captions)

            # Build sample
            sample = {
                'image_name': image_name,
                'image_path': os.path.join(self.dataset_root, domain, 'images', image_name),
                'report': report,
                'region_caption': region_caption,
                'negative_captions': processed_negatives,
                'is_no_finding': is_no_finding,
                'domain': domain
            }

            return sample

        except Exception as e:
            return None

    def _process_negatives_robust(self, negative_captions):
        """Robust processing of negative samples - not necessarily 5"""
        if not negative_captions:
            return []

        # Ensure it's a list
        if isinstance(negative_captions, str):
            negative_captions = [negative_captions]
        elif not isinstance(negative_captions, list):
            return []

        # Filter empty and invalid items
        valid_negatives = []
        for item in negative_captions:
            if item and isinstance(item, str) and item.strip():
                valid_negatives.append(item.strip())

        # If no valid negative samples, return empty list
        if not valid_negatives:
            return []

        # Robust handling: don't force 5 items
        # If less than 3, repeat fill; if more than 8, truncate
        if len(valid_negatives) < 3:
            while len(valid_negatives) < 3:
                valid_negatives.append(valid_negatives[0])
        elif len(valid_negatives) > 8:
            valid_negatives = valid_negatives[:8]

        return valid_negatives

    def _find_roi_path_optimized(self, image_name, domain, is_no_finding):
        """Fully optimized ROI path finding"""
        if not hasattr(self, 'optimized_roi_manager') or self.optimized_roi_manager is None:
            self.optimized_roi_manager = OptimizedROIManager(self.dataset_root)

        return self.optimized_roi_manager.find_optimal_roi_match(image_name, domain, is_no_finding)

    def _load_image_fast(self, image_path):
        """Fast image loading"""
        try:
            img = cv2.imread(image_path)
            if img is None:
                return torch.zeros(3, self.image_size, self.image_size)

            # Fast processing: assume all are RGB, convert directly
            if len(img.shape) == 3:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # Fast resize
            if img.shape[:2] != (self.image_size, self.image_size):
                img = cv2.resize(img, (self.image_size, self.image_size))

            # Convert to tensor
            img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
            return torch.clamp(img_tensor, 0.0, 1.0)

        except:
            return torch.zeros(3, self.image_size, self.image_size)

    def _tokenize_robust(self, text, max_length):
        """Robust text tokenization"""
        if not text or not isinstance(text, str):
            text = ""

        try:
            result = self.tokenizer(
                text,
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            return result
        except:
            # Tokenization failed, return result for empty text
            return self.tokenizer(
                "",
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

    def _process_negatives_for_training(self, negative_captions):
        """Process negative samples for training - Fixed: ensure all samples return same shape"""
        # Fixed return 5 negative samples, ensure tensor shape consistency
        fixed_negatives = []

        if not negative_captions:
            # No Finding case: fill 5 negative samples with empty strings
            fixed_negatives = [""] * 5
        else:
            # Has Finding case: ensure 5 negative samples
            if len(negative_captions) == 0:
                fixed_negatives = [""] * 5
            elif len(negative_captions) < 5:
                # Less than 5, repeat fill
                fixed_negatives = negative_captions.copy()
                while len(fixed_negatives) < 5:
                    fixed_negatives.append(fixed_negatives[0] if fixed_negatives else "")
            else:
                # More than 5, truncate
                fixed_negatives = negative_captions[:5]

        # Tokenization processing - ensure return [5, 64] tensor
        try:
            result = self.tokenizer(
                fixed_negatives,
                max_length=64,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            return result
        except Exception as e:
            print(f"Warning: Negative sample tokenization failed, using default values: {e}")
            # Tokenization failed, return result for 5 empty texts
            default_negatives = [""] * 5
            result = self.tokenizer(
                default_negatives,
                max_length=64,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            return result

    def _get_default_sample(self):
        """Get default sample (error handling)"""
        return {
            'image': torch.zeros(3, self.image_size, self.image_size),
            'roi': torch.zeros(3, self.image_size, self.image_size),
            'roi_type': 'failed',
            'roi_weight': torch.tensor(0.1, dtype=torch.float32),
            'has_roi': False,
            'is_no_finding': True,
            'report_ids': torch.zeros(self.max_text_length, dtype=torch.long),
            'report_mask': torch.zeros(self.max_text_length, dtype=torch.long),
            'region_ids': torch.zeros(64, dtype=torch.long),
            'region_mask': torch.zeros(64, dtype=torch.long),
            'negative_ids': torch.zeros(5, 64, dtype=torch.long),
            'negative_mask': torch.zeros(5, 64, dtype=torch.long),
            'domain': 'Unknown',
            'image_path': '',
            'roi_path': ''
        }

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """Final optimized version of sample retrieval method"""
        try:
            sample = self.samples[idx]

            # Load main image
            main_img = self._load_image_fast(sample['image_path'])

            # Optimized ROI matching
            roi_path, roi_strategy, roi_weight = self._find_roi_path_optimized(
                sample['image_name'],
                sample['domain'],
                sample['is_no_finding']
            )

            # Load ROI image
            if roi_path and os.path.exists(roi_path):
                try:
                    roi_img = self._load_image_fast(roi_path)
                    has_roi = True

                    # Only output debug info for first 3 samples
                    if idx < 3:
                        print(f"Sample {idx}: {sample['image_name']} -> {os.path.basename(roi_path)} "
                              f"({roi_strategy}, weight:{roi_weight:.2f})")

                except Exception as e:
                    roi_img = main_img.clone()
                    has_roi = False
                    roi_strategy = 'load_failed'
                    roi_weight = 0.1
            else:
                roi_img = main_img.clone()
                has_roi = False
                roi_strategy = 'no_roi_found'
                roi_weight = 0.1

            # Text processing
            report_tokens = self._tokenize_robust(sample['report'], self.max_text_length)
            region_tokens = self._tokenize_robust(sample['region_caption'], 64)

            # Negative sample processing (ensure consistency)
            negative_captions = sample['negative_captions']
            if not negative_captions or len(negative_captions) == 0:
                fixed_negatives = [""] * 5
            else:
                fixed_negatives = list(negative_captions)
                while len(fixed_negatives) < 5:
                    fixed_negatives.append(fixed_negatives[0] if fixed_negatives else "")
                fixed_negatives = fixed_negatives[:5]

            negative_tokens = self.tokenizer(
                fixed_negatives,
                max_length=64,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            return {
                'image': main_img,
                'roi': roi_img,
                'roi_type': roi_strategy,
                'roi_weight': torch.tensor(roi_weight, dtype=torch.float32),
                'has_roi': has_roi,
                'is_no_finding': sample['is_no_finding'],

                'report_ids': report_tokens['input_ids'].squeeze(0),
                'report_mask': report_tokens['attention_mask'].squeeze(0),
                'region_ids': region_tokens['input_ids'].squeeze(0),
                'region_mask': region_tokens['attention_mask'].squeeze(0),
                'negative_ids': negative_tokens['input_ids'],
                'negative_mask': negative_tokens['attention_mask'],

                'domain': sample['domain'],
                'image_path': sample['image_path'],
                'roi_path': roi_path or ""
            }

        except Exception as e:
            print(f"Sample {idx} processing error: {e}")
            return self._get_default_sample()


# Simplified dataset factory function
def create_fast_dataset(dataset_root, tokenizer, sample_ratio=1.0, **kwargs):
    """
    Convenient function to create fast dataset

    Args:
        dataset_root: Dataset root directory
        tokenizer: Tokenizer
        sample_ratio: Dataset ratio (0.01=1%, 0.1=10%, 1.0=100%)
        **kwargs: Other parameters

    Returns:
        Dataset instance
    """
    return EnglishMedicalDatasetFast(
        dataset_root=dataset_root,
        tokenizer=tokenizer,
        sample_ratio=sample_ratio,
        **kwargs
    )


# Test final optimization solution
def test_final_optimization(dataset_root):
    """Test final optimization solution"""
    print("Testing final optimized ROI solution...")

    manager = OptimizedROIManager(dataset_root)

    # Test samples
    test_cases = [
        ('img_00001', 'Bone and Joint Imaging', True),
        ('img_00001', 'Bone and Joint Imaging', False),
        ('img_00005', 'Gynecological Imaging', True),
        ('img_00005', 'Gynecological Imaging', False),
        ('img_00010', 'Pathology Slide Imaging', True),
        ('img_00010', 'Pathology Slide Imaging', False),
    ]

    print(f"\nFinal optimized ROI matching test:")
    strategy_counts = defaultdict(int)

    for image_name, domain, is_no_finding in test_cases:
        roi_path, strategy, weight = manager.find_optimal_roi_match(image_name, domain, is_no_finding)
        finding_type = "No Finding" if is_no_finding else "Has Finding"
        status = "Found" if roi_path else "Not Found"

        strategy_counts[strategy] += 1

        print(f"{domain} - {image_name} ({finding_type}): {status} - {strategy} (weight:{weight:.1f})")

    print(f"\nStrategy distribution: {dict(strategy_counts)}")

    # Evaluate expected effect
    high_quality_strategies = ['normal_perfect_match', 'abnormal_perfect_match',
                               'normal_data_design', 'abnormal_data_design']
    high_quality_count = sum(strategy_counts[s] for s in high_quality_strategies)
    total_count = sum(strategy_counts.values())

    expected_quality_ratio = high_quality_count / total_count if total_count > 0 else 0
    print(f"Expected ROI quality ratio: {expected_quality_ratio:.1%}")

    if expected_quality_ratio >= 0.8:
        print("Final optimization solution expected effect excellent!")
    else:
        print("May need further tuning")


# Test functions
def test_fast_dataset(dataset_root, tokenizer_path, sample_ratio=0.01):
    """Test fast dataset loading"""
    from transformers import AutoTokenizer
    import time

    print("=" * 60)
    print(f"Testing fast dataset loading (ratio: {sample_ratio * 100:.1f}%)")
    print("=" * 60)

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    print(f"Tokenizer loaded successfully")

    # Create dataset
    print("Creating dataset...")
    start_time = time.time()

    dataset = create_fast_dataset(
        dataset_root=dataset_root,
        tokenizer=tokenizer,
        sample_ratio=sample_ratio
    )

    load_time = time.time() - start_time
    print(f"Dataset creation complete, time taken: {load_time:.2f}s")

    # Test sample loading
    print("Testing sample loading...")
    start_time = time.time()

    sample = dataset[0]

    sample_time = time.time() - start_time
    print(f"Sample loading complete, time taken: {sample_time * 1000:.1f}ms")

    # Print sample info
    print("\nSample information:")
    print(f"  Image shape: {sample['image'].shape}")
    print(f"  ROI shape: {sample['roi'].shape}")
    print(f"  Is No Finding: {sample['is_no_finding']}")
    print(f"  ROI type: {sample['roi_type']}")
    print(f"  ROI weight: {sample['roi_weight']:.2f}")
    print(f"  Report length: {sample['report_ids'].shape}")
    print(f"  Region description length: {sample['region_ids'].shape}")
    print(f"  Negative sample count: {sample['negative_ids'].shape[0]}")
    print(f"  Domain: {sample['domain']}")

    # Count No Finding ratio
    no_finding_count = sum(1 for i in range(min(100, len(dataset))) if dataset[i]['is_no_finding'])
    print(f"\nNo Finding ratio in first 100 samples: {no_finding_count}%")

    print("=" * 60)
    print("Test completed!")
    print("=" * 60)


if __name__ == "__main__":
    # Usage example
    dataset_root = "/root/autodl-tmp/dataset"
    tokenizer_path = "/root/autodl-tmp/pubmedbert-base-uncased-abstract-local"

    # Test 1% data
    test_fast_dataset(dataset_root, tokenizer_path, sample_ratio=0.01)
