import os
import json
import random
import zipfile
from pycocotools.coco import COCO
import pandas as pd
import ast
from pathlib import Path
from typing import Dict, List
from datasets import load_dataset
from PIL import Image
class COCODatasetProcessor:
    def __init__(self, data_dir="./coco_data"):
        self.data_dir = data_dir
        os.makedirs(data_dir, exist_ok=True)
        
    def download_coco_subset(self):
        annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
        images_url = "http://images.cocodataset.org/zips/train2017.zip"
        
        print("Downloading COCO annotations...")
        self._download_and_extract(annotations_url, "annotations_trainval2017.zip")
        
        print("Downloading COCO images...")
        self._download_and_extract(images_url, "train2017.zip")
        
    def _download_and_extract(self, url, filename):
        filepath = os.path.join(self.data_dir, filename)
        if not os.path.exists(filepath):
            response = requests.get(url, stream=True)
            total_size = int(response.headers.get('content-length', 0))
            block_size = 8192
            downloaded = 0
            
            print(f"Downloading {filename}...")
            with open(filepath, 'wb') as f:
                for chunk in response.iter_content(chunk_size=block_size):
                    if chunk:
                        f.write(chunk)
                        downloaded += len(chunk)
                        done = int(50 * downloaded / total_size)
                        print(f"\r[{'=' * done}{' ' * (50-done)}] {downloaded}/{total_size} bytes", end='')
          
        
        print(f"Extracting {filename}...")
        with zipfile.ZipFile(filepath, 'r') as zip_ref:
            zip_ref.extractall(self.data_dir)
    
    def create_subset_and_split(self, train_size=5000, test_size=1000, 
                               img_only_ratio=0.2, text_only_ratio=0.2):
        coco = COCO(os.path.join(self.data_dir, 'annotations/captions_train2017.json'))

        img_ids = list(coco.imgs.keys())
        random.shuffle(img_ids)

        selected_ids = img_ids[:train_size + test_size]
        train_ids = selected_ids[:train_size]
        test_ids = selected_ids[train_size:]

        train_data = self._create_modal_split(coco, train_ids, img_only_ratio, text_only_ratio)
        test_data = self._create_modal_split(coco, test_ids, 0.0, 0.0)
        sava_dir = f'{self.data_dir}/{img_only_ratio}'
        os.makedirs(sava_dir, exist_ok=True)

        with open(os.path.join(sava_dir, 'train_split.json'), 'w') as f:
            json.dump(train_data, f)
        with open(os.path.join(sava_dir, 'test_split.json'), 'w') as f:
            json.dump(test_data, f)
        
        print(f"Created dataset split with {len(train_ids)} training and {len(test_ids)} test samples")
        print(f"Training set: {len(train_data['image_only'])} image-only, {len(train_data['text_only'])} text-only, {len(train_data['multimodal'])} multimodal")
        print(f"Test set: {len(test_data['image_only'])} image-only, {len(test_data['text_only'])} text-only, {len(test_data['multimodal'])} multimodal")
        
        return train_data, test_data
    
    def _create_modal_split(self, coco, img_ids, img_only_ratio, text_only_ratio):
       
        data = []
        
        for img_id in img_ids:
            img_info = coco.imgs[img_id]
            ann_ids = coco.getAnnIds(imgIds=img_id)
            anns = coco.loadAnns(ann_ids)
            
            if anns:  
                caption = anns[0]['caption'] 
                
                data.append({
                    'image_id': img_id,
                    'image_path': os.path.join('train2017', img_info['file_name']),
                    'caption': caption
                })

        random.shuffle(data)
        n = len(data)

        img_only = data[:int(img_only_ratio * n)]
        text_only = data[int(img_only_ratio * n):int((img_only_ratio + text_only_ratio) * n)]
        multimodal = data[int((img_only_ratio + text_only_ratio) * n):]

        for item in img_only:
            item['modality'] = 'image_only'
        for item in text_only:
            item['modality'] = 'text_only'
        for item in multimodal:
            item['modality'] = 'multimodal'
        
        return {
            'image_only': img_only,
            'text_only': text_only,
            'multimodal': multimodal
        }

