"""
MedicalDataset wrapper: Adds empty images and text to pure medical datasets to make them compatible with multimodal model interfaces
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset

# Conditional import to support offline-only releases
try:
    from datapress.Aligned.medical_dataset_with_los import MedicalDatasetWithLOS
except ImportError:
    MedicalDatasetWithLOS = None

# 在模块级别导入collate_fn，避免多进程中的导入问题
try:
    from datapress.dataloader import collate_fn as dataloader_collate_fn
except ImportError:
    # 如果导入失败，定义一个简单的fallback
    dataloader_collate_fn = None

# 模块级别的简单tokenizer类，可以被pickle
class SimpleTokenizer:
    """简单的fallback tokenizer，当AutoTokenizer不可用时使用"""
    def __call__(self, text, truncation=True, max_length=512, padding='max_length', return_tensors='pt'):
        # 简单的tokenization
        tokens = (text or "").lower().split()[:max_length]
        ids = [hash(t) % 10000 for t in tokens]
        pad_len = max_length - len(ids)
        ids = ids + [0] * pad_len
        attn = [1] * (max_length - pad_len) + [0] * pad_len
        return {
            'input_ids': torch.tensor([ids], dtype=torch.long),
            'attention_mask': torch.tensor([attn], dtype=torch.long),
            'token_type_ids': torch.zeros(1, max_length, dtype=torch.long)
        }

class MedicalDatasetWrapper(Dataset):
    """
    包装MedicalDataset，添加空图像和文本以兼容多模态模型
    避免多模态对齐时的索引越界问题
    """
    def __init__(self, medical_dataset, max_text_length=512, image_size=224):
        """
        Args:
            medical_dataset: MedicalDatasetWithLOS 实例
            max_text_length: 文本最大长度
            image_size: 图像大小
        """
        self.medical_dataset = medical_dataset
        self.max_text_length = max_text_length
        self.image_size = image_size
        
        # 初始化tokenizer（用于处理text_note）
        # 优先使用transformers的tokenizer，失败时使用简单的fallback
        try:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        except:
            # 使用模块级别的SimpleTokenizer（可以被pickle）
            self.tokenizer = SimpleTokenizer()
    
    def __len__(self):
        return len(self.medical_dataset)
    
    def __getitem__(self, index):
        # 获取医学数据
        med_sample = self.medical_dataset[index]
        
        # Handle both dict and tuple formats
        # If it's already a tuple from MedicalDatasetWrapper, return it directly
        if isinstance(med_sample, tuple):
            # Already wrapped, return as is
            return med_sample
        
        # If it's a dict, process it
        if not isinstance(med_sample, dict):
            raise TypeError(f"Expected dict or tuple, got {type(med_sample)}")
        
        # 创建空图像（1个图像，全零）
        images = torch.zeros(1, 3, self.image_size, self.image_size)
        
        # 处理文本（使用text_note）
        text_note = med_sample.get('text_note', '')
        text_encoding = self.tokenizer(
            text_note,
            truncation=True,
            max_length=self.max_text_length,
            padding='max_length',
            return_tensors='pt'
        )
        # 扩展为多文本格式 [1, 1, max_len] -> [1, max_len]
        text_item = {
            'input_ids': text_encoding['input_ids'].squeeze(0),  # [1, max_len]
            'attention_mask': text_encoding['attention_mask'].squeeze(0),
            'token_type_ids': text_encoding['token_type_ids'].squeeze(0)
        }
        
        # 处理时间序列数据
        dynamic_data = med_sample['dynamic_data']  # 可能是 [n_med, T, F, 2] 或其他格式
        
        # 转换为列表格式（兼容现有的collate函数）
        # 将动态数据转换为列表，每个元素是一个时间序列
        time_series_lengths_list = []
        if dynamic_data.dim() == 4:  # [n_med, T, F, 2]
            n_med, T, F, _ = dynamic_data.shape
            time_series_list = []
            for i in range(n_med):
                ts = dynamic_data[i]  # [T, F, 2]
                # 计算有效长度（非NaN的部分）
                ts_values = ts[:, :, 0]  # [T, F]
                valid_mask = ~torch.isnan(ts_values).all(dim=-1)
                valid_length = valid_mask.sum().item()
                time_series_list.append(ts)
                time_series_lengths_list.append(valid_length)
        elif dynamic_data.dim() == 3:  # [T, F, 2] 或 [n_med, T, F]
            if dynamic_data.shape[-1] == 2:  # [T, F, 2]
                time_series_list = [dynamic_data]
                ts_values = dynamic_data[:, :, 0]  # [T, F]
                valid_mask = ~torch.isnan(ts_values).all(dim=-1)
                valid_length = valid_mask.sum().item()
                time_series_lengths_list = [valid_length]
            else:  # [n_med, T, F]
                time_series_list = []
                for i in range(dynamic_data.shape[0]):
                    ts = dynamic_data[i]  # [T, F]
                    # 添加缺失值指示器维度
                    ts_with_mask = torch.stack([ts, torch.ones_like(ts)], dim=-1)  # [T, F, 2]
                    time_series_list.append(ts_with_mask)
                    valid_mask = ~torch.isnan(ts).all(dim=-1)
                    time_series_lengths_list.append(valid_mask.sum().item())
        else:
            # 其他格式，尝试适配
            time_series_list = [dynamic_data]
            time_series_lengths_list = [dynamic_data.shape[0] if dynamic_data.dim() >= 1 else 1]
        
        # 静态数据 - 需要与时间序列数量一致（每个时间序列对应一个静态数据）
        # 格式必须与MultiDataset完全一致
        # MultiDataset中：static_data.unsqueeze(0) -> [1, ...] -> cat -> [n_med, ...]
        static_data = med_sample['static_data']
        n_med = len(time_series_list)
        
        # 处理静态数据，完全模拟MultiDataset的处理方式
        # MultiDataset中：static_data.unsqueeze(0) -> [1, *original_shape] -> cat -> [n_med, *original_shape]
        # collate_fn期望：item[4] 是 [n_med, *shape]，其中 shape[1:] 用于创建 static_shape
        if static_data.numel() == 0:
            # 空tensor，使用与_empty_static相同的格式 [1, 3, 2]
            empty_static = torch.zeros(1, 3, 2)
            # 为每个时间序列复制，格式 [n_med, 3, 2]
            concatenated_static_data = empty_static.expand(n_med, 3, 2).clone()
        else:
            # 获取原始静态数据的形状
            # medical_dataset返回的static_data可能是1D或更高维
            original_shape = static_data.shape
            
            # 如果static_data是1D，我们需要扩展为至少2D，以匹配collate_fn的期望
            # 从_empty_static看，默认格式是 [3, 2]，所以我们尝试将1D数据reshape为类似格式
            if static_data.dim() == 1:
                # 1D数据，尝试reshape为合理的2D格式
                # 如果长度是6，可以reshape为 [3, 2] 或 [6, 1] 等
                # 为了匹配collate_fn，我们使用 [static_dim, 1] 然后扩展
                static_dim = static_data.shape[0]
                # 扩展为 [static_dim, 2] 以匹配missing value indicator的需求
                if static_dim % 2 == 0:
                    # 可以reshape为 [static_dim//2, 2]
                    static_data = static_data.view(static_dim//2, 2)
                else:
                    # 不能整除，使用 [static_dim, 1] 然后pad到 [static_dim, 2]
                    static_data = static_data.unsqueeze(-1)  # [static_dim, 1]
                    pad = torch.zeros(static_dim, 1, dtype=static_data.dtype, device=static_data.device)
                    static_data = torch.cat([static_data, pad], dim=-1)  # [static_dim, 2]
                original_shape = static_data.shape
            elif static_data.dim() == 0:
                # 标量，扩展为2D
                static_data = static_data.unsqueeze(0).unsqueeze(0)  # [1, 1]
                pad = torch.zeros(1, 1, dtype=static_data.dtype, device=static_data.device)
                static_data = torch.cat([static_data, pad], dim=-1)  # [1, 2]
                original_shape = static_data.shape
            
            # 现在static_data至少是2D，格式为 [..., 2] 或类似
            # 按照MultiDataset的处理：unsqueeze(0) 然后cat
            static_with_batch = static_data.unsqueeze(0)  # [1, *original_shape]
            # 扩展到n_med
            concatenated_static_data = static_with_batch.expand(n_med, *original_shape).clone()
        
        # 验证格式：应该是 [n_med, ...]，至少2维，且形状与collate_fn期望一致
        assert concatenated_static_data.dim() >= 2, f"static_data should be at least 2D, got {concatenated_static_data.dim()}D with shape {concatenated_static_data.shape}"
        assert concatenated_static_data.shape[0] == n_med, f"First dimension should be {n_med}, got {concatenated_static_data.shape[0]}"
        
        # 标签 - 需要扩展为与时间序列数量一致
        labels = med_sample.get('label', {})
        
        # 返回格式与collate_batch兼容: (images, cxr_input_encodings, all_time_series, all_time_series_lengths, concatenated_static_data, all_labels)
        # 注意：item[4]必须是concatenated_static_data tensor，shape[0]应该是n_med
        return images, [text_item], time_series_list, time_series_lengths_list, concatenated_static_data, labels

def collate_medical_batch(batch):
    """
    为MedicalDatasetWrapper创建的batch进行collate
    使用现有的collate_fn逻辑
    """
    if dataloader_collate_fn is not None:
        # 我们的格式已经是正确的: (images, cxr_input_encodings, all_time_series, all_time_series_lengths, concatenated_static_data, all_labels)
        # 直接使用现有的collate函数
        return dataloader_collate_fn(batch)
    else:
        # Fallback: 如果导入失败，使用简单的堆叠
        raise ImportError("Cannot import collate_fn from datapress.dataloader. Please check the import path.")

