import os
import torch

def fix_image_token_mismatch_comprehensive(input_ids, pixel_values, config, **kwargs):
    """
    全面修复图像token数量不匹配问题
    同时处理相关的labels, attention_mask等tensor以避免偏移问题
    """
    try:
        image_token_index = getattr(config, 'image_token_index', 128256)
        
        # 计算期望的图像特征数量
        if pixel_values.dim() == 5:  # [batch, num_patches, channels, height, width]
            expected_features = pixel_values.shape[1]
        else:
            print(f"[FIX] Unexpected pixel_values shape: {pixel_values.shape}")
            return None
        
        # 获取相关的tensor
        labels = kwargs.get('labels')
        attention_mask = kwargs.get('attention_mask')
        position_ids = kwargs.get('position_ids')
        
        # 统计当前图像token数量
        batch_size = input_ids.shape[0]
        result = {}
        result['input_ids'] = input_ids.clone()
        
        if labels is not None:
            result['labels'] = labels.clone()
        if attention_mask is not None:
            result['attention_mask'] = attention_mask.clone()
        if position_ids is not None:
            result['position_ids'] = position_ids.clone()
        
        any_fixed = False
        
        for batch_idx in range(batch_size):
            current_tokens = input_ids[batch_idx]
            image_token_positions = (current_tokens == image_token_index).nonzero(as_tuple=True)[0]
            current_image_tokens = len(image_token_positions)
            
            print(f"[FIX] Batch {batch_idx}: Current image tokens: {current_image_tokens}, Expected: {expected_features}")
            
            if current_image_tokens != expected_features:
                difference = current_image_tokens - expected_features
                
                if difference > 0:
                    # 需要移除多余的token
                    print(f"[FIX] Removing {difference} excess image tokens")
                    positions_to_remove = image_token_positions[-difference:].tolist()  # 移除最后的几个
                    
                    # 对所有相关tensor应用相同的移除操作
                    for tensor_name in ['input_ids', 'labels', 'attention_mask', 'position_ids']:
                        if tensor_name in result:
                            tensor = result[tensor_name]
                            current_seq = tensor[batch_idx]
                            
                            # 创建mask移除指定位置
                            mask = torch.ones(current_seq.shape[0], dtype=torch.bool, device=current_seq.device)
                            mask[positions_to_remove] = False
                            filtered_seq = current_seq[mask]
                            
                            # 填充到原始长度
                            padding_size = current_seq.shape[0] - filtered_seq.shape[0]
                            if padding_size > 0:
                                if tensor_name == 'input_ids':
                                    pad_value = getattr(config, 'pad_token_id', 128257)
                                elif tensor_name == 'labels':
                                    pad_value = -100  # ignore label
                                elif tensor_name == 'attention_mask':
                                    pad_value = 0
                                elif tensor_name == 'position_ids':
                                    pad_value = filtered_seq[-1].item() + 1 if len(filtered_seq) > 0 else 0
                                    
                                if tensor_name == 'position_ids':
                                    padding = torch.arange(
                                        pad_value, pad_value + padding_size,
                                        device=filtered_seq.device, dtype=filtered_seq.dtype
                                    )
                                else:
                                    padding = torch.full((padding_size,), pad_value,
                                                       device=filtered_seq.device, dtype=filtered_seq.dtype)
                                
                                new_seq = torch.cat([filtered_seq, padding])
                            else:
                                new_seq = filtered_seq
                            
                            result[tensor_name][batch_idx] = new_seq
                    
                    any_fixed = True
                    
                elif difference < 0:
                    # 需要添加缺失的token
                    tokens_to_add = -difference
                    print(f"[FIX] Adding {tokens_to_add} missing image tokens")
                    
                    if len(image_token_positions) > 0:
                        # 在第一个图像token之后插入
                        insert_pos = image_token_positions[0].item() + 1
                        
                        for tensor_name in ['input_ids', 'labels', 'attention_mask', 'position_ids']:
                            if tensor_name in result:
                                tensor = result[tensor_name]
                                current_seq = tensor[batch_idx]
                                
                                # 创建要插入的值
                                if tensor_name == 'input_ids':
                                    insert_values = torch.full((tokens_to_add,), image_token_index,
                                                             device=current_seq.device, dtype=current_seq.dtype)
                                elif tensor_name == 'labels':
                                    insert_values = torch.full((tokens_to_add,), -100,
                                                             device=current_seq.device, dtype=current_seq.dtype)
                                elif tensor_name == 'attention_mask':
                                    insert_values = torch.ones((tokens_to_add,),
                                                              device=current_seq.device, dtype=current_seq.dtype)
                                elif tensor_name == 'position_ids':
                                    start_pos = current_seq[insert_pos].item() if insert_pos < len(current_seq) else len(current_seq)
                                    insert_values = torch.arange(start_pos, start_pos + tokens_to_add,
                                                                device=current_seq.device, dtype=current_seq.dtype)
                                
                                # 插入新值
                                new_seq = torch.cat([
                                    current_seq[:insert_pos],
                                    insert_values,
                                    current_seq[insert_pos:]
                                ])
                                
                                # 调整position_ids的后续值
                                if tensor_name == 'position_ids' and insert_pos < len(current_seq):
                                    new_seq[insert_pos + tokens_to_add:] += tokens_to_add
                                
                                # 截断或填充到原始长度
                                target_length = tensor.shape[1]
                                if new_seq.shape[0] > target_length:
                                    new_seq = new_seq[:target_length]
                                elif new_seq.shape[0] < target_length:
                                    padding_size = target_length - new_seq.shape[0]
                                    if tensor_name == 'input_ids':
                                        pad_value = getattr(config, 'pad_token_id', 128257)
                                    elif tensor_name == 'labels':
                                        pad_value = -100
                                    elif tensor_name == 'attention_mask':
                                        pad_value = 0
                                    elif tensor_name == 'position_ids':
                                        pad_value = new_seq[-1].item() + 1 if len(new_seq) > 0 else 0
                                    
                                    if tensor_name == 'position_ids':
                                        padding = torch.arange(
                                            pad_value, pad_value + padding_size,
                                            device=new_seq.device, dtype=new_seq.dtype
                                        )
                                    else:
                                        padding = torch.full((padding_size,), pad_value,
                                                           device=new_seq.device, dtype=new_seq.dtype)
                                    
                                    new_seq = torch.cat([new_seq, padding])
                                
                                result[tensor_name][batch_idx] = new_seq
                        
                        any_fixed = True
                    else:
                        print(f"[FIX] No image tokens found, cannot add missing tokens safely")
                        return None
        
        if any_fixed:
            # 验证修复结果
            final_image_tokens = (result['input_ids'] == image_token_index).sum().item()
            expected_total = expected_features * batch_size
            print(f"[FIX] Final verification: {final_image_tokens} total image tokens (expected: {expected_total})")
            return result
        else:
            return None
        
    except Exception as e:
        print(f"[FIX] Error in fix_image_token_mismatch_comprehensive: {e}")
        import traceback
        print(f"[FIX] Traceback: {traceback.format_exc()}")
        return None

