import os,zipfile
import torch
import clip
from tqdm import tqdm
from PIL import Image
import requests
from io import BytesIO
from datasets import load_dataset
from multiprocessing.pool import ThreadPool
import tarfile
import ast 
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import json, random
from pathlib import Path
from typing import List, Dict, Any, Optional
from torchvision import transforms
from glob import glob
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet 均值
        std=[0.229, 0.224, 0.225]
    ),
])

def build_knowledge_base(data, data_dir,clip_model, preprocess, save_path, device="cuda"):
   
    image_embeds = []
    text_embeds = []
    meta_infos = []

    for item in tqdm(data, desc="Building Knowledge Base"):
        img_path = os.path.join(data_dir, item['image_path'])
        caption = item["caption"]
        image = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)     
        text = clip.tokenize([caption[:77]]).to(device)

        with torch.no_grad():
            img_feat = clip_model.encode_image(image)
            txt_feat = clip_model.encode_text(text)

            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)

        image_embeds.append(img_feat.cpu())
        text_embeds.append(txt_feat.cpu())
        meta_infos.append({
            "image_path": img_path,
            "caption": caption
        })

    image_embeds = torch.cat(image_embeds, dim=0)
    text_embeds = torch.cat(text_embeds, dim=0)

    torch.save({
        "image_embeds": image_embeds,
        "text_embeds": text_embeds,
        "meta": meta_infos
    }, save_path)

    print(f"Knowledge base saved to: {save_path}")
    


def build_clip_features(clip_model, preprocess, data, device):
    image_features, text_features, raw_entries = [], [], []

    for item in tqdm(data, desc="Extracting CLIP features"):
        
        image = preprocess(Image.open(item["image_path"])).unsqueeze(0).to(device)
        text = clip.tokenize([item["caption"][:77]]).to(device)

        with torch.no_grad():
            img_feat = clip_model.encode_image(image)
            txt_feat = clip_model.encode_text(text)
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)


        image_features.append(img_feat.cpu())
        text_features.append(txt_feat.cpu())
        raw_entries.append(item)
       
    return {
        "image_embeds": torch.cat(image_features),
        "text_embeds": torch.cat(text_features),
        "meta": raw_entries
    }

from PIL import ExifTags

def build_cc3m_knowledge_base(clip_model, preprocess, kb_dir="./cc3m_kb", num_samples=10000, num_workers=16, device="cuda"):
    os.makedirs(kb_dir, exist_ok=True)
    img_dir = os.path.join(kb_dir, "images")
    os.makedirs(img_dir, exist_ok=True)
    kb_path = os.path.join(kb_dir, "cc3m_kb.pt")

    if os.path.exists(kb_path):
        print("External knowledge base already exists. Loading...")
        return torch.load(kb_path)

    print("Loading CC3M dataset from HuggingFace...")
    ds = load_dataset("pixparse/cc3m-wds", split="train")
    ds = ds.shuffle(seed=42).select(range(num_samples))
    ds = ds.with_format("python") 

    entries = []
    for i, example in enumerate(iter(ds)):
        image = example.get("jpg")
        caption = example.get("txt")
        if image and caption:
            save_path = os.path.join(img_dir, f"{i}.jpg")
            image.save(save_path)
            entries.append({"image_path": save_path, "caption": caption})

    print("Extracting CLIP features...")
    kb = build_BioViL_features(clip_model, preprocess, entries, device)

    print(f"Saving knowledge base to {kb_path}...")
    torch.save(kb, kb_path)
    return kb


def parse_flickr30k_csv(csv_path, image_dir, num_samples=None, split="train"):
    df = pd.read_csv(csv_path)
    df = df[df["split"] == split]
    entries = []
    for idx, row in df.iterrows():
        raw_list = ast.literal_eval(row["raw"])
        caption = raw_list[0]  
        image_path = os.path.join(image_dir, row["filename"])
        if os.path.exists(image_path):
            entries.append({
                "image_path": image_path,
                "caption": caption
            })
        
        if num_samples and len(entries) >= num_samples:
            break

    print(f"Collected {len(entries)} image-caption pairs for split='{split}'.")
    return entries
