import os
import json
import torch
from PIL import Image
from tqdm import tqdm
from typing import Optional,Dict, List, Tuple

from models.llava_model import LlavaWrapper
from models.clip_model import CLIPWrapper
from utils.image_processing import get_image_path, bbox_to_xyxy, load_and_crop_image
from utils.clustering import cluster_embeddings, initialize_cache_centers
from config import DEFAULT_CACHE_SIZE, DEFAULT_CLIP_THRESHOLD, DEFAULT_DEVICE, DEFAULT_MOMENTUM

class LocalCacheBuilder:
    def __init__(self, 
                 annotation_path: str,
                 image_dir: str,
                 momentum:Optional[float] = DEFAULT_MOMENTUM,
                 cache_size: Optional[int] = DEFAULT_CACHE_SIZE,
                 clip_threshold: Optional[float] = DEFAULT_CLIP_THRESHOLD,
                 device: str = DEFAULT_DEVICE):
        
        self.annotation_path = annotation_path
        self.image_dir = image_dir
        self.momentum = momentum
        self.cache_size = cache_size
        self.clip_threshold = clip_threshold
        self.device = device
        
        # Load dataset
        with open(annotation_path, 'r') as f:
            self.data = json.load(f)
        self.images = {img['id']: img for img in self.data['images']}
        self.annotations = self.data['annotations']
        
        # Initialize models
        self.llava = LlavaWrapper(device)
        self.clip = CLIPWrapper(device)
        
        # Initialize cache
        self.text_cache = None
        self.visual_cache = None

    def _process_single_image(self, img_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """Process all annotations for a single image"""
        img_info = self.images[img_id]
        
        try:
            img_path = get_image_path(img_info, self.image_dir)
            print(img_path)
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Skipping image_id={img_id}: {str(e)}")
            return [], []
            
        # Get all annotations for this image
        anns = [ann for ann in self.annotations if ann['image_id'] == img_id]
        bboxes = [bbox_to_xyxy(ann['bbox']) for ann in anns]
        
        # Generate descriptions with LLaVA
        valid_text_embeds, valid_visual_embeds = [], []
        for bbox in bboxes:
            if bbox[2]-bbox[0]<=10 or bbox[3]-bbox[1]<=10:
                continue
            roi = image.crop(bbox)
            inputs = self.llava.processor(
                text="<image>\nUSER: Generate a fine-grained textual description of the image (e.g., 'a zebra with black and white stripes\nASSISTANT:",
                images=roi, 
                return_tensors="pt"
            ).to(self.device)
            generate_ids = self.llava.model.generate(**inputs, max_length=50)
            desc = self.llava.processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
            desc = desc.split("ASSISTANT:")[-1].strip()
            # descriptions.append(desc)
        
        # Filter with CLIP and encode
            inputs = self.clip.processor(
                text=desc, 
                images=roi, 
                return_tensors="pt", 
                padding=True
            ).to(self.device)
            with torch.no_grad():
                outputs = self.clip.model(**inputs)
                sim = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds).item()
                if sim > self.clip_threshold:
                    valid_text_embeds.append(outputs.text_embeds.cpu())
                    valid_visual_embeds.append(outputs.image_embeds.cpu())
        
        return valid_text_embeds, valid_visual_embeds

    def build_cache(self, save_path: str = "local_cache.pth"):
        """Build and save the local cache"""
        all_text_embeds = []
        all_visual_embeds = []
        
        # Process all images
        for img_id in tqdm(self.images.keys(), desc="Processing Images"):
            text_embeds, visual_embeds = self._process_single_image(img_id)
            if text_embeds:
                all_text_embeds.extend(text_embeds)
                all_visual_embeds.extend(visual_embeds)
        
        # Convert to tensors
        text_embeds = torch.cat(all_text_embeds).to(self.device)
        visual_embeds = torch.cat(all_visual_embeds).to(self.device)
        
        # Cluster embeddings
        cluster_ids = cluster_embeddings(text_embeds, self.cache_size)
        
        # Initialize cache centers
        self.text_cache, self.visual_cache = initialize_cache_centers(
            text_embeds, visual_embeds, cluster_ids, self.cache_size, self.device
        )
        
        # Save
        torch.save({
            "text_cache": self.text_cache.cpu(),
            "visual_cache": self.visual_cache.cpu()
        }, save_path)
        print(f"Cache saved to {save_path}")

    def update_cache(self, new_text_embeds, new_visual_embeds):
        """
        动量更新缓存（训练时调用）
        :param new_text_embeds: 新文本特征 [N, D]
        :param new_visual_embeds: 新视觉特征 [N, D]
        """
        new_text_embeds = new_text_embeds.to(self.device)
        new_visual_embeds = new_visual_embeds.to(self.device)
        
        # Step 1: 计算相似度并分配聚类
        text_sim = torch.mm(new_text_embeds, self.text_cache.T)  # [N, M]
        visual_sim = torch.mm(new_visual_embeds, self.visual_cache.T)
        _, text_clusters = torch.max(text_sim, dim=1)  # [N]
        _, visual_clusters = torch.max(visual_sim, dim=1)
        
        # Step 2: 动量更新文本缓存
        for cluster_id in torch.unique(text_clusters):
            mask = (text_clusters == cluster_id)
            if mask.sum() > 0:
                selected = new_text_embeds[mask]
                self.text_cache[cluster_id] = (
                    self.momentum * self.text_cache[cluster_id] +
                    (1 - self.momentum) * selected.mean(dim=0))
        
        # Step 3: 动量更新视觉缓存
        for cluster_id in torch.unique(visual_clusters):
            mask = (visual_clusters == cluster_id)
            if mask.sum() > 0:
                selected = new_visual_embeds[mask]
                self.visual_cache[cluster_id] = (
                    self.momentum * self.visual_cache[cluster_id] +
                    (1 - self.momentum) * selected.mean(dim=0))
        
        # 保持归一化
        self.text_cache = torch.nn.functional.normalize(self.text_cache, dim=-1)
        self.visual_cache = torch.nn.functional.normalize(self.visual_cache, dim=-1)

    
    def get_cache_state(self):
        """获取当前缓存状态（用于保存检查点）"""
        return {
            "text_cache": self.text_cache.cpu(),
            "visual_cache": self.visual_cache.cpu()
        }
    
    def load_cache_state(self, state_dict):
        """加载缓存状态"""
        self.text_cache = state_dict["text_cache"].to(self.device)
        self.visual_cache = state_dict["visual_cache"].to(self.device)

if __name__ == "__main__":
    builder = LocalCacheBuilder(
        annotation_path="/data/baorunfeng/LLaVA-main/lvis_v1_val.json",
        image_dir="/data/baorunfeng/LLaVA-main/val2017",
        cache_size=100,
        clip_threshold=0.25
    )
    builder.build_cache(save_path="coco_local_cache.pth")
