from torch.utils.data import Dataset
import json
import sys
sys.path.append('/ssd/0/wzq/Multi_Med')
from datapress.Aligned.mimiccxr_dataset import MIMICCXRDataset
from datapress.Aligned.medical_dataset import MedicalDataset
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from PIL import Image
import pandas as pd
from collections import defaultdict
import os

class MultiModalAlignedDataset(Dataset):
    def __init__(self, cxr_dataset, med_dataset, sid_json_path=None):
        self.cxr_dataset = cxr_dataset
        self.med_dataset = med_dataset

        import os, json, pandas as pd
        from collections import defaultdict

        cache_dir = '/ssd/0/wzq/Multi_Med/data_dir/cache'
        os.makedirs(cache_dir, exist_ok=True)
        idx_dict_path = os.path.join(cache_dir, 'idx_dict.csv')
        cxr_sid2idx_path = os.path.join(cache_dir, 'cxr_sid2idx.json')
        med_sid2idx_path = os.path.join(cache_dir, 'med_sid2idx.json')
        stayid2cxr_path  = os.path.join(cache_dir, 'stayid2cxr.json')

        # 如果用户传了对齐白名单（subject_id 列表或 {subjects:[...]} 结构），只构建这些
        whitelist_sids = None
        if sid_json_path and os.path.exists(sid_json_path):
            try:
                with open(sid_json_path, 'r') as f:
                    j = json.load(f)
                if isinstance(j, dict) and 'subjects' in j:
                    whitelist_sids = set(str(x) for x in j['subjects'])
                elif isinstance(j, list):
                    whitelist_sids = set(str(x) for x in j)
            except Exception:
                whitelist_sids = None  # 解析失败就忽略白名单

        # 优先读取已保存的索引对
        if os.path.exists(idx_dict_path):
            idx_df = pd.read_csv(idx_dict_path)
            self.cxr_indices  = idx_df['cxr_idx'].tolist()
            self.med_indices  = idx_df['med_idx'].tolist()
            self.common_sids  = [str(x) for x in idx_df['subject_id'].tolist()]
            # 加载 cxr_sid2idx / med_sid2idx / stayid2cxr（若存在）
            self.cxr_sid2idx = self._safe_load_json(cxr_sid2idx_path, default={})
            self.med_sid2idx = self._safe_load_json(med_sid2idx_path, default={})
            self.stayid2cxr  = self._safe_load_json(stayid2cxr_path,  default={})
            return

        # ========= 首次构建（加速版）=========
        # 1) 快速构建 CXR 映射（尽量不触发 __getitem__）
        cxr_sid2idx, cxr_meta_rows = self._fast_build_cxr_sid2idx(self.cxr_dataset, whitelist_sids)

        # 2) 快速构建 MED 映射（本来就不重）
        med_sid2idx = defaultdict(list)
        for idx, path in enumerate(self.med_dataset.sample_lut):
            # 解析路径中的 SIDxxxx
            patient_dir = os.path.basename(os.path.dirname(path))
            sid = None
            for part in patient_dir.split('_'):
                if part.startswith("SID"):
                    sid = part[3:]
                    break
            if sid is None:
                continue
            if whitelist_sids and sid not in whitelist_sids:
                continue
            med_sid2idx[sid].append(idx)
        med_sid2idx = dict(med_sid2idx)

        # 3) 求交集并取各自第一个索引（保持原语义）
        common_sids = sorted(set(cxr_sid2idx.keys()) & set(med_sid2idx.keys()))
        self.cxr_indices = [cxr_sid2idx[sid][0] for sid in common_sids]
        self.med_indices = [med_sid2idx[sid][0] for sid in common_sids]
        self.common_sids = common_sids

        # 4) 写缓存
        pd.DataFrame({
            'subject_id': self.common_sids,
            'cxr_idx': self.cxr_indices,
            'med_idx': self.med_indices,
        }).to_csv(idx_dict_path, index=False)

        with open(cxr_sid2idx_path, 'w') as f:
            json.dump(cxr_sid2idx, f, indent=2)
        with open(med_sid2idx_path, 'w') as f:
            json.dump(med_sid2idx, f, indent=2)

        # 5) 构建 stayid2cxr（基于元数据而非加载图片）
        stayid2cxr_multi = defaultdict(list)
        # cxr_meta_rows: list of dict( idx, subject_id, stay_id, image_path )
        cxr_idx_set = set(self.cxr_indices)  # 只对齐集内的样本再汇总
        for row in cxr_meta_rows:
            i = row['idx']
            if i in cxr_idx_set:
                stayid2cxr_multi[row.get('stay_id','')].append({
                    'cxr_idx': i,
                    'subject_id': row.get('subject_id', ''),
                    'image_path': row.get('image_path', ''),
                })
        self.stayid2cxr = dict(stayid2cxr_multi)
        with open(stayid2cxr_path, 'w') as f:
            json.dump(self.stayid2cxr, f, indent=2)

        # 保存到实例（兼容旧逻辑）
        self.cxr_sid2idx = cxr_sid2idx
        self.med_sid2idx = med_sid2idx

    # ================== 私有辅助：轻量读取 & 兜底 ==================
    def _safe_load_json(self, path, default):
        try:
            with open(path, 'r') as f:
                return json.load(f)
        except Exception:
            return default
    def _try_build_cxr_from_index_file(self, cxr_dataset, whitelist_sids=None):
        """
        优先从数据集暴露的 index 文件/路径构建轻量元数据，不触发 __getitem__。
        需要 cxr_dataset 暴露一个 index 路径属性，例如:
          - cxr_dataset.index_file_path / index_path / meta_path / manifest_path
        返回 (sid2idx: dict, meta_rows: list[dict]) 或 (None, None) 如果不可用。
        """
        import os, json
        from collections import defaultdict

        # 常见可能的属性名（可按你项目实际再加几个）
        cand_attrs = ['index_file_path', 'index_path', 'meta_path', 'manifest_path']
        index_path = None
        for a in cand_attrs:
            if hasattr(cxr_dataset, a):
                p = getattr(cxr_dataset, a)
                if isinstance(p, str) and os.path.exists(p):
                    index_path = p
                    break
        if index_path is None:
            return None, None

        # 流式解析大 JSON（如需要可回退到 json.load）
        sid2idx = defaultdict(list)
        meta_rows = []
        n = len(cxr_dataset)

        def add_row(i, sid, stay_id=None, image_path=None):
            if sid is None:
                return
            s = str(sid)
            if whitelist_sids and s not in whitelist_sids:
                return
            sid2idx[s].append(i)
            meta_rows.append({
                'idx': i, 'subject_id': s,
                'stay_id': (str(stay_id) if stay_id is not None else ''),
                'image_path': (str(image_path) if image_path is not None else '')
            })

        try:
            try:
                import ijson
                with open(index_path, 'rb') as f:
                    # 假设 index.json 是一个数组 [{...}, {...}, ...]
                    i = 0
                    for rec in ijson.items(f, 'item'):
                        sid  = rec.get('subject_id') or rec.get('sid') or rec.get('patient')  # 多备选
                        did  = rec.get('stay_id')    or rec.get('stay') or rec.get('hadm_id')
                        ipth = rec.get('image_path') or rec.get('path') or rec.get('img')
                        add_row(i, sid, did, ipth)
                        i += 1
            except ImportError:
                # 无 ijson 就一次性载入（小文件可用）
                with open(index_path, 'r') as f:
                    arr = json.load(f)
                for i, rec in enumerate(arr):
                    sid  = rec.get('subject_id') or rec.get('sid') or rec.get('patient')
                    did  = rec.get('stay_id')    or rec.get('stay') or rec.get('hadm_id')
                    ipth = rec.get('image_path') or rec.get('path') or rec.get('img')
                    add_row(i, sid, did, ipth)

            # 将轻量 DataFrame 注入数据集，后续也能命中快路径
            try:
                import pandas as pd
                cxr_dataset.index_df = pd.DataFrame(meta_rows)
            except Exception:
                pass

            return dict(sid2idx), meta_rows
        except Exception:
            return None, None

    def _fast_build_cxr_sid2idx(self, cxr_dataset, whitelist_sids=None):
        """
        返回:
          - cxr_sid2idx: {sid: [idx, ...]}
          - meta_rows:   [{'idx':i, 'subject_id':sid, 'stay_id':did, 'image_path':p}, ...]
        说明: 优先从数据集轻量元数据读取；若不可得，回退到 __getitem__（慢）。
        """
        sid2idx, meta_rows = self._try_build_cxr_from_index_file(cxr_dataset, whitelist_sids=whitelist_sids)
        if sid2idx is not None:
            return sid2idx, meta_rows

        from collections import defaultdict
        sid2idx = defaultdict(list)
        meta_rows = []

        def add_row(i, sid, stay_id=None, image_path=None):
            s = str(sid) if sid is not None else None
            if s is None:
                return
            if whitelist_sids and s not in whitelist_sids:
                return
            sid2idx[s].append(i)
            meta_rows.append({
                'idx': i, 'subject_id': s,
                'stay_id': (str(stay_id) if stay_id is not None else ''),
                'image_path': (str(image_path) if image_path is not None else '')
            })

        n = len(cxr_dataset)

        # 优先尝试常见的轻量属性（不会加载图像）
        # 适配多种可能的字段名：index / index_df / samples / meta / table ...
        used_fast_path = False
        try:
            # 情况 A：列表/字典组成的 index
            if hasattr(cxr_dataset, 'index') and isinstance(getattr(cxr_dataset, 'index'), (list, tuple)):
                for i, rec in enumerate(cxr_dataset.index):
                    # 可能的键名猜测
                    sid  = rec.get('subject_id') if isinstance(rec, dict) else None
                    did  = rec.get('stay_id')    if isinstance(rec, dict) else None
                    ipth = rec.get('image_path') if isinstance(rec, dict) else None
                    add_row(i, sid, did, ipth)
                used_fast_path = True

            # 情况 B：pandas DataFrame
            elif hasattr(cxr_dataset, 'index_df'):
                df = cxr_dataset.index_df  # 需包含列名 subject_id/stay_id/image_path（若没有就按可用列取）
                cols = df.columns
                sid_col  = 'subject_id' if 'subject_id' in cols else None
                did_col  = 'stay_id'    if 'stay_id' in cols else None
                img_col  = 'image_path' if 'image_path' in cols else None
                for i in range(len(df)):
                    sid = df.iloc[i][sid_col] if sid_col else None
                    did = df.iloc[i][did_col] if did_col else None
                    ipt = df.iloc[i][img_col] if img_col else None
                    add_row(i, sid, did, ipt)
                used_fast_path = True

            # 情况 C：samples / meta / table 等容器
            elif hasattr(cxr_dataset, 'samples') and isinstance(cxr_dataset.samples, (list, tuple)):
                for i, rec in enumerate(cxr_dataset.samples):
                    sid = rec.get('subject_id') if isinstance(rec, dict) else None
                    did = rec.get('stay_id')    if isinstance(rec, dict) else None
                    ipt = rec.get('image_path') if isinstance(rec, dict) else None
                    add_row(i, sid, did, ipt)
                used_fast_path = True
            elif hasattr(cxr_dataset, 'meta') and isinstance(cxr_dataset.meta, (list, tuple)):
                for i, rec in enumerate(cxr_dataset.meta):
                    sid = rec.get('subject_id') if isinstance(rec, dict) else None
                    did = rec.get('stay_id')    if isinstance(rec, dict) else None
                    ipt = rec.get('image_path') if isinstance(rec, dict) else None
                    add_row(i, sid, did, ipt)
                used_fast_path = True
        except Exception:
            used_fast_path = False  # 意外就回退慢路径

        # 回退：逐条 __getitem__（慢，但保证兼容）
        if not used_fast_path:
            for i in range(n):
                # 注意：有些实现里 __getitem__ 可能加载图像，为避免极慢，
                # 可以尝试 dataset 提供的轻量接口（若有）
                rec = None
                try:
                    if hasattr(cxr_dataset, 'get_meta'):
                        rec = cxr_dataset.get_meta(i)  # 约定：仅返回元数据（若实现）
                except Exception:
                    rec = None

                if rec is None:
                    rec = cxr_dataset[i]  # 最终兜底（可能很慢）

                sid = rec.get('subject_id') if isinstance(rec, dict) else None
                did = rec.get('stay_id')    if isinstance(rec, dict) else None
                ipt = rec.get('image_path') if isinstance(rec, dict) else None
                add_row(i, sid, did, ipt)

        return dict(sid2idx), meta_rows
        
    def __len__(self):
        return len(self.common_sids)

    def __getitem__(self, idx):
        subject_id = self.common_sids[idx]
        # 获取所有该subject_id的cxr样本
        cxr_idxs = self.cxr_sid2idx[subject_id]
        cxr_items = [self.cxr_dataset[i] for i in cxr_idxs]
        # 获取所有该subject_id的med样本
        med_idxs = self.med_sid2idx[subject_id]
        # med_idxs 已经是字典列表
        med_items = [self.med_dataset[i] for i in med_idxs]
        return {
            'subject_id': subject_id,
            'cxr_items': cxr_items,
            'med_items': med_items,
        }

    def get_cxr_by_stay_id(self, stay_id):
        """通过stay_id直接获取CXR样本（含图片、报告等）"""
        return self.stayid2cxr.get(stay_id, None)

    def get_cxr_by_subject_id(self, subject_id):
        """通过subject_id获取所有CXR样本索引"""
        idxs = self.cxr_sid2idx.get(subject_id, [])
        return [self.cxr_dataset[i] for i in idxs]

    def get_med_by_sid(self, subject_id):
        """通过subject_id（支持p前缀和无前缀）获取所有medical样本（字典列表）"""
        return self.med_dataset.get_patient_data_by_sid(subject_id)

    def get_med_by_did(self, stay_id):
        """通过stay_id（支持s前缀和无前缀）获取所有medical样本（字典列表）"""
        return self.med_dataset.get_patient_data_by_did(stay_id)

    def visualize_all(self, idx, mean=None, std=None, save_path=None):
        import matplotlib.pyplot as plt
        import os
        import shutil
        # 删除保存相关代码和临时图片文件
        if save_path is not None and os.path.exists(save_path):
            shutil.rmtree(save_path)
        if isinstance(idx, list):
            for i in idx:
                print(f"\n=== Visualizing aligned multimodal sample idx={i} ===")
                self.visualize_all(i, mean=mean, std=std, save_path=None)
            return
        sample = self[idx]
        print(f"Subject ID: {sample['subject_id']}")
        print(f"Number of CXR items: {len(sample['cxr_items'])}")
        cxr_idxs = self.cxr_indices[idx] if isinstance(self.cxr_indices[idx], list) else [self.cxr_indices[idx]]
        for i, cxr_item in enumerate(sample['cxr_items']):
            print(f"\n[CXR {i+1}/{len(sample['cxr_items'])}] Stay ID: {cxr_item['stay_id']}")
            img = cxr_item['image']
            if mean is not None and std is not None and img is not None:
                img = img.clone()
                for t, m, s in zip(img, mean, std):
                    t.mul_(s).add_(m)
                image_np = img.permute(1, 2, 0).cpu().numpy()
                image_np = np.clip(image_np, 0, 1)
            else:
                image_np = img.permute(1, 2, 0).cpu().numpy()
            plt.imshow(image_np)
            plt.title(f"Stay ID: {cxr_item['stay_id']}")
            plt.axis('off')
            plt.show()
            print("Report snippet:")
            print(cxr_item['report'][:500] + ("..." if len(cxr_item['report']) > 500 else ""))
        print("\n[Medical Data]")
        med_idxs = self.med_indices[idx] if isinstance(self.med_indices[idx], list) else [self.med_indices[idx]]
        for i, med_item in enumerate(sample['med_items']):
            med_idx = med_idxs[i] if i < len(med_idxs) else med_idxs[0]
            if hasattr(self.med_dataset, "visualize_item"):
                self.med_dataset.visualize_item(med_idx, show=True)
            else:
                print(med_item)

