import os
import sys
from argparse import ArgumentParser
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from PIL import Image
import io
from datasets import load_dataset
from reward_distill.model.clip_mlp_model import MLP
from reward_models.reward_interface import UnifiedReward
from customize_scheduler.customize_euler import CustomEuler
from customize_pipeline.custom_sdxl_pipeline import CustomizeStableDiffusionXLPipeline
from tqdm import tqdm
import clip
from accelerate import Accelerator
from torch.nn import MSELoss
from torch.utils.data import DataLoader


import ipdb

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    device = torch.device(args.device)

    ground_truth = UnifiedReward("pickscore")

    # dataset = load_dataset("THUDM/ImageRewardDB", "1k")
    dataset = load_dataset("yuvalkirstain/pickapic_v1", num_proc=1)

    # Paths
    save_path = "processed_dataset.pt"

    # Define preprocessing pipeline
    preprocess = Compose([
    Resize((512, 512)),  # Resize to model input size
    CenterCrop(512),
    ToTensor(),
    Normalize([0.5], [0.5])  # Normalize to [-1, 1] range
    ])

    processed_data = []
    with torch.no_grad():
        for i in tqdm(range(len(dataset['train'])), desc="Preprocessing dataset"):
            image_bytes_0 = dataset['train'][i]['jpg_0']
            image_bytes_1 = dataset['train'][i]['jpg_1']
            prompt = dataset['train'][i]['caption']
            label = dataset['train'][i]['label_0']

            # Process image 0
            image_0 = Image.open(io.BytesIO(image_bytes_0)).convert("RGB")
            processed_image_0 = preprocess(image_0)
            score_0 = ground_truth.score(ToTensor()(image_0), prompt)

            # Process image 1
            image_1 = Image.open(io.BytesIO(image_bytes_1)).convert("RGB")
            processed_image_1 = preprocess(image_1)
            score_1 = ground_truth.score(ToTensor()(image_1), prompt)
            
            processed_data.append({
            "image_0": processed_image_0,
            "image_1": processed_image_1,
            "prompt": prompt,
            "label": label,
            "score_0": score_0,
            "score_1": score_1,
            })

    # Save processed dataset as a .pt file
    torch.save(processed_data, save_path)
    print(f"Processed dataset saved to {save_path}")

    
