import torch
import sys
sys.path.append('/ssd/0/wzq/Multi_Med/datapress')
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from transformers import AutoTokenizer
from Aligned.multiModelAligned_dataset import MultiModalAlignedDataset
from PIL import Image
import tqdm
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torch.multiprocessing as mp
import numpy as np
from functools import lru_cache
import concurrent.futures
from typing import Dict, List, Tuple, Any
import time
mp.set_start_method('spawn', force=True)

class MultiDataset(Dataset):
    def __init__(self, aligned_dataset, image_size=224, max_len=512, text_model_path="/ssd/0/wzq/Multi_Med/checkpoints/text_encoder/local_model_dir"):
        self.aligned_dataset = aligned_dataset
        self.image_size = image_size
        self.max_len = max_len
        
        # Pre-compute image transforms for efficiency
        self.image_transforms = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Optimized SimpleTokenizer with caching
        class SimpleTokenizer:
            def __init__(self, vocab_size: int = 10000, pad_token_id: int = 0, unk_token_id: int = 1):
                self.vocab_size = vocab_size
                self.pad_token_id = pad_token_id
                self.unk_token_id = unk_token_id
                self._cache = {}
                
            @lru_cache(maxsize=10000)
            def _tokenize_cached(self, text: str, max_length: int):
                tokens = (text or "").lower().split()
                ids = [abs(hash(tok)) % (self.vocab_size - 2) + 2 for tok in tokens]
                if len(ids) > max_length:
                    ids = ids[:max_length]
                return ids
                
            def __call__(self, text: str, truncation=True, max_length=512, padding='max_length', return_tensors='pt'):
                import torch
                ids = self._tokenize_cached(text, max_length)
                attn = [1] * len(ids)
                if padding == 'max_length':
                    pad_len = max(0, max_length - len(ids))
                    ids = ids + [self.pad_token_id] * pad_len
                    attn = attn + [0] * pad_len
                input_ids = torch.tensor([ids], dtype=torch.long)
                attention_mask = torch.tensor([attn], dtype=torch.long)
                token_type_ids = torch.zeros_like(attention_mask)
                return {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids}
        
        # Initialize tokenizer with fallback
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(text_model_path, local_files_only=True)
        except Exception:
            try:
                self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
            except Exception:
                self.tokenizer = SimpleTokenizer()
        
        # Pre-allocate tensors for common operations
        self._empty_image = torch.zeros(1, 3, image_size, image_size)
        self._empty_static = torch.zeros(1, 3, 2)

    def __len__(self):
        return len(self.aligned_dataset)

    def _process_image(self, image_tensor):
        """Optimized image processing with reduced tensor operations"""
        if image_tensor is None:
            return None
            
        # More efficient tensor operations
        if image_tensor.requires_grad:
            image_tensor = image_tensor.detach()
        
        # Convert to numpy more efficiently
        if image_tensor.dim() == 3 and image_tensor.shape[0] == 3:
            image_tensor = image_tensor.permute(1, 2, 0)
        
        image_np = image_tensor.cpu().numpy()
        image_np = (image_np * 255).astype(np.uint8)
        
        # Convert to PIL and apply transforms
        image_pil = Image.fromarray(image_np)
        return self.image_transforms(image_pil)

    def __getitem__(self, idx):
        item = self.aligned_dataset[idx]
        
        # Pre-allocate lists with estimated capacity
        n_med_items = len(item['med_items'])
        n_cxr_items = len(item['cxr_items'])
        
        all_time_series = []
        all_time_series_lengths = []
        all_static_data = []
        all_med_encodings = []
        all_labels = {}

        # Process medical items
        for med_item in item['med_items']:
            # 处理med_item可能是列表的情况
            if isinstance(med_item, list):
                # 如果med_item是列表，取第一个元素
                med_item = med_item[0] if med_item else {}
            
            # 检查med_item是否为字典
            if not isinstance(med_item, dict):
                print(f"Warning: med_item is not a dict, type: {type(med_item)}")
                continue
                
            time_series = med_item['dynamic_data']
            if len(time_series.shape) == 1:
                time_series = time_series.unsqueeze(0)
            all_time_series.append(time_series)
            all_time_series_lengths.append(time_series.shape[0])

            static_data = med_item['static_data']
            all_static_data.append(static_data.unsqueeze(0))

            # Tokenize text note
            text_note = med_item.get('text_note', '')
            encoding = self.tokenizer(text_note, truncation=True, max_length=self.max_len, padding='max_length', return_tensors='pt')
            encoding = {k: v.squeeze(0) for k, v in encoding.items()}
            all_med_encodings.append(encoding)
            
            # Process labels
            labels_dict = med_item['label']
            for key, value in labels_dict.items():
                if key not in all_labels:
                    all_labels[key] = []
                all_labels[key].append(value.unsqueeze(0))
            
            if 'next_step' not in all_labels:
                all_labels['next_step'] = []
            next_step = med_item['dynamic_data_now']
            all_labels['next_step'].append(next_step.unsqueeze(0))

        # Process CXR items
        cxr_images = []
        for cxr_item in item['cxr_items']:
            image = cxr_item['image']
            processed_image = self._process_image(image)
            if processed_image is not None:
                cxr_images.append(processed_image)
            
            # Tokenize report
            report = cxr_item['report']
            encoding = self.tokenizer(report, truncation=True, max_length=self.max_len, padding='max_length', return_tensors='pt')
            encoding = {k: v.squeeze(0) for k, v in encoding.items()}
            all_med_encodings.append(encoding)

        # Efficient tensor concatenation
        if all_static_data:
            concatenated_static_data = torch.cat(all_static_data, dim=0)
        else:
            concatenated_static_data = self._empty_static
            
        # Concatenate labels efficiently
        for key in all_labels:
            if all_labels[key]:
                all_labels[key] = torch.cat(all_labels[key], dim=0)
            else:
                all_labels[key] = torch.empty(0)

        # Handle images
        if cxr_images:
            images = torch.stack(cxr_images)
        else:
            images = self._empty_image.clone()

        return images, all_med_encodings, all_time_series, all_time_series_lengths, concatenated_static_data, all_labels


