import os
import json
import torch
from PIL import Image
from tqdm import tqdm
from typing import Dict, List, Tuple
import shutil 

from .image_processing import get_image_path, bbox_to_xyxy, load_and_crop_image
from .clustering import initialize_cache_centers , cluster_embeddings
from .config import DEFAULT_CACHE_SIZE, DEFAULT_CLIP_THRESHOLD, DEFAULT_DEVICE, DEFAULT_MOMENTUM
from torchvision import transforms
import torch.nn.functional as F

from .inference import EmbeddingDecoder  
from transformers import CLIPTokenizer
import logging
from datetime import datetime, timedelta, timezone
from sentence_transformers import SentenceTransformer
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, "../..")))
from .llava_model import LlavaWrapper
from .clip_model import CLIPWrapper

BEIJING_TZ = timezone(timedelta(hours=8))

def setup_logger():
    log_dir = "you_path"
    os.makedirs(log_dir, exist_ok=True)
    log_filename = f"cluster_log_{datetime.now(BEIJING_TZ).strftime('%Y%m%d_%H%M%S')}.log"
    log_path = os.path.join(log_dir, log_filename)

    # time
    class BeijingFormatter(logging.Formatter):
        def formatTime(self, record, datefmt=None):
            ct = datetime.fromtimestamp(record.created, BEIJING_TZ)
            if datefmt:
                return ct.strftime(datefmt)
            else:
                return ct.isoformat()

    formatter = BeijingFormatter(
        fmt="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )

    handler = logging.FileHandler(log_path, mode="w")
    handler.setFormatter(formatter)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.handlers = [handler]

    return log_path

class LocalCacheBuilder:
    def __init__(self, 
                 annotation_path: str,
                 image_dir: str,
                 momentum: float = DEFAULT_MOMENTUM,
                 cache_size: int = DEFAULT_CACHE_SIZE,
                 clip_threshold: float = DEFAULT_CLIP_THRESHOLD,
                 device: str = DEFAULT_DEVICE):
        
        self.annotation_path = annotation_path
        self.image_dir = image_dir
        self.momentum = momentum
        self.cache_size = cache_size
        self.clip_threshold = clip_threshold
        self.device = device
        
        # Load dataset
        with open(annotation_path, 'r') as f:
            self.data = json.load(f)
        self.images = {img['id']: img for img in self.data['images']}
        self.annotations = self.data['annotations']
        
        # Initialize models
        
        self.llava = LlavaWrapper(device)
        self.clip = CLIPWrapper(device)
        
        # Initialize cache
        self.text_cache = None
        self.visual_cache = None

    def _process_single_image(self, img_id: int):
        img_info = self.images[img_id]
        try:
            img_path = get_image_path(img_info, self.image_dir)
            print(img_path)
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Skipping image_id={img_id}: {str(e)}")
            return [], []

        anns = [ann for ann in self.annotations if ann['image_id'] == img_id]
        bboxes = [bbox_to_xyxy(ann['bbox']) for ann in anns]
        valid_text_embeds, valid_visual_embeds = [], []
        for bbox in bboxes:
            if bbox[2]-bbox[0]<=10 or bbox[3]-bbox[1]<=10:
                continue
            roi = image.crop(bbox)
            inputs = self.llava.processor(
                text="<image>\nUSER: Generate a fine-grained textual description of the image (e.g., 'a zebra with black and white stripes\nASSISTANT:",
                images=roi, 
                return_tensors="pt"
            ).to(self.device)
            generate_ids = self.llava.model.generate(**inputs, max_new_tokens=50)
            desc = self.llava.processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
            desc = desc.split("ASSISTANT:")[-1].strip()
            logging.info(desc)
            # descriptions.append(desc)
            inputs = self.clip.processor(
                text=desc, 
                images=roi, 
                return_tensors="pt", 
                padding=True
            ).to(self.device)
            with torch.no_grad():
                outputs = self.clip.model(**inputs)
                sim = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds).item()
                if sim > self.clip_threshold:
                    logging.info(sim)
                    valid_text_embeds.append(outputs.text_embeds.cpu())
                    valid_visual_embeds.append(outputs.image_embeds.cpu())
        
        return valid_text_embeds, valid_visual_embeds
    
    def _process_batchsize_image(self,max_images: int = 300):
        all_text_embeds, all_visual_embeds = [], []
        image_ids = list(self.images.keys())[:max_images]

        for img_id in tqdm(image_ids, desc=f"Processing First {max_images} Images"):
            try:
                text_embeds, visual_embeds = self._process_single_image(img_id)
                if text_embeds:
                    all_text_embeds.extend(text_embeds)
                    all_visual_embeds.extend(visual_embeds)

            except Exception as e:
                print(e)
                continue
        if not all_text_embeds or not all_visual_embeds:
            print("❌ No valid embeddings extracted. Skipping cache build.")
            return None, None
        text_embeds = torch.cat(all_text_embeds).to(self.device)
        visual_embeds = torch.cat(all_visual_embeds).to(self.device)
        return text_embeds, visual_embeds

            

    def process_batch_feature(self, samples, targets, save_path=None, init=False):
        """

        """
        from torchvision.transforms import ToPILImage
        all_text_embeds, all_visual_embeds = [], []
        to_pil = ToPILImage()

        images = samples.tensors

        for img_tensor, target in zip(images, targets):
            image_id = int(target["image_id"].item())
            anns = [ann for ann in self.annotations if ann["image_id"] == image_id]
            bboxes = [bbox_to_xyxy(ann["bbox"]) for ann in anns]
            image_pil = to_pil(img_tensor.cpu())

            for bbox in bboxes:
                if bbox[2] - bbox[0] <= 10 or bbox[3] - bbox[1] <= 10:
                    continue
                roi = image_pil.crop(bbox)

                inputs = self.llava.processor(
                    text="<image>\nUSER: Describe this.\nASSISTANT:",
                    images=roi,
                    return_tensors="pt"
                ).to(self.device)
                with torch.no_grad():
                    ids = self.llava.model.generate(**inputs, max_new_tokens=50)
                    desc = self.llava.processor.batch_decode(ids, skip_special_tokens=True)[0]
                    desc = desc.split("ASSISTANT:")[-1].strip()

                clip_inputs = self.clip.processor(
                    text=desc,
                    images=roi,
                    return_tensors="pt",
                    padding=True
                ).to(self.device)
                with torch.no_grad():
                    out = self.clip.model(**clip_inputs)
                    sim = torch.cosine_similarity(out.image_embeds, out.text_embeds).item()
                    if sim > self.clip_threshold:
                        all_text_embeds.append(out.text_embeds.cpu())
                        all_visual_embeds.append(out.image_embeds.cpu())

        if not all_text_embeds:
            logging.warning("⚠️ No valid features extracted from batch.")
            return

        text_embeds = torch.cat(all_text_embeds).to(self.device)
        visual_embeds = torch.cat(all_visual_embeds).to(self.device)

        if init:
            actual_K = min(self.cache_size, text_embeds.shape[0])
            cluster_ids_np = cluster_embeddings(text_embeds, actual_K)
            cluster_ids = torch.tensor(cluster_ids_np, device=self.device)

            self.text_cache, self.visual_cache = initialize_cache_centers(
                text_embeds, visual_embeds, cluster_ids, actual_K, self.device
            )

            if save_path:
                torch.save({
                    "text_cache": self.text_cache.cpu(),
                    "visual_cache": self.visual_cache.cpu()
                }, save_path)
                logging.info(f"✅ Cache initialized and saved to {save_path}")
        else:
            self.update_cache(text_embeds, visual_embeds, save_path)
            logging.info(f"🔁 Cache updated with {text_embeds.shape[0]} new samples.")





    def update_cache(self, new_text_embeds, new_visual_embeds, save_path: str = None):
        """
        Momentum update of the cache using new embeddings.
        :param new_text_embeds: New text embeddings, shape [N, D]
        :param new_visual_embeds: New visual embeddings, shape [N, D]
        """
        new_text_embeds = new_text_embeds.to(self.device)
        new_visual_embeds = new_visual_embeds.to(self.device)
        
        # Step 1: Compute similarity between new features and cache centers
        text_sim = torch.mm(new_text_embeds, self.text_cache.T)      # [N, M]
        visual_sim = torch.mm(new_visual_embeds, self.visual_cache.T)
        _, text_clusters = torch.max(text_sim, dim=1)                # [N]
        _, visual_clusters = torch.max(visual_sim, dim=1)

        # Step 2: Momentum update for text cache
        for cluster_id in torch.unique(text_clusters):
            mask = (text_clusters == cluster_id)
            if mask.sum() > 0:
                selected = new_text_embeds[mask]
                self.text_cache[cluster_id] = (
                    self.momentum * self.text_cache[cluster_id] +
                    (1 - self.momentum) * selected.mean(dim=0)
                )

        # Step 3: Momentum update for visual cache
        for cluster_id in torch.unique(visual_clusters):
            mask = (visual_clusters == cluster_id)
            if mask.sum() > 0:
                selected = new_visual_embeds[mask]
                self.visual_cache[cluster_id] = (
                    self.momentum * self.visual_cache[cluster_id] +
                    (1 - self.momentum) * selected.mean(dim=0)
                )

        # Step 4: Normalize updated cache entries
        self.text_cache = torch.nn.functional.normalize(self.text_cache, dim=-1)
        self.visual_cache = torch.nn.functional.normalize(self.visual_cache, dim=-1)
        if save_path is not None:
            torch.save({
                "text_cache": self.text_cache.cpu(),
                "visual_cache": self.visual_cache.cpu()
            }, save_path)
            logging.info(f"💾 Updated cache saved to {save_path}")  
        # Step 5: Log number of active (non-zero) slots in the cache
        with torch.no_grad():
            active_text = (self.text_cache.norm(dim=1) > 1e-6).sum().item()
            active_visual = (self.visual_cache.norm(dim=1) > 1e-6).sum().item()
            logging.info(f"[UpdateCache] Active text slots: {active_text}/{self.cache_size}, "
                        f"visual slots: {active_visual}/{self.cache_size}")

        
    def get_cache_state(self):
        """get state"""
        return {
            "text_cache": self.text_cache.cpu(),
            "visual_cache": self.visual_cache.cpu()
        }
    
    def load_cache_state(self, state_dict):
        """load"""
        self.text_cache = state_dict["text_cache"].to(self.device)
        self.visual_cache = state_dict["visual_cache"].to(self.device)


if __name__ == "__main__":
    log_path = setup_logger()
    logging.info("🚀 Start cache building and analysis...")
    builder = LocalCacheBuilder(
        annotation_path="",
        image_dir="",
        cache_size=300,
        clip_threshold=0.7
    )
    
    save_path=""
    
    builder.build_cache(save_path, max_images=2000)