if __name__ == "__main__":
    import os
    from mimiccxr_dataset import MIMICCXRDataset
    from medical_dataset import MedicalDataset
    from omegaconf import OmegaConf

    # 路径参数（请根据实际情况修改）
    base_data_path = '/ssd/0/wzq/Multi_Med/'
    index_file = os.path.join(base_data_path, 'mimic-cxr-images-512/index.json')
    image_dir = os.path.join(base_data_path, '/hdd/0/dkm/REFERS-master/data/MIMIC/')
    reports_dir = os.path.join(base_data_path, 'mimic-cxr-reports')

    # 图像预处理
    from torchvision import transforms
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])

    # 加载CXR数据集
    cxr_dataset = MIMICCXRDataset(
        index_file_path=index_file,
        image_root=image_dir,
        reports_root=reports_dir,
        transform=transform
    )

    opt = OmegaConf.load("/ssd/0/wzq/Multi_Med/exp/mimic_data/exp_mix_age.yaml")

    med_dataset = MedicalDataset(**opt.data.train_val, **opt.data.shared_param)

    # 构建融合数据集
    json_path = os.path.join(base_data_path, 'datapress', 'aligned_subjects.json')
    multi_dataset = MultiModalAlignedDataset(cxr_dataset, med_dataset, sid_json_path=json_path)
    print(f"Aligned multimodal dataset size: {len(multi_dataset)}")

    # 随机取一个样本，打印内容
    sample = multi_dataset[0]
    print("Sample keys:", sample.keys())
    print("Subject ID:", sample['subject_id'])

    
    print(sample['med_items']) # dynamic_data, static_data, label, dynamic_data_now, patient_info
    print(sample['cxr_items'])

    # multi_dataset.visualize_all(0, mean=MEAN, std=STD, save_path="outputs/sample_visualization/")