class FlickrDatasetProcessor:
    def __init__(self,data_dir="./flickr30k_kb", archive_path="./flickr30k_kb/flickr30k-images.zip", 
                 caption_csv="./flickr30k_kb/flickr_annotations_30k.csv"
                 ):
        self.archive_path = archive_path
        self.caption_csv = caption_csv
        self.data_dir = data_dir
        self.image_dir = os.path.join(data_dir, "images/flickr30k-images")
        
    
    def create_subset_and_split(self, train_size=5000, test_size=1000, 
                                img_only_ratio=0.2, text_only_ratio=0.2):
        
        df = pd.read_csv(self.caption_csv)
        df['captions'] = df['raw'].apply(ast.literal_eval)
        df['caption'] = df['captions'].apply(lambda x: random.choice(x))
        records = df[['filename', 'caption','img_id']].to_dict(orient='records')
        random.shuffle(records)

        selected = records[:train_size + test_size]
        train_data = selected[:train_size]
        test_data = selected[train_size:]

        train_split = self._create_modal_split(train_data, img_only_ratio, text_only_ratio)
        test_split = self._create_modal_split(test_data, 0.0, 0.0)

        with open(os.path.join(self.data_dir, 'train_split.json'), 'w') as f:
            json.dump(train_split, f)
        with open(os.path.join(self.data_dir, 'test_split.json'), 'w') as f:
            json.dump(test_split, f)

        print(f"Created dataset split with {len(train_data)} training and {len(test_data)} test samples")
        print(f"Training set: {len(train_split['image_only'])} image-only, "
              f"{len(train_split['text_only'])} text-only, {len(train_split['multimodal'])} multimodal")
        print(f"Test set: {len(test_split['image_only'])} image-only, "
              f"{len(test_split['text_only'])} text-only, {len(test_split['multimodal'])} multimodal")

        return train_split, test_split

    def _create_modal_split(self, data, img_only_ratio, text_only_ratio):
        random.shuffle(data)
        n = len(data)

        img_only = data[:int(img_only_ratio * n)]
        text_only = data[int(img_only_ratio * n):int((img_only_ratio + text_only_ratio) * n)]
        multimodal = data[int((img_only_ratio + text_only_ratio) * n):]

        for item in img_only:
            item['modality'] = 'image_only'
        for item in text_only:
            item['modality'] = 'text_only'
        for item in multimodal:
            item['modality'] = 'multimodal'

        for item in data:
            item['image_path'] = os.path.join('images/flickr30k-images', item['filename'])
            item['image_id'] = item['img_id']
        return {
            'image_only': img_only,
            'text_only': text_only,
            'multimodal': multimodal
        }

class CC3MDatasetProcessor:
    """Utility class to create a CC3M training subset with missing‑modality splits."""

    def __init__(self,kb_dir: str = "./cc3m_kb"):
        self.kb_dir = Path(kb_dir)
        self.img_dir = self.kb_dir / "images_train"
        self.data_dir = self.img_dir
        self.data_dir.mkdir(parents=True, exist_ok=True)
 
    def create_subset_and_split(
        self,
        train_size: int = 10000,
        test_size: int = 0,
        img_only_ratio: float = 0.0,
        text_only_ratio: float = 0.0,
        *,
        num_workers: int = 16,
    ):
        total_needed = train_size + test_size
        entries = self._prepare_entries(total_needed, num_workers=num_workers)

        random.shuffle(entries)
        selected = entries[:total_needed]
        train_entries = selected[:train_size]
        test_entries = selected[train_size:]

        train_split = self._modal_split(train_entries, img_only_ratio, text_only_ratio)
        test_split = self._modal_split(test_entries, 0.0, 0.0)

        (self.data_dir / "train_split.json").write_text(json.dumps(train_split, ensure_ascii=False, indent=2))
        (self.data_dir / "test_split.json").write_text(json.dumps(test_split, ensure_ascii=False, indent=2))

        print(
            f"Created dataset split with {len(train_entries)} training and {len(test_entries)} test samples\n"
            f"Training set: {len(train_split['image_only'])} image‑only, {len(train_split['text_only'])} text‑only, {len(train_split['multimodal'])} multimodal\n"
            f"Test set: {len(test_split['image_only'])} image‑only, {len(test_split['text_only'])} text‑only, {len(test_split['multimodal'])} multimodal"
        )
        return train_split, test_split

    def _prepare_entries(self, n: int, *, num_workers: int) -> List[dict]:
        """Ensure *n* image‑caption pairs available locally and return them."""
        self.img_dir.mkdir(parents=True, exist_ok=True)
        ds = load_dataset(
            "pixparse/cc3m-wds",
            split="train",
            num_proc=num_workers,
        )
        entries: List[dict] = []
        for i, ex in enumerate(ds):
            if len(entries) >= n:
                break
            img_filename = self.img_dir / f"{i}.jpg"
            caption = ex.get("txt")
            
            if caption is None:
                continue  # skip bad sample

            img_bytes = ex.get("jpg")
            if img_bytes is None:
                continue
            img_bytes.save(img_filename)
            
            entries.append({
                "image_id":i,
                "image_path": str(img_filename.relative_to(self.kb_dir.parent)),
                "caption": caption,
            })
            if len(entries) % 500 == 0:
                print(f"Prepared {len(entries)}/{n} pairs…", end="\r")

        print(f"Finished preparing {len(entries)} image‑caption pairs.")
        return entries

    def _modal_split(
        self,
        data: List[dict],
        img_only_ratio: float,
        text_only_ratio: float,
    ) -> Dict[str, List[dict]]:
        random.shuffle(data)
        n = len(data)
        img_cut = int(img_only_ratio * n)
        text_cut = int((img_only_ratio + text_only_ratio) * n)

        img_only = data[:img_cut]
        text_only = data[img_cut:text_cut]
        multimodal = data[text_cut:]

        for item in img_only:
            item["modality"] = "image_only"
            item["caption"] = "" 
        for item in text_only:
            item["modality"] = "text_only"
            item["image_path"] = ""  
        for item in multimodal:
            item["modality"] = "multimodal"

        return {
            "image_only": img_only,
            "text_only": text_only,
            "multimodal": multimodal,
        }