def collate_fn(batch, max_item=16):
    """Optimized collate function with reduced tensor operations"""
    import os
    
    # Pre-compute dimensions for efficiency
    batch_size = len(batch)
    max_images = min(max(item[0].shape[0] for item in batch), max_item)
    max_ts = min(max(item[4].shape[0] for item in batch), max_item)
    n_features = batch[0][2][0].shape[-2]
    # Pre-allocate tensors to avoid repeated allocations
    image_size = batch[0][0].shape[-1] if batch[0][0].numel() > 0 else 224
    max_image_len = max_images
    
    # Pre-allocate image tensor
    images = torch.zeros(batch_size, max_image_len, 3, image_size, image_size)
    
    # Process batch with vectorized operations where possible
    trimmed_batch = []
    for i, (images_item, cxr_input_encodings, all_time_series, all_time_series_lengths, concatenated_static_data, all_labels) in enumerate(batch):
        # Trim modalities efficiently
        if images_item.shape[0] > max_images:
            images_item = images_item[:max_images]
        if len(cxr_input_encodings) > max_ts:
            cxr_input_encodings = cxr_input_encodings[:max_ts]
        
        # Fill pre-allocated image tensor
        num_images = images_item.shape[0]
        images[i, :num_images] = images_item
        
        trimmed_batch.append((images_item, cxr_input_encodings, all_time_series, all_time_series_lengths, concatenated_static_data, all_labels))
    
    batch = trimmed_batch

    # Optimized text processing
    max_report_counts = max_images
    
    # Find max text length efficiently
    max_text_len = 512
    for _, cxr_input_encodings, _, _, _, _ in batch:
        if len(cxr_input_encodings) > 0:
            dim = len(cxr_input_encodings[0]['input_ids'])
            max_text_len = min(dim, 512)
            break
    
    # Pre-allocate text tensors
    input_ids_batch = torch.zeros((batch_size, max_report_counts, max_text_len), dtype=torch.long)
    attention_mask_batch = torch.zeros((batch_size, max_report_counts, max_text_len), dtype=torch.long)
    token_type_ids_batch = torch.zeros((batch_size, max_report_counts, max_text_len), dtype=torch.long)
    
    # Vectorized text processing
    for i, (_, cxr_input_encodings, _, _, _, _) in enumerate(batch):
        for j, encoding in enumerate(cxr_input_encodings):
            if j >= max_report_counts:
                break
            
            # Efficient tensor conversion
            ids = encoding['input_ids'][:max_text_len]
            mask = encoding['attention_mask'][:max_text_len]
            types = encoding['token_type_ids'][:max_text_len]
            
            # Ensure tensors are the right size
            if len(ids) < max_text_len:
                pad_len = max_text_len - len(ids)
                ids = torch.cat([ids, torch.zeros(pad_len, dtype=torch.long)])
                mask = torch.cat([mask, torch.zeros(pad_len, dtype=torch.long)])
                types = torch.cat([types, torch.zeros(pad_len, dtype=torch.long)])
            
            input_ids_batch[i, j] = ids
            attention_mask_batch[i, j] = mask
            token_type_ids_batch[i, j] = types
    
    text_item = {
        'input_ids': input_ids_batch,
        'attention_mask': attention_mask_batch, 
        'token_type_ids': token_type_ids_batch
    }
    
    # Optimized time series processing
    all_time_series = [item[2] for item in batch]
    batch_size = len(all_time_series)
    n_med = max_ts
    max_seq_len = 167
    
    # Pre-allocate time series tensor with missing value indicator
    padded_time_series = torch.zeros(batch_size, n_med, max_seq_len, n_features, 2)
    padded_time_series[:, :, :, :, 1] = 1  # Missing value indicator
    time_series_lengths = torch.zeros(batch_size, n_med, dtype=torch.long)
    
    # Vectorized time series processing
    for i, ts_list in enumerate(all_time_series):
        for j, ts in enumerate(ts_list):
            if j >= n_med:
                break
                
            T = ts.shape[0]
            t_min = min(T, max_seq_len)
            
            if ts.dim() == 2:
                # Handle 2D tensors [T, 2]
                cur = ts[:t_min]
                if cur.shape[1] < 2:
                    # Pad to 2 dimensions
                    pad = torch.zeros(t_min, 2 - cur.shape[1], dtype=cur.dtype, device=cur.device)
                    cur = torch.cat([cur, pad], dim=1)
                elif cur.shape[1] > 2:
                    cur = cur[:, :2]
                padded_time_series[i, j, :t_min, 0, :] = cur
                
            elif ts.dim() == 3:
                # Handle 3D tensors [T, D1, D2]
                cur = ts[:t_min]
                if cur.shape[1] == 1 and cur.shape[2] >= 2:
                    padded_time_series[i, j, :t_min, 0, :] = cur[:, 0, :2]
                elif cur.shape[2] == 2:
                    # Direct assignment for compatible shapes
                    padded_time_series[i, j, :t_min, :cur.shape[1], :] = cur
                else:
                    # Flatten and take first two features
                    flat = cur.reshape(t_min, -1)
                    if flat.size(1) < 2:
                        flat = torch.nn.functional.pad(flat, (0, 2 - flat.size(1)))
                    padded_time_series[i, j, :t_min, 0, :] = flat[:, :2]
            else:
                # Unknown dimensions, skip but record length
                t_min = 0
                
            time_series_lengths[i, j] = t_min

    # Optimized static data processing
    max_static_len = max_ts
    
    # Pre-allocate static data tensor
    static_shape = batch[0][4].shape[1:] if batch[0][4].numel() > 0 else (3, 2)
    static_data = torch.zeros(batch_size, max_static_len, *static_shape)
    
    # 设置missing value indicator - 只有当最后一维>=2时才设置
    if static_data.dim() == 4 and static_data.shape[-1] >= 2:
        static_data[:, :, :, 1] = 1  # Missing value indicator
    elif static_data.dim() == 3:
        # 如果是3维 [batch_size, max_static_len, static_dim]
        # 只有当static_dim >= 2时才能设置missing value indicator
        if static_data.shape[-1] >= 2:
            # 添加一个维度用于missing value indicator
            static_data = static_data.unsqueeze(-1)  # [B, max_static_len, static_dim, 1]
            # Pad到2维（如果还不够）
            if static_data.shape[-1] < 2:
                pad = torch.zeros(*static_data.shape[:-1], 2 - static_data.shape[-1], 
                                dtype=static_data.dtype, device=static_data.device)
                static_data = torch.cat([static_data, pad], dim=-1)
            static_data[:, :, :, 1] = 1  # Missing value indicator
            static_data = static_data.squeeze(-1)  # 恢复3维
        # 如果最后一维<2，跳过missing value indicator设置
    
    static_lengths = torch.zeros(batch_size, dtype=torch.long)
    
    for i, item in enumerate(batch):
        num_static = min(item[4].shape[0], max_static_len)
        static_data[i, :num_static] = item[4][:num_static]
        static_lengths[i] = num_static
    
    # Reshape and pad to 10 dimensions
    static_data = static_data.reshape(batch_size, max_static_len, -1)
    if static_data.shape[-1] < 10:
        pad = torch.zeros(batch_size, max_static_len, 10 - static_data.shape[-1], device=static_data.device)
        static_data = torch.cat([static_data, pad], dim=-1)
    
    # Optimized label processing - 返回字典格式，包含所有任务标签
    # 收集所有batch中的标签键
    all_label_keys = set()
    for item in batch:
        label_dict = item[5]
        if isinstance(label_dict, dict):
            all_label_keys.update(label_dict.keys())
    
    # 为每个标签键创建batch tensor
    batched_labels = {}
    
    # 处理 next_step 标签（用于forecasting任务）
    if 'next_step' in all_label_keys:
        max_label_len = max_ts
        first_next_step = None
        for item in batch:
            label_dict = item[5]
            if isinstance(label_dict, dict) and 'next_step' in label_dict and label_dict['next_step'].numel() > 0:
                first_next_step = label_dict['next_step']
                break
        label_shape = first_next_step.shape[1:] if first_next_step is not None else (n_features, 2)
        next_step_labels = torch.zeros(batch_size, max_label_len, *label_shape)
        next_step_labels[:, :, :, 1] = 1  # Missing value indicator
        
        for i, item in enumerate(batch):
            label_dict = item[5]
            if isinstance(label_dict, dict) and 'next_step' in label_dict:
                value = label_dict['next_step']
                num_labels = min(value.shape[0], max_label_len)
                next_step_labels[i, :num_labels] = value[:num_labels]
        
        batched_labels['next_step'] = next_step_labels
    
    # 处理其他分类任务标签（如mortality_24h_48h, los_prediction_48h等）
    for label_key in all_label_keys:
        if label_key == 'next_step':
            continue  # 已经处理过了
        
        # 找到第一个非空标签来确定形状
        first_label = None
        for item in batch:
            label_dict = item[5]
            if isinstance(label_dict, dict) and label_key in label_dict and label_dict[label_key].numel() > 0:
                first_label = label_dict[label_key]
                break
        
        if first_label is None:
            continue
        
        # 确定标签形状
        label_shape = first_label.shape[1:] if first_label.dim() > 1 else ()
        # 计算max label length（对于每个样本）
        max_label_len_for_key = max(
            (item[5][label_key].shape[0] if isinstance(item[5], dict) and label_key in item[5] else 0)
            for item in batch
        ) if max_ts > 0 else 1
        
        # 创建batch tensor
        if label_shape:
            batched_label = torch.zeros(batch_size, max_label_len_for_key, *label_shape, dtype=first_label.dtype)
        else:
            batched_label = torch.zeros(batch_size, max_label_len_for_key, dtype=first_label.dtype)
        
        # 填充batch tensor
        for i, item in enumerate(batch):
            label_dict = item[5]
            if isinstance(label_dict, dict) and label_key in label_dict:
                value = label_dict[label_key]
                num_labels = min(value.shape[0], max_label_len_for_key)
                if label_shape:
                    batched_label[i, :num_labels] = value[:num_labels]
                else:
                    batched_label[i, :num_labels] = value[:num_labels]
        
        batched_labels[label_key] = batched_label
    
    # 如果没有标签，创建一个空字典
    if not batched_labels:
        batched_labels = {}
    
    return images, text_item, padded_time_series, time_series_lengths, static_data, batched_labels