def patch_llava_next_forward():
    try:
        import transformers
        # 动态查找llava_next的modeling模块
        for name in dir(transformers.models):
            if "llava_next" in name:
                mod = getattr(transformers.models, name)
                if hasattr(mod, "modeling_llava_next"):
                    modeling = getattr(mod, "modeling_llava_next")
                    break
        else:
            print("llava_next modeling module not found!")
            return
        # 保存原始forward
        orig_forward = modeling.LlavaNextForConditionalGeneration.forward
        def new_forward(self, *args, **kwargs):
            try:
                return orig_forward(self, *args, **kwargs)
            except ValueError as e:
                if "Image features and image tokens do not match" in str(e):
                    print(f"[SKIP] Detecting token mismatch: {str(e)}")
                    
                    # 获取输入数据用于调试
                    input_ids = kwargs.get('input_ids')
                    pixel_values = kwargs.get('pixel_values')
                    
                    # 记录问题数据的信息用于后续分析
                    if input_ids is not None and pixel_values is not None:
                        image_token_index = getattr(self.config, 'image_token_index', 128256)
                        current_image_tokens = (input_ids == image_token_index).sum().item()
                        expected_features = pixel_values.shape[1] if pixel_values.dim() == 5 else 0
                        expected_total = expected_features * input_ids.shape[0]
                        
                        print(f"[SKIP] Problematic data - Current tokens: {current_image_tokens}, Expected: {expected_total}")
                        print(f"[SKIP] Input shape: {input_ids.shape}, Pixel values shape: {pixel_values.shape}")
                        
                        # 可选：保存问题数据的详细信息到文件
                        with open(f"/tmp/llava_skipped_samples_{os.getpid()}.txt", "a") as f:
                            f.write(f"=== Skipped Sample ===\n")
                            f.write(f"Error: {str(e)}\n")
                            f.write(f"Current tokens: {current_image_tokens}, Expected: {expected_total}\n")
                            f.write(f"Input shape: {input_ids.shape}, Pixel values shape: {pixel_values.shape}\n")
                            f.write(f"Batch size: {input_ids.shape[0]}\n")
                            
                            # 记录每个样本的token分布
                            for batch_idx in range(input_ids.shape[0]):
                                batch_tokens = input_ids[batch_idx]
                                batch_image_tokens = (batch_tokens == image_token_index).sum().item()
                                f.write(f"  Batch {batch_idx}: {batch_image_tokens} image tokens\n")
                            
                            f.write("\n")
                    
                    # 跳过这个有问题的batch，返回零loss
                    print(f"[SKIP] Skipping problematic batch and continuing training...")
                    
                    # 创建一个形状正确的零loss tensor
                    device = input_ids.device if input_ids is not None else next(self.parameters()).device
                    zero_loss = torch.tensor(0.0, device=device, requires_grad=True)
                    
                    # 返回一个兼容的输出对象，支持字典和索引访问
                    class SkippedOutput:
                        def __init__(self, loss):
                            self.loss = loss
                        
                        def __getitem__(self, key):
                            if key == "loss" or key == 0:
                                return self.loss
                            raise KeyError(f"Key {key} not found")
                        
                        def __contains__(self, key):
                            return key == "loss" or key == 0
                        
                        def get(self, key, default=None):
                            if key == "loss":
                                return self.loss
                            return default
                    
                    return SkippedOutput(zero_loss)
                    
                else:
                    # 其他类型的ValueError，正常抛出
                    raise
            except Exception as e:
                # 非ValueError异常，正常抛出
                raise
        modeling.LlavaNextForConditionalGeneration.forward = new_forward
    except Exception as ex:
        print(f"[patch_llava_next_forward] Exception: {ex}")

