import json
import numpy as np
import os
import torch
import open_clip
from PIL import Image
from PIL.Image import DecompressionBombError
import csv
import faiss
import pandas as pd
from sentence_transformers import SentenceTransformer

# Parameters
BATCH_SIZE = 16
CHUNK_SIZE = 100000
JSONL_FILE = 'Wiki6M_ver_1_0_updated.jsonl'
OUTPUT_DIR = '/data2/ckddls1321/data/temp3/'
CSV_FILE_TEMPLATE = 'metadata_chunk_{}.csv'
EMBEDDINGS_FILE_TEMPLATE = 'image_embeddings_chunk_{}.npy'
FAISS_INDEX_FILE = 'mixed_index_large.index'  

# For the fair comparison, provide light weight retriever
image_model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/ViT-SO400M-16-SigLIP2-384')
tokenizer = open_clip.get_tokenizer('hf-hub:timm/ViT-SO400M-16-SigLIP2-384')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
image_model.to(device)
image_model.eval()
text_model = SentenceTransformer("Alibaba-NLP/gte-modernbert-base", device=device)

# For the performance, we can use large retriever
# image_model, preprocess = open_clip.create_model_from_pretrained(
#     'hf-hub:timm/ViT-gopt-16-SigLIP2-384'
# )
# tokenizer = open_clip.get_tokenizer('hf-hub:timm/ViT-gopt-16-SigLIP2-384')
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# image_model.to(device)
# image_model.eval()

# Load the text model (SentenceTransformer) as Qwen3 Embedding 0.6B
#text_model = SentenceTransformer(
#    "Qwen/Qwen3-Embedding-0.6B",
#    model_kwargs={
#        "torch_dtype": torch.bfloat16,
#        "attn_implementation": "flash_attention_2",
#        "device_map": "auto"
#    },
#    tokenizer_kwargs={"padding_side": "left"}
#)

# Ensure the output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Helper function to save metadata to CSV
def save_metadata_to_csv(metadata, csv_file):
    with open(csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=["image_path", "caption"])
        writer.writeheader()
        writer.writerows(metadata)

# Helper function to process a batch of images and captions
def process_batch(image_paths, captions):
    images = []
    texts = []

    for path, caption in zip(image_paths, captions):
        try:
            image = Image.open(path).convert("RGB")
            image = preprocess(image)
            images.append(image)
            texts.append(caption)
        except (IOError, ValueError, OSError, DecompressionBombError) as e:
            print(f"Error processing image {path}: {e}")
            continue

    if not images:
        return np.empty((0,)) 

    images_tensor = torch.stack(images).to(device)
    with torch.no_grad():
        img_feats = image_model.encode_image(images_tensor, normalize=True)
        # input_ids = tokenizer(texts, context_length=image_model.context_length).to(device) 
        # img_feats = img_feats + image_model.encode_text(input_ids, normalize=True) # For the large index, we can use RA-CM3
        txt_feats = text_model.encode(
            texts,
            normalize_embeddings=True,
            convert_to_tensor=True,
            convert_to_numpy=False
        )
        mixed = torch.concat([img_feats, txt_feats], dim=-1)

    return mixed.cpu().numpy()

# Read JSONL, chunk, embed, and save
def process_jsonl_in_chunks(jsonl_file, chunk_size):
    image_paths, captions = [], []
    with open(jsonl_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            image_paths.append(data["image_path"])
            captions.append(data["wikipedia_summary"])

    num_chunks = (len(image_paths) + chunk_size - 1) // chunk_size
    for chunk_idx in range(num_chunks):
        emb_file = os.path.join(
            OUTPUT_DIR, EMBEDDINGS_FILE_TEMPLATE.format(chunk_idx)
        )
        csv_file = os.path.join(
            OUTPUT_DIR, CSV_FILE_TEMPLATE.format(chunk_idx)
        )

        if os.path.exists(emb_file):
            print(f"Chunk {chunk_idx} already done, skipping.")
            continue

        start, end = chunk_idx * chunk_size, (chunk_idx + 1) * chunk_size
        imgs = image_paths[start:end]
        caps = captions[start:end]

        print(f"Processing chunk {chunk_idx+1}/{num_chunks}...")
        all_embs, meta = [], []
        for i in range(0, len(imgs), BATCH_SIZE):
            batch_emb = process_batch(imgs[i:i+BATCH_SIZE], caps[i:i+BATCH_SIZE])
            if batch_emb.size:
                all_embs.append(batch_emb)
            meta.extend([
                {"image_path": p, "caption": c}
                for p, c in zip(imgs[i:i+BATCH_SIZE], caps[i:i+BATCH_SIZE])
            ])

        if all_embs:
            arr = np.vstack(all_embs)
            np.save(emb_file, arr)
            save_metadata_to_csv(meta, csv_file)
            print(f"Saved chunk {chunk_idx}: {emb_file}, {csv_file}")
        else:
            print(f"No valid embeddings for chunk {chunk_idx}")

# Merge embeddings and build FAISS index named `mixed_index`
def merge_and_create_faiss_index(output_dir, index_file):
    print("Merging all embeddings...")
    embs = []
    for fname in sorted(os.listdir(output_dir)):
        if fname.startswith("image_embeddings_chunk") and fname.endswith(".npy"):
            embs.append(np.load(os.path.join(output_dir, fname)))

    if not embs:
        print("No embeddings found.")
        return

    all_embs = np.vstack(embs)
    dim = all_embs.shape[1]
    print(f"Total embeddings: {all_embs.shape}")

    print("Building FAISS IndexFlatIP...")
    idx = faiss.IndexFlatIP(dim)
    idx.add(all_embs)
    faiss.write_index(idx, index_file)
    print(f"FAISS index saved to {index_file}")

# Merge metadata CSVs
def merge_csv_files(output_dir):
    print("Merging metadata CSVs...")
    dfs = []
    for fname in sorted(os.listdir(output_dir)):
        if fname.startswith("metadata_chunk_") and fname.endswith(".csv"):
            dfs.append(pd.read_csv(os.path.join(output_dir, fname)))

    if dfs:
        merged = pd.concat(dfs, ignore_index=True)
        out = os.path.join(output_dir, "mixed_metadata.csv")
        merged.to_csv(out, index=False)
        print(f"Merged metadata saved to {out}")

def main():
    process_jsonl_in_chunks(JSONL_FILE, CHUNK_SIZE)
    merge_and_create_faiss_index(OUTPUT_DIR, FAISS_INDEX_FILE)
    merge_csv_files(OUTPUT_DIR)
    print("Done.")

if __name__ == "__main__":
    main()