def create_data_loader(dataset, batch_size, image_size=224, max_len=512, out_label=96, shuffle=True, num_workers=4, text_model_path=None):
    """Create optimized DataLoader with improved performance settings"""
    # Optimized DataLoader configuration
    data_loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle, 
        collate_fn=collate_fn, 
        num_workers=num_workers,  # Enable multiprocessing
        pin_memory=True,  # Enable pin_memory for faster GPU transfer
        persistent_workers=True if num_workers > 0 else False,  # Keep workers alive
        prefetch_factor=2 if num_workers > 0 else None,  # Prefetch batches
        drop_last=False,  # Keep all data
        multiprocessing_context='spawn' if num_workers > 0 else None  # Use spawn for stability
    )
    return data_loader

def benchmark_dataloader(data_loader, num_batches=10, warmup_batches=2):
    """Benchmark dataloader performance"""
    print(f"Benchmarking DataLoader performance...")
    print(f"Warmup batches: {warmup_batches}, Test batches: {num_batches}")
    
    # Warmup
    for i, batch in enumerate(data_loader):
        if i >= warmup_batches:
            break
    
    # Benchmark
    start_time = time.time()
    batch_times = []
    
    for i, batch in enumerate(data_loader):
        batch_start = time.time()
        if i >= num_batches:
            break
        batch_times.append(time.time() - batch_start)
    
    total_time = time.time() - start_time
    avg_batch_time = np.mean(batch_times)
    std_batch_time = np.std(batch_times)
    
    print(f"Results:")
    print(f"  Total time: {total_time:.2f}s")
    print(f"  Average batch time: {avg_batch_time:.3f}s ± {std_batch_time:.3f}s")
    print(f"  Batches per second: {1/avg_batch_time:.2f}")
    print(f"  Memory usage: {torch.cuda.memory_allocated() / 1024**3:.2f} GB" if torch.cuda.is_available() else "  CUDA not available")
    
    return {
        'total_time': total_time,
        'avg_batch_time': avg_batch_time,
        'std_batch_time': std_batch_time,
        'batches_per_second': 1/avg_batch_time
    }