patch_llava_next_forward()

# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Optional

from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer


if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments, TrainerCallback

    from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


logger = get_logger(__name__)


def run_sft(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
    callbacks: Optional[list["TrainerCallback"]] = None,
):
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

    # 自动过滤有问题的数据样本（针对LLaVA模型）
    # 默认启用，可以通过环境变量 LLAVA_FILTER_DATA=0 来禁用
    if (hasattr(model.config, 'image_token_index') and 
        training_args.do_train and 
        os.environ.get('LLAVA_FILTER_DATA', '1') != '0'):
        print("[FILTER] Auto data filtering enabled for LLaVA model (set LLAVA_FILTER_DATA=0 to disable)")
        data_filter = create_data_filter_for_llava()
        
        # 过滤训练数据集
        if 'train_dataset' in dataset_module and dataset_module['train_dataset'] is not None:
            print(f"[FILTER] Filtering train dataset...")
            original_size = len(dataset_module['train_dataset'])
            dataset_module['train_dataset'] = data_filter(dataset_module['train_dataset'], model.config)
            filtered_size = len(dataset_module['train_dataset'])
            print(f"[FILTER] Train dataset: {original_size} -> {filtered_size} samples ({original_size-filtered_size} filtered out)")
        
        # 过滤验证数据集
        if 'eval_dataset' in dataset_module and dataset_module['eval_dataset'] is not None:
            if isinstance(dataset_module['eval_dataset'], dict):
                # 多个验证集的情况
                for key, eval_dataset in dataset_module['eval_dataset'].items():
                    print(f"[FILTER] Filtering eval dataset '{key}'...")
                    original_size = len(eval_dataset)
                    dataset_module['eval_dataset'][key] = data_filter(eval_dataset, model.config)
                    filtered_size = len(dataset_module['eval_dataset'][key])
                    print(f"[FILTER] Eval dataset '{key}': {original_size} -> {filtered_size} samples ({original_size-filtered_size} filtered out)")
            else:
                # 单个验证集的情况
                print(f"[FILTER] Filtering eval dataset...")
                original_size = len(dataset_module['eval_dataset'])
                dataset_module['eval_dataset'] = data_filter(dataset_module['eval_dataset'], model.config)
                filtered_size = len(dataset_module['eval_dataset'])
                print(f"[FILTER] Eval dataset: {original_size} -> {filtered_size} samples ({original_size-filtered_size} filtered out)")
    elif hasattr(model.config, 'image_token_index') and training_args.do_train:
        print("[INFO] LLaVA model detected but data filtering is disabled (LLAVA_FILTER_DATA=0)")

    if getattr(model, "is_quantized", False) and not training_args.do_train:
        setattr(model, "_hf_peft_config_loaded", True)  # hack here: make model compatible with prediction

    data_collator = SFTDataCollatorWith4DAttentionMask(
        template=template,
        model=model if not training_args.predict_with_generate else None,
        pad_to_multiple_of=8 if training_args.do_train else None,  # for shift short attention
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
        block_diag_attn=model_args.block_diag_attn,
        attn_implementation=getattr(model.config, "_attn_implementation", None),
        compute_dtype=model_args.compute_dtype,
        **tokenizer_module,
    )

    # Metric utils
    metric_module = {}
    if training_args.predict_with_generate:
        metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
    elif finetuning_args.compute_accuracy:
        metric_module["compute_metrics"] = ComputeAccuracy()
        metric_module["preprocess_logits_for_metrics"] = eval_logit_processor

    # Keyword arguments for `model.generate`
    gen_kwargs = generating_args.to_dict(obey_generation_config=True)
    gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id

    # Initialize our Trainer
    trainer = CustomSeq2SeqTrainer(
        model=model,
        args=training_args,
        finetuning_args=finetuning_args,
        data_collator=data_collator,
        callbacks=callbacks,
        gen_kwargs=gen_kwargs,
        **dataset_module,
        **tokenizer_module,
        **metric_module,
    )

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
        trainer.save_model()
        if finetuning_args.include_effective_tokens_per_second:
            train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
                dataset_module["train_dataset"], train_result.metrics, stage="sft"
            )

        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
            keys = ["loss"]
            if isinstance(dataset_module.get("eval_dataset"), dict):
                keys += sum(
                    [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
                )
            else:
                keys += ["eval_loss", "eval_accuracy"]

            plot_loss(training_args.output_dir, keys=keys)

    if training_args.predict_with_generate:
        tokenizer.padding_side = "left"  # use left-padding in generation

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Predict
    if training_args.do_predict:
        logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
        predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
        trainer.log_metrics("predict", predict_results.metrics)
        trainer.save_metrics("predict", predict_results.metrics)
        trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)

    # Create model card
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

