#%%
import torch
from torch import nn, optim
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
import json
from PIL import Image
import io
import os
from tqdm import tqdm
from transformers import BlipProcessor, BlipForImageTextRetrieval
import traceback

#%%
# Config
model_path = "pretrained_frameworks/blip-itm-base-coco"
data_path = "downloaded_datatset/HumanEdit"
json_path = "downloaded_datatset/target_description.json"
batch_size = 32
epochs = 20
lr = 1e-5
device = "cuda" if torch.cuda.is_available() else "cpu"

#%%
# Load BLIP model
processor = BlipProcessor.from_pretrained(model_path)
model = BlipForImageTextRetrieval.from_pretrained(model_path).to(device)

#%%
# Load target descriptions
with open(json_path, "r") as f:
    desc_data = json.load(f)
imgid2desc = {item["img_id"]: item["tar_desc"] for item in desc_data} if isinstance(desc_data, list) else desc_data

#%%
# Dataset
class StreamingCLIPDataset(IterableDataset):
    def __init__(self, hf_dataset, imgid2desc, processor, image_key="OUTPUT_IMG", id_key="IMAGE_ID"):
        self.dataset = hf_dataset
        self.imgid2desc = imgid2desc
        self.image_key = image_key
        self.id_key = id_key
        self.processor = processor

    def __iter__(self):
        for sample in self.dataset:
            try:
                img_id = sample[self.id_key]
                if img_id not in self.imgid2desc:
                    continue

                image_data = sample[self.image_key]
                if isinstance(image_data, bytes):
                    image = Image.open(io.BytesIO(image_data)).convert("RGB")
                elif isinstance(image_data, Image.Image):
                    image = image_data.convert("RGB")
                else:
                    continue

                caption = self.imgid2desc[img_id]
                yield {
                    "image": image,
                    "text": caption
                }
            except Exception:
                continue

#%%
# Collate function to handle PIL and text lists
def collate_fn(batch):
    images = [item["image"] for item in batch]
    texts = [item["text"] for item in batch]
    return {"images": images, "texts": texts}

#%%
# Load data
dataset_raw = load_dataset(data_path, split="train", streaming=False)
train_dataset = StreamingCLIPDataset(dataset_raw, imgid2desc, processor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)

#%%
# Loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

#%%
# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images = batch["images"]
        texts = batch["texts"]

        # Preprocess images and texts
        image_inputs = processor(images=images, return_tensors="pt").to(device)
        pixel_values = image_inputs["pixel_values"]

        text_inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True,).to(device)
        input_ids = text_inputs["input_ids"]
        attention_mask = text_inputs["attention_mask"]

        image_embeds = model.vision_model(pixel_values=pixel_values).last_hidden_state[:, 0, :]
        text_embeds = model.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        
        # Normalize
        image_features = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
        text_features = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

        # Compute cosine similarity loss
        logits_per_image = image_features @ text_features.T
        logits_per_text = logits_per_image.T
        targets = torch.arange(image_features.size(0), device=device)
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        del image_inputs, text_inputs, pixel_values, input_ids, attention_mask
        torch.cuda.empty_cache()

        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{epochs} - Average Loss: {total_loss:.4f}")

#%%
# Save fine-tuned model
save_dir = "finetuned_models/blip_finetuned"
os.makedirs(save_dir, exist_ok=True)  # Ensure directory exists

save_path = f"{save_dir}/blip_finetuned_{epochs}.pth"

# Save the model's state_dict
torch.save(model.state_dict(), save_path)
print(f"Model saved to '{save_path}'")

# %%
