import os
import json
import cv2
import torch
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
from torch.utils.data import Dataset
from torchvision import transforms
from omegaconf import OmegaConf
from mimiccxr_dataset import MIMICCXRDataset
from medical_dataset import MedicalDataset
from multiModelAligned_dataset import MultiModalAlignedDataset

# === Global variables for annotation ===
annotations = []
current_image = None
current_phrases = []
image_index = 0
phrase_index = 0
output_dir = "annotations"
os.makedirs(output_dir, exist_ok=True)

# === Simple Phrase Extractor ===
def simple_phrase_extractor(report: str) -> list:
    lines = report.strip().split('\n')
    phrases = [line.strip().strip('.') for line in lines if len(line.strip()) > 5]
    return list(set(phrases))

# === Annotation Callback ===
def onselect(eclick, erelease):
    global phrase_index
    x1, y1 = int(eclick.xdata), int(eclick.ydata)
    x2, y2 = int(erelease.xdata), int(erelease.ydata)
    x, y, w, h = min(x1,x2), min(y1,y2), abs(x2-x1), abs(y2-y1)

    ann = {
        "image": image_files[image_index],
        "phrase": current_phrases[phrase_index],
        "bbox": [x, y, w, h],
        "time_index": time_index[image_index]  # optional
    }
    annotations.append(ann)
    print(f"Annotated: {ann['phrase']} => {ann['bbox']}")
    phrase_index += 1

    if phrase_index < len(current_phrases):
        show_image()
    else:
        save_annotations()
        next_image()

# === Display Function ===
def show_image():
    plt.clf()
    plt.imshow(current_image)
    plt.title(f"{image_files[image_index]}\nPhrase: {current_phrases[phrase_index]}")
    toggle_selector.RS = RectangleSelector(plt.gca(), onselect, drawtype='box', useblit=True,
                                           button=[1], minspanx=5, minspany=5, spancoords='pixels')
    plt.show()

# === Load Next Image ===
def next_image():
    global image_index, phrase_index, current_image, current_phrases
    image_index += 1
    phrase_index = 0
    if image_index >= len(image_files):
        print("All images annotated.")
        return
    current_image = cv2.cvtColor(cv2.imread(image_files[image_index]), cv2.COLOR_BGR2RGB)
    current_phrases = phrase_dict.get(os.path.basename(image_files[image_index]), [])
    show_image()

# === Save ===
def save_annotations():
    with open(os.path.join(output_dir, "annotations.json"), "w") as f:
        json.dump(annotations, f, indent=2)
    print(f"Saved {len(annotations)} annotations.")

# === Phrase-Image Contrastive Pair Constructor ===
def build_phrase_image_pairs(aligned_dataset, phrase_extractor):
    contrastive_pairs = []
    for i in range(len(aligned_dataset)):
        sample = aligned_dataset[i]
        subject_id = sample['subject_id']
        cxr_items = sample['cxr_items']
        for item in cxr_items:
            image_tensor = item['image']
            phrases = phrase_extractor(item['report'])
            for p in phrases:
                contrastive_pairs.append({
                    'image': image_tensor,
                    'phrase': p,
                    'subject_id': subject_id,
                    'is_positive': True
                })
    return contrastive_pairs

# === Main Execution and Dataset Pipeline ===
if __name__ == "__main__":
    base_data_path = '/ssd/0/wzq/Multi_Med/'
    index_file = os.path.join(base_data_path, 'mimic-cxr-images-512/index.json')
    image_dir = os.path.join(base_data_path, 'mimic-cxr-images-512')
    reports_dir = os.path.join(base_data_path, 'mimic-cxr-reports')

    # 图像预处理变换
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])

    # 加载CXR数据集
    cxr_dataset = MIMICCXRDataset(
        index_file_path=index_file,
        image_root=image_dir,
        reports_root=reports_dir,
        transform=transform
    )

    # 加载MED数据集
    opt = OmegaConf.load("/ssd/0/wzq/Multi_Med/exp/mimic_data/exp_mix_age.yaml")
    med_dataset = MedicalDataset(**opt.data.train_val, **opt.data.shared_param)

    # 构建对齐多模态数据集
    json_path = os.path.join(base_data_path, 'datapress', 'aligned_subjects.json')
    multi_dataset = MultiModalAlignedDataset(cxr_dataset, med_dataset, sid_json_path=json_path)
    print(f"Aligned multimodal dataset size: {len(multi_dataset)}")

    # 构造对比学习训练对
    pairs = build_phrase_image_pairs(multi_dataset, simple_phrase_extractor)
    print(f"Constructed {len(pairs)} phrase-image pairs with simple phrase extractor.")

# === Static Demo Image Folder Annotation ===
image_folder = "images/"
image_files = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.jpg')])
time_index = [i for i in range(len(image_files))]  # optional time sequence

# 自动提取 Phrase 字典
phrase_dict = {}
for f in image_files:
    report_path = f.replace(".jpg", ".txt").replace("images", "reports")
    if os.path.exists(report_path):
        with open(report_path, 'r') as rf:
            report_text = rf.read()
        phrases = simple_phrase_extractor(report_text)
        phrase_dict[os.path.basename(f)] = phrases

if image_files:
    current_image = cv2.cvtColor(cv2.imread(image_files[0]), cv2.COLOR_BGR2RGB)
    current_phrases = phrase_dict.get(os.path.basename(image_files[0]), [])
    show_image()
else:
    print("No images found in folder.")
