from typing import Optional, Callable
from torchvision.datasets import CocoDetection
from omegaconf import OmegaConf
import os
import pandas as pd
import json
from PIL import Image
import io
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch

class CocoCaptionsDict(CocoDetection):

    def __init__(
            self,
            split,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
            tokenizer=None,
    ) -> None:
        assert split in ['train', 'val']
        annFile = os.path.join(root, 'annotations/captions_val2017.json')
        self.split = split
        conf = OmegaConf.load(os.path.join(root, 'coco_split.yaml'))
        root = os.path.join(root, 'val2017')
        self._train_ids = conf['train']
        self._val_ids = conf['test']
        self.tokenizer = tokenizer

        super().__init__(root, annFile, transform, target_transform, transforms)

        if split == 'train':
            self.ids = [self.ids[i] for i in self._train_ids]
        elif split == 'val':
            self.ids = [self.ids[i] for i in self._val_ids]
        self._init_tokenize_captions()

    def _init_tokenize_captions(self):
        pass

    def save_images_and_metadata(self, output_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        metadata_file = os.path.join(output_dir, 'metadata.jsonl')

        with open(metadata_file, 'w') as f:
            for idx in range(len(self.ids)):
                img_id = self.ids[idx]
                img, target = self._load_image_and_target(idx)

                # images
                img_path = os.path.join(output_dir, f'{img_id}.jpg')
                img.save(img_path)

                # captions
                captions = [ann['caption'] for ann in target]

                # metadata
                if captions:
                    caption = captions[0]

                    metadata = {
                        "file_name": f'{img_id}.jpg',
                        "text": caption
                    }
                    f.write(json.dumps(metadata) + '\n')

    def _load_image_and_target(self, idx):
        img_id = self.ids[idx]
        image = self._load_image(img_id)
        target = self._load_target(img_id)
        return image, target


device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained("../../SecMI-LDM/checkpoint/blip")
model = BlipForConditionalGeneration.from_pretrained("../../SecMI-LDM/checkpoint/blip").to(device)

folder_path = '../../SecMI-LDM/datasets/datasets/Flickr'
dfs = []

for file in os.listdir(folder_path):
    if file.endswith('.parquet'):
        file_path = os.path.join(folder_path, file)
        df = pd.read_parquet(file_path)
        dfs.append(df)

combined_df = pd.concat(dfs, ignore_index=True)

new_captions = []

for index, row in combined_df.iterrows():
    split = row['split']
    image_dict = row['image']
    image_bytes = image_dict['bytes']

    image = Image.open(io.BytesIO(image_bytes))

    inputs = processor(images=image, return_tensors="pt").to(device)
    outputs = model.generate(**inputs)
    caption = processor.decode(outputs[0], skip_special_tokens=True)
    new_captions.append(caption)



# New DataFrame
new_df = combined_df.copy()
new_df['caption'] = new_captions

# New .parquet
new_file_path = '../../SecMI-LDM/datasets/datasets/Flickr_blip.parquet'
new_df.to_parquet(new_file_path)

print(f"save: {new_file_path}")




