"""
时序对齐的多模态数据集
为CXR图像、报告和医疗时序数据添加时间标签，支持多模态图模型构建
"""

import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset
from datetime import datetime, timedelta
import pandas as pd
from typing import Dict, List, Tuple, Optional
import warnings

class TemporalAlignedDataset(Dataset):
    """
    时序对齐的多模态数据集
    为CXR图像、报告和医疗时序数据添加时间标签
    """
    
    def __init__(self, cxr_dataset, med_dataset, sid_json_path=None, 
                 time_resolution_hours=2, max_time_window_hours=336):
        """
        初始化时序对齐数据集
        
        Args:
            cxr_dataset: CXR数据集
            med_dataset: 医疗时序数据集
            sid_json_path: 对齐患者ID的JSON文件路径
            time_resolution_hours: 时间分辨率（小时）
            max_time_window_hours: 最大时间窗口（小时）
        """
        self.cxr_dataset = cxr_dataset
        self.med_dataset = med_dataset
        self.time_resolution_hours = time_resolution_hours
        self.max_time_window_hours = max_time_window_hours
        
        # 构建时序对齐索引
        self._build_temporal_alignment()
        
        # 时间编码器
        self._build_time_encoder()
        
    def _build_temporal_alignment(self):
        """构建时序对齐索引"""
        print("构建时序对齐索引...")
        
        # 获取对齐的患者ID
        if hasattr(self.med_dataset, 'sample_lut'):
            # 从医疗数据集获取患者信息
            self.aligned_subjects = self._get_aligned_subjects()
        else:
            # 使用默认对齐方式
            self.aligned_subjects = self._get_default_aligned_subjects()
        
        # 为每个患者构建时序对齐
        self.temporal_alignment = {}
        for subject_id in self.aligned_subjects:
            self.temporal_alignment[subject_id] = self._align_patient_temporal_data(subject_id)
        
        print(f"完成时序对齐，共 {len(self.aligned_subjects)} 个患者")
    
    def _get_aligned_subjects(self):
        """获取对齐的患者ID"""
        # 从医疗数据集获取患者ID
        med_subjects = set()
        for path in self.med_dataset.sample_lut:
            patient_dir = os.path.basename(os.path.dirname(path))
            for part in patient_dir.split('_'):
                if part.startswith("SID"):
                    sid = part[3:]
                    med_subjects.add(f"p{sid}")
        
        # 从CXR数据集获取患者ID
        cxr_subjects = set()
        for item in self.cxr_dataset.data:
            subject_id = item.get('subject_id')
            if subject_id:
                cxr_subjects.add(subject_id)
        
        # 取交集
        aligned_subjects = med_subjects & cxr_subjects
        return list(aligned_subjects)
    
    def _get_default_aligned_subjects(self):
        """获取默认对齐的患者ID"""
        # 这里可以加载之前保存的对齐信息
        if os.path.exists('aligned_subjects.json'):
            with open('aligned_subjects.json', 'r') as f:
                return json.load(f)
        else:
            # 返回空列表，需要手动构建对齐
            return []
    
    def _align_patient_temporal_data(self, subject_id):
        """为单个患者构建时序对齐"""
        alignment = {
            'cxr_items': [],
            'med_items': [],
            'time_axis': [],
            'temporal_graph': {}
        }
        
        # 获取患者的CXR数据
        cxr_items = self._get_patient_cxr_data(subject_id)
        
        # 获取患者的医疗时序数据
        med_items = self._get_patient_med_data(subject_id)
        
        # 构建时间轴
        time_axis = self._build_time_axis(cxr_items, med_items)
        
        # 为每个模态分配时间标签
        alignment['cxr_items'] = self._assign_cxr_timestamps(cxr_items, time_axis)
        alignment['med_items'] = self._assign_med_timestamps(med_items, time_axis)
        alignment['time_axis'] = time_axis
        
        # 构建时序图结构
        alignment['temporal_graph'] = self._build_temporal_graph(
            alignment['cxr_items'], alignment['med_items'], time_axis
        )
        
        return alignment
    
    def _get_patient_cxr_data(self, subject_id):
        """获取患者的CXR数据"""
        cxr_items = []
        
        # 从CXR数据集获取该患者的所有数据
        for item in self.cxr_dataset.data:
            if item.get('subject_id') == subject_id:
                cxr_item = self.cxr_dataset[self.cxr_dataset.data.index(item)]
                cxr_items.append(cxr_item)
        
        return cxr_items
    
    def _get_patient_med_data(self, subject_id):
        """获取患者的医疗时序数据"""
        # 这里需要根据实际的医疗数据集接口来获取数据
        # 假设医疗数据集有get_patient_data_by_sid方法
        if hasattr(self.med_dataset, 'get_patient_data_by_sid'):
            return self.med_dataset.get_patient_data_by_sid(subject_id)
        else:
            # 如果没有该方法，返回空列表
            return []
    
    def _build_time_axis(self, cxr_items, med_items):
        """构建统一的时间轴"""
        # 收集所有时间点
        all_timestamps = []
        
        # 从CXR数据中提取时间信息
        for cxr_item in cxr_items:
            timestamp = self._extract_cxr_timestamp(cxr_item)
            if timestamp:
                all_timestamps.append(timestamp)
        
        # 从医疗数据中提取时间信息
        for med_item in med_items:
            timestamps = self._extract_med_timestamps(med_item)
            all_timestamps.extend(timestamps)
        
        # 去重并排序
        all_timestamps = sorted(list(set(all_timestamps)))
        
        # 构建统一的时间轴
        if all_timestamps:
            start_time = min(all_timestamps)
            end_time = max(all_timestamps)
            
            # 按时间分辨率生成时间轴
            time_axis = []
            current_time = start_time
            while current_time <= end_time:
                time_axis.append(current_time)
                current_time += timedelta(hours=self.time_resolution_hours)
            
            return time_axis
        else:
            # 如果没有时间信息，创建默认时间轴
            return self._create_default_time_axis()
    
    def _extract_cxr_timestamp(self, cxr_item):
        """从CXR项目中提取时间戳"""
        # 尝试从图像路径或文件名中提取时间信息
        image_path = cxr_item.get('image_path', '')
        
        # 方法1: 从文件名中提取时间
        if 'study_time' in cxr_item:
            return self._parse_timestamp(cxr_item['study_time'])
        
        # 方法2: 从路径中提取时间信息
        path_parts = image_path.split('/')
        for part in path_parts:
            # 尝试解析时间格式
            timestamp = self._parse_timestamp(part)
            if timestamp:
                return timestamp
        
        # 方法3: 使用stay_id作为相对时间参考
        stay_id = cxr_item.get('stay_id')
        if stay_id:
            # 将stay_id转换为相对时间戳
            return self._stay_id_to_timestamp(stay_id)
        
        return None
    
    def _extract_med_timestamps(self, med_item):
        """从医疗项目中提取时间戳"""
        timestamps = []
        
        # 如果医疗数据包含时间信息
        if 'timestamps' in med_item:
            for ts in med_item['timestamps']:
                timestamp = self._parse_timestamp(ts)
                if timestamp:
                    timestamps.append(timestamp)
        
        # 如果动态数据包含时间信息
        if 'dynamic_data' in med_item:
            dynamic_data = med_item['dynamic_data']
            if hasattr(dynamic_data, 'shape') and len(dynamic_data.shape) >= 2:
                # 假设第一列是时间信息
                time_data = dynamic_data[:, 0] if dynamic_data.shape[1] > 0 else []
                for t in time_data:
                    timestamp = self._parse_timestamp(t)
                    if timestamp:
                        timestamps.append(timestamp)
        
        return timestamps
    
    def _parse_timestamp(self, time_str):
        """解析时间字符串"""
        if not time_str:
            return None
        
        # 尝试多种时间格式
        time_formats = [
            '%Y-%m-%d %H:%M:%S',
            '%Y-%m-%dT%H:%M:%S',
            '%Y%m%d%H%M%S',
            '%Y-%m-%d',
            '%H:%M:%S'
        ]
        
        for fmt in time_formats:
            try:
                if isinstance(time_str, str):
                    return datetime.strptime(time_str, fmt)
                elif isinstance(time_str, (int, float)):
                    # 处理时间戳
                    return datetime.fromtimestamp(time_str)
            except (ValueError, TypeError):
                continue
        
        return None
    
    def _stay_id_to_timestamp(self, stay_id):
        """将stay_id转换为相对时间戳"""
        # 这里可以根据实际情况实现
        # 暂时返回一个基于stay_id的相对时间
        try:
            # 假设stay_id包含时间信息
            if isinstance(stay_id, str) and len(stay_id) >= 8:
                # 提取数字部分作为相对时间
                numeric_part = ''.join(filter(str.isdigit, stay_id))
                if numeric_part:
                    # 转换为相对时间戳
                    base_time = datetime(2020, 1, 1)  # 基准时间
                    relative_hours = int(numeric_part) % 1000  # 取模避免过大
                    return base_time + timedelta(hours=relative_hours)
        except:
            pass
        
        # 默认返回基准时间
        return datetime(2020, 1, 1)
    
    def _create_default_time_axis(self):
        """创建默认时间轴"""
        # 创建14天的时间轴，每2小时一个点
        base_time = datetime(2020, 1, 1)
        time_axis = []
        for i in range(self.max_time_window_hours // self.time_resolution_hours):
            time_axis.append(base_time + timedelta(hours=i * self.time_resolution_hours))
        return time_axis
    
    def _assign_cxr_timestamps(self, cxr_items, time_axis):
        """为CXR项目分配时间戳"""
        timestamped_cxr = []
        
        for cxr_item in cxr_items:
            # 提取时间戳
            timestamp = self._extract_cxr_timestamp(cxr_item)
            
            if timestamp:
                # 找到最近的时间轴点
                time_index = self._find_nearest_time_index(timestamp, time_axis)
                
                # 添加时间信息
                timestamped_cxr.append({
                    **cxr_item,
                    'timestamp': timestamp,
                    'time_index': time_index,
                    'relative_time_hours': (timestamp - time_axis[0]).total_seconds() / 3600
                })
            else:
                # 如果没有时间信息，使用默认时间
                timestamped_cxr.append({
                    **cxr_item,
                    'timestamp': time_axis[0],
                    'time_index': 0,
                    'relative_time_hours': 0
                })
        
        return timestamped_cxr
    
    def _assign_med_timestamps(self, med_items, time_axis):
        """为医疗项目分配时间戳"""
        timestamped_med = []
        
        for med_item in med_items:
            # 提取时间戳
            timestamps = self._extract_med_timestamps(med_item)
            
            if timestamps:
                # 为每个时间点创建条目
                for timestamp in timestamps:
                    time_index = self._find_nearest_time_index(timestamp, time_axis)
                    
                    timestamped_med.append({
                        **med_item,
                        'timestamp': timestamp,
                        'time_index': time_index,
                        'relative_time_hours': (timestamp - time_axis[0]).total_seconds() / 3600
                    })
            else:
                # 如果没有时间信息，使用默认时间
                timestamped_med.append({
                    **med_item,
                    'timestamp': time_axis[0],
                    'time_index': 0,
                    'relative_time_hours': 0
                })
        
        return timestamped_med
    
    def _find_nearest_time_index(self, timestamp, time_axis):
        """找到最近的时间轴索引"""
        if not time_axis:
            return 0
        
        # 计算时间差
        time_diffs = [abs((timestamp - t).total_seconds()) for t in time_axis]
        return time_diffs.index(min(time_diffs))
    
    def _build_temporal_graph(self, cxr_items, med_items, time_axis):
        """构建时序图结构"""
        graph = {
            'nodes': [],
            'edges': [],
            'time_slices': {}
        }
        
        # 为每个时间点创建节点
        for i, time_point in enumerate(time_axis):
            graph['time_slices'][i] = {
                'cxr_nodes': [],
                'med_nodes': [],
                'timestamp': time_point
            }
        
        # 添加CXR节点
        for cxr_item in cxr_items:
            time_index = cxr_item['time_index']
            node_id = f"cxr_{cxr_item['stay_id']}_{time_index}"
            
            graph['nodes'].append({
                'id': node_id,
                'type': 'cxr',
                'time_index': time_index,
                'data': cxr_item
            })
            
            graph['time_slices'][time_index]['cxr_nodes'].append(node_id)
        
        # 添加医疗节点
        for med_item in med_items:
            time_index = med_item['time_index']
            node_id = f"med_{med_item.get('stay_id', 'unknown')}_{time_index}"
            
            graph['nodes'].append({
                'id': node_id,
                'type': 'medical',
                'time_index': time_index,
                'data': med_item
            })
            
            graph['time_slices'][time_index]['med_nodes'].append(node_id)
        
        # 添加时序边（时间相邻的节点）
        for i in range(len(time_axis) - 1):
            current_nodes = graph['time_slices'][i]['cxr_nodes'] + graph['time_slices'][i]['med_nodes']
            next_nodes = graph['time_slices'][i + 1]['cxr_nodes'] + graph['time_slices'][i + 1]['med_nodes']
            
            for curr_node in current_nodes:
                for next_node in next_nodes:
                    graph['edges'].append({
                        'source': curr_node,
                        'target': next_node,
                        'type': 'temporal',
                        'time_gap': self.time_resolution_hours
                    })
        
        # 添加同时间点的模态间边
        for time_index, time_slice in graph['time_slices'].items():
            cxr_nodes = time_slice['cxr_nodes']
            med_nodes = time_slice['med_nodes']
            
            for cxr_node in cxr_nodes:
                for med_node in med_nodes:
                    graph['edges'].append({
                        'source': cxr_node,
                        'target': med_node,
                        'type': 'cross_modal',
                        'time_gap': 0
                    })
        
        return graph
    
    def _build_time_encoder(self):
        """构建时间编码器"""
        # 时间编码维度
        self.time_encoding_dim = 64
        
        # 位置编码
        self.time_position_encoding = torch.nn.Parameter(
            torch.randn(self.max_time_window_hours // self.time_resolution_hours, self.time_encoding_dim)
        )
        
        # 时间嵌入层
        self.time_embedding = torch.nn.Linear(1, self.time_encoding_dim)
    
    def encode_time(self, relative_time_hours):
        """编码时间信息"""
        # 将相对时间转换为索引
        time_index = int(relative_time_hours / self.time_resolution_hours)
        time_index = max(0, min(time_index, self.time_position_encoding.shape[0] - 1))
        
        # 获取位置编码
        pos_encoding = self.time_position_encoding[time_index]
        
        # 获取时间嵌入
        time_tensor = torch.tensor([[relative_time_hours]], dtype=torch.float32)
        time_embedding = self.time_embedding(time_tensor).squeeze(0)
        
        # 组合编码
        return pos_encoding + time_embedding
    
    def __len__(self):
        return len(self.aligned_subjects)
    
    def __getitem__(self, idx):
        """获取时序对齐的样本"""
        subject_id = self.aligned_subjects[idx]
        alignment = self.temporal_alignment[subject_id]
        
        # 构建时序对齐的数据
        sample = {
            'subject_id': subject_id,
            'cxr_items': alignment['cxr_items'],
            'med_items': alignment['med_items'],
            'time_axis': alignment['time_axis'],
            'temporal_graph': alignment['temporal_graph'],
            'time_encoding': self._encode_sample_time(alignment)
        }
        
        return sample
    
    def _encode_sample_time(self, alignment):
        """编码样本的时间信息"""
        time_encodings = []
        
        # 为每个时间点编码
        for i, time_point in enumerate(alignment['time_axis']):
            relative_time = (time_point - alignment['time_axis'][0]).total_seconds() / 3600
            time_encoding = self.encode_time(relative_time)
            time_encodings.append(time_encoding)
        
        return torch.stack(time_encodings)
    
    def get_temporal_graph(self, subject_id):
        """获取指定患者的时序图"""
        if subject_id in self.temporal_alignment:
            return self.temporal_alignment[subject_id]['temporal_graph']
        return None
    
    def get_time_aligned_data(self, subject_id):
        """获取指定患者的时间对齐数据"""
        if subject_id in self.temporal_alignment:
            return self.temporal_alignment[subject_id]
        return None


def create_temporal_aligned_dataset(cxr_dataset, med_dataset, **kwargs):
    """创建时序对齐数据集的工厂函数"""
    return TemporalAlignedDataset(cxr_dataset, med_dataset, **kwargs) 