def extract_zip(zip_path, extract_to):
    if not os.path.exists(extract_to) or not os.listdir(extract_to):
        print(f"Extracting {zip_path}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print(f"Extracted to: {extract_to}")
    else:
        print(f"Image directory already exists: {extract_to}")

def build_flickr30k_knowledge_base(
    clip_model, preprocess,
    kb_dir="./flickr30k_kb",
    device="cuda",
    num_samples=None,
    num_workers=16,
    archive_path="./flickr30k_kb/flickr30k-images.zip",
    caption_csv="./flickr30k_kb/flickr_annotations_30k.csv"
):
    os.makedirs(kb_dir, exist_ok=True)
    img_dir = os.path.join(kb_dir, "images")
    kb_path = os.path.join(kb_dir, "flickr30k_kb.pt")

    if os.path.exists(kb_path):
        print("Flickr30k knowledge base already exists. Loading...")
        return torch.load(kb_path)

    extract_zip(archive_path, img_dir)
    print("Parsing captions...")
    entries = parse_flickr30k_csv(caption_csv, os.path.join(img_dir, "flickr30k-images"), num_samples=num_samples, split="train")

    print(f"Total image-caption pairs: {len(entries)}")
    kb = build_clip_large_features(clip_model, preprocess, entries, device=device)

    torch.save({
        "image_embeds": kb["image_embeds"],
        "text_embeds": kb["text_embeds"],
        "meta": kb["meta"]
    }, kb_path)

    print(f"Saved Flickr30k knowledge base to: {kb_path}")
    return kb

def _load_captions_map(ann_dir):
    captions_map = {}
    json_files = sorted(glob(os.path.join(ann_dir, "*.json")))
    if not json_files:
        raise FileNotFoundError(f"{ann_dir} no *.json")

    for jf in json_files:
        with open(jf, "r", encoding="utf-8") as f:
            data = json.load(f)

        for key, value in (data.items() if isinstance(data, dict) else []):
            if not isinstance(value, list):
                continue
            for rec in value:
                fn = rec.get("filename")
                if not fn:
                    continue
                caps = []
                for kk in ("raw", "raw_1", "raw_2", "raw_3", "raw_4"):
                    v = rec.get(kk)
                    if isinstance(v, str) and v.strip():
                        caps.append(v.strip())
                if caps:
                    captions_map[fn] = caps
    if not captions_map:
        raise RuntimeError(f"{ann_dir} no caption")
    return captions_map

def _key_to_filename(hf_key: str) -> str:
    
    base = os.path.basename(hf_key)
    return base + ".jpg"

def build_nwpu_knowledge_base(clip_model, preprocess,
                              kb_dir="./nwpu_kb",
                              data_dir="./NWPU-Captions",
                              ann_dir="./NWPU-Captions",  
                              num_samples=10000,
                              device="cuda"):
    os.makedirs(kb_dir, exist_ok=True)
    image_save_dir = os.path.join(data_dir, "images")
    print(image_save_dir)
    os.makedirs(image_save_dir, exist_ok=True)
    kb_path = os.path.join(kb_dir, "nwpu_kb.pt")

    if os.path.exists(kb_path):
        print("Knowledge base already exists. Loading...")
        return torch.load(kb_path, map_location=device)
    captions_map = _load_captions_map(ann_dir)
    dataset = load_dataset("KhangTruong/NWPU-Caption", split="train")
    if num_samples is not None:
        dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))

    entries = []
    skipped_no_caption = 0
    saved_images = 0

    for item in tqdm(dataset, desc="🖼️ Processing"):
        pil_img = item.get("jpg")
        hf_key = item.get("__key__")
        
        filename = _key_to_filename(hf_key)
        caps = captions_map.get(filename)
        if not caps:
            skipped_no_caption += 1
            continue

        img_save_path = os.path.join(image_save_dir, filename)
        if not os.path.exists(img_save_path):
            
            pil_img.save(img_save_path)
            saved_images += 1

        entries.append({
            "image_path": f"{image_save_dir}/{filename}",
            "caption": caps[0].strip()
        })
    if not entries:
        raise RuntimeError("no avalible sample")

    kb = build_clip_features(clip_model, preprocess, entries, device=device)

    print(f"save in {kb_path}")
    torch.save(kb, kb_path)
    return kb