def create_data_filter_for_llava():
    """
    创建一个数据过滤器，用于在训练前过滤掉可能导致token不匹配的样本
    """
    def filter_problematic_samples(dataset, model_config):
        """
        过滤数据集中可能有问题的样本
        """
        try:
            total_samples = len(dataset)
            print(f"[FILTER] Starting to filter dataset with {total_samples} samples...")
            
            filtered_indices = []
            skipped_count = 0
            detailed_stats = {
                'missing_fields': 0,
                'invalid_format': 0,
                'token_mismatch': 0,
                'processing_error': 0
            }
            
            image_token_index = getattr(model_config, 'image_token_index', 128256)
            
            # 创建日志文件
            import time
            log_file = f"/tmp/llava_filtered_samples_{int(time.time())}.txt"
            
            with open(log_file, "w") as f:
                f.write(f"=== LLaVA Data Filtering Log ===\n")
                f.write(f"Total samples: {total_samples}\n")
                f.write(f"Image token index: {image_token_index}\n")
                f.write(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            for idx in range(total_samples):
                if idx % 1000 == 0 and idx > 0:
                    print(f"[FILTER] Processed {idx}/{total_samples} samples, kept {len(filtered_indices)}, skipped {skipped_count}")
                
                try:
                    sample = dataset[idx]
                    
                    # 检查样本是否包含必要的字段
                    if not isinstance(sample, dict):
                        detailed_stats['invalid_format'] += 1
                        skipped_count += 1
                        continue
                        
                    if 'input_ids' not in sample or 'pixel_values' not in sample:
                        detailed_stats['missing_fields'] += 1
                        skipped_count += 1
                        continue
                    
                    input_ids = sample['input_ids']
                    pixel_values = sample['pixel_values']
                    
                    # 安全地转换input_ids为tensor
                    if isinstance(input_ids, (list, tuple)):
                        input_ids = torch.tensor(input_ids)
                    elif not hasattr(input_ids, 'shape'):
                        detailed_stats['invalid_format'] += 1
                        skipped_count += 1
                        continue
                    
                    # 安全地检查pixel_values
                    if not hasattr(pixel_values, 'shape') and not hasattr(pixel_values, 'dim'):
                        if isinstance(pixel_values, (list, tuple)):
                            try:
                                pixel_values = torch.tensor(pixel_values)
                            except:
                                detailed_stats['invalid_format'] += 1
                                skipped_count += 1
                                continue
                        else:
                            detailed_stats['invalid_format'] += 1
                            skipped_count += 1
                            continue
                    
                    # 计算期望的图像token数量
                    if hasattr(pixel_values, 'dim'):
                        if pixel_values.dim() == 4:  # [num_patches, channels, height, width]
                            expected_image_tokens = pixel_values.shape[0]
                        elif pixel_values.dim() == 5:  # [batch, num_patches, channels, height, width]
                            expected_image_tokens = pixel_values.shape[1]
                        else:
                            detailed_stats['invalid_format'] += 1
                            skipped_count += 1
                            continue
                    else:
                        # 如果没有dim属性，尝试从shape推断
                        if len(pixel_values.shape) == 4:
                            expected_image_tokens = pixel_values.shape[0]
                        elif len(pixel_values.shape) == 5:
                            expected_image_tokens = pixel_values.shape[1]
                        else:
                            detailed_stats['invalid_format'] += 1
                            skipped_count += 1
                            continue
                    
                    # 计算实际的图像token数量
                    actual_image_tokens = (input_ids == image_token_index).sum().item()
                    
                    # 检查是否匹配
                    if actual_image_tokens != expected_image_tokens:
                        detailed_stats['token_mismatch'] += 1
                        skipped_count += 1
                        
                        # 记录被跳过的样本信息到详细日志
                        with open(log_file, "a") as f:
                            f.write(f"Sample {idx}: Token mismatch - actual: {actual_image_tokens}, expected: {expected_image_tokens}\n")
                            f.write(f"  Input shape: {input_ids.shape}, Pixel values shape: {pixel_values.shape}\n")
                        continue
                    
                    # 如果通过所有检查，保留这个样本
                    filtered_indices.append(idx)
                    
                except Exception as e:
                    detailed_stats['processing_error'] += 1
                    skipped_count += 1
                    with open(log_file, "a") as f:
                        f.write(f"Sample {idx}: Processing error - {str(e)}\n")
                    continue
            
            # 输出详细统计信息
            print(f"[FILTER] Filtering complete!")
            print(f"[FILTER] Total samples: {total_samples}")
            print(f"[FILTER] Kept samples: {len(filtered_indices)}")
            print(f"[FILTER] Filtered samples: {skipped_count}")
            print(f"[FILTER] Breakdown:")
            print(f"[FILTER]   - Missing fields: {detailed_stats['missing_fields']}")
            print(f"[FILTER]   - Invalid format: {detailed_stats['invalid_format']}")
            print(f"[FILTER]   - Token mismatch: {detailed_stats['token_mismatch']}")
            print(f"[FILTER]   - Processing errors: {detailed_stats['processing_error']}")
            print(f"[FILTER] Detailed log saved to: {log_file}")
            
            # 写入统计摘要到日志文件
            with open(log_file, "a") as f:
                f.write(f"\n=== Filtering Summary ===\n")
                f.write(f"Total samples: {total_samples}\n")
                f.write(f"Kept samples: {len(filtered_indices)}\n")
                f.write(f"Filtered samples: {skipped_count}\n")
                f.write(f"Missing fields: {detailed_stats['missing_fields']}\n")
                f.write(f"Invalid format: {detailed_stats['invalid_format']}\n")
                f.write(f"Token mismatch: {detailed_stats['token_mismatch']}\n")
                f.write(f"Processing errors: {detailed_stats['processing_error']}\n")
                f.write(f"End time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            
            # 创建过滤后的数据集
            if hasattr(dataset, 'select'):
                # 如果是HuggingFace datasets
                filtered_dataset = dataset.select(filtered_indices)
            else:
                # 如果是其他类型的数据集，创建一个简单的包装器
                class FilteredDataset:
                    def __init__(self, original_dataset, indices):
                        self.dataset = original_dataset
                        self.indices = indices
                    
                    def __len__(self):
                        return len(self.indices)
                    
                    def __getitem__(self, idx):
                        return self.dataset[self.indices[idx]]
                
                filtered_dataset = FilteredDataset(dataset, filtered_indices)
            
            return filtered_dataset
            
        except Exception as e:
            print(f"[FILTER] Error during filtering: {e}")
            print(f"[FILTER] Returning original dataset without filtering")
            return dataset
    
    return filter_problematic_samples