class RSICDDatasetProcessor:
    def __init__(self, data_dir="./rsicd_data"):
        self.data = data_dir
        self.ds_train = None  
        self.ds_test = None  
        
        self.images_dir = os.path.join(self.data, "images")
        os.makedirs(self.images_dir, exist_ok=True)

    def prepare_data(self):
        self.ds_train = load_dataset("arampacha/rsicd", split="train")
        self.ds_train = self.ds_train.shuffle(seed=42)
        self.ds_test = load_dataset("arampacha/rsicd", split="test")
        self.ds_test = self.ds_test.shuffle(seed=42)
        
    def create_subset_and_split(self, train_size=8734, test_size=1000,
                                 img_only_ratio=0.2, text_only_ratio=0.2):
        total = train_size + test_size
        self.prepare_data()

        train_ds = self.ds_train
        test_ds = self.ds_test

        train_data = self._create_modal_split(train_ds, img_only_ratio, text_only_ratio)
        test_data = self._create_modal_split(test_ds, 0.0, 0.0)

        save_dir = os.path.join(self.data, "splits")
        os.makedirs(save_dir, exist_ok=True)
        with open(os.path.join(save_dir, "train_split.json"), 'w') as f:
            json.dump(train_data, f, indent=2)
        with open(os.path.join(save_dir, "test_split.json"), 'w') as f:
            json.dump(test_data, f, indent=2)

        print(f"Train: {len(train_data['image_only'])} image-only, "
              f"{len(train_data['text_only'])} text-only, "
              f"{len(train_data['multimodal'])} multimodal")
        print(f"Test:  {len(test_data['image_only'])} image-only, "
              f"{len(test_data['text_only'])} text-only, "
              f"{len(test_data['multimodal'])} multimodal")
        return train_data, test_data

    def _create_modal_split(self, dataset, img_only_ratio, text_only_ratio):
        data = []

        for example in dataset:
            img_name = example["filename"]
            image = example["image"]
            caption = example["captions"][0]

            img_save_path = os.path.join(self.images_dir, os.path.basename(img_name))
            if not os.path.exists(img_save_path):
                image.save(img_save_path)
              
            data.append({
                'image_id': os.path.splitext(os.path.basename(img_name))[0],
                'image_path': f"images/{os.path.basename(img_name)}",  # 本地路径
                'caption': caption,
            })

        random.shuffle(data)
        n = len(data)

        img_only = data[:int(img_only_ratio * n)]
        text_only = data[int(img_only_ratio * n):int((img_only_ratio + text_only_ratio) * n)]
        multimodal = data[int((img_only_ratio + text_only_ratio) * n):]

        for item in img_only:
            item['modality'] = 'image_only'
        for item in text_only:
            item['modality'] = 'text_only'
        for item in multimodal:
            item['modality'] = 'multimodal'

        return {
            'image_only': img_only,
            'text_only': text_only,
            'multimodal': multimodal
        }