def save_dataloader_to_single_file(dataloader, output_file_path):
    """
    将 DataLoader 中的所有批次数据保存到一个单独的 .pt 文件中。
    PackedSequence 会被分解为 data 和 batch_sizes 进行保存。

    Args:
        dataloader (DataLoader): 要保存数据的 DataLoader 实例。
        output_file_path (str): 保存输出文件的完整路径。
    """
    all_processed_batches = []
    print(f"开始收集并保存 DataLoader 中的数据到: {output_file_path}")

    for i, (images, text_item, packed_time_series_batch, static_data, labels) in enumerate(dataloader):
        # 将 PackedSequence 分解为 data 和 batch_sizes
        time_series_data = packed_time_series_batch.data
        time_series_batch_sizes = packed_time_series_batch.batch_sizes

        # 创建一个字典来存储当前批次的所有组件
        batch_to_save = {
            'images': images.squeeze(0),
            'text_item': text_item,
            'time_series_data': time_series_data.squeeze(0),
            'time_series_batch_sizes': time_series_batch_sizes,
            'static_data': static_data.squeeze(0),
            'labels': labels
        }
        all_processed_batches.append(batch_to_save)
        if i > 10 :
            break

    # 确保输出目录存在
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

    # 保存所有批次
    torch.save(all_processed_batches, output_file_path+ '/all_data.pt')
    print(f"所有 {len(all_processed_batches)} 个批次数据已成功保存到: {output_file_path}")


if __name__ == "__main__":
    import os
    from Aligned.mimiccxr_dataset import MIMICCXRDataset
    from Aligned.medical_dataset import MedicalDataset
    from omegaconf import OmegaConf

    # 路径参数（请根据实际情况修改）
    base_data_path = '/ssd/0/wzq/Multi_Med/'
    index_file = os.path.join(base_data_path, 'data_dir/MIMIC/index.json')
    image_dir = '/hdd/0/dkm/REFERS-master/data/MIMIC/data/files'
    reports_dir = os.path.join(base_data_path, 'mimic-cxr-reports')

    # 图像预处理
    from torchvision import transforms
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])

    # 加载CXR数据集
    cxr_dataset = MIMICCXRDataset(
        index_file_path=index_file,
        image_root=image_dir,
        reports_root=reports_dir,
        transform=transform
    )

    opt = OmegaConf.load("/ssd/0/wzq/Multi_Med/exp/mimic_data/exp_mix_age.yaml")

    med_dataset = MedicalDataset(**opt.data.train_val, **opt.data.shared_param)

    # 构建融合数据集
    json_path = os.path.join(base_data_path, 'aligned_subjects.json')
    multi_dataset = MultiModalAlignedDataset(cxr_dataset, med_dataset, sid_json_path=json_path)
    print(f"Aligned multimodal dataset size: {len(multi_dataset)}")

    data_loader = create_data_loader(multi_dataset, batch_size=2, num_workers=4)

    print("\n==== 测试 DataLoader 多模态 batch 输出 ====")
    for batch_idx, (images, text_item, padded_time_series, time_series_lengths, static_data, labels) in enumerate(data_loader):
        print(f"\nBatch {batch_idx}:")
        print("  images:", images.shape)
        print("  text_item['input_ids']:", text_item['input_ids'].shape)
        print("  padded_time_series:", padded_time_series.shape)
        print("  time_series_lengths:", time_series_lengths.shape)
        print("  static_data:", static_data.shape)
        for key, value in labels.items():
            print(f"  labels['{key}']:", value.shape)
        if batch_idx >= 2:
            break
    print("\n==== DataLoader 测试完成 ====")
    
    # Performance benchmarking
    print("\n==== 性能基准测试 ====")
    benchmark_results = benchmark_dataloader(data_loader, num_batches=5, warmup_batches=1)
    print("==== 基准测试完成 ====")