# clip/clip_embedder.py
"""
CLIP embedder wrapper. Supports two modes:
 - If 'open_clip' or 'clip' library is installed, use it for real embeddings.
 - Otherwise, fall back to a deterministic placeholder embedding (for tests).

API:
  embed_text(list_of_prompts) -> torch.Tensor [N, D]
  embed_image_from_rgb(rgb_tensor) -> torch.Tensor [N, D]   (rgb in [0,1])
Outputs are L2-normalized.
"""
import torch
import numpy as np

try:
    # try huggingface/open_clip first
    import clip as openai_clip  # try OpenAI clip package
    _HAS_CLIP = True
except Exception:
    try:
        import open_clip  # alternative
        _HAS_CLIP = True
    except Exception:
        _HAS_CLIP = False

class CLIPEmbedder:
    def __init__(self, model_name="ViT-L/14", device="cpu"):
        self.device = device
        self.model = None
        self.dtype = torch.float32
        if _HAS_CLIP:
            # TODO: load chosen CLIP model; here we keep a placeholder provision
            try:
                import open_clip
                self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s34b_b79k')
                self.model.to(self.device)
            except Exception:
                self.model = None
        else:
            self.model = None

        # placeholder embedding dim
        self.dim = 768

    def _placeholder_text_to_vec(self, texts):
        """Deterministic placeholder: hash-based vector"""
        vecs = []
        for t in texts:
            h = abs(hash(t)) % (10**8)
            rng = np.random.RandomState(h)
            v = rng.normal(size=(self.dim,))
            v = v / (np.linalg.norm(v) + 1e-9)
            vecs.append(v)
        return torch.from_numpy(np.vstack(vecs)).float().to(self.device)

    def embed_text(self, texts):
        if self.model is not None:
            # actual CLIP encoding path (requires installed model)
            # TODO: implement with open_clip or clip package when available
            raise NotImplementedError("Real CLIP integration not yet wired. Install open_clip and reinitialize.")
        else:
            return self._placeholder_text_to_vec(texts)

    def embed_image_from_rgb(self, rgb_tensor):
        """
        rgb_tensor: torch.Tensor [B, 3, H, W], values in [0,1]
        Returns normalized embeddings [B, dim]
        """
        if self.model is not None:
            raise NotImplementedError("Real CLIP integration not yet wired. Install open_clip and reinitialize.")
        else:
            # placeholder: random but deterministic per image content (hash of bytes)
            B = rgb_tensor.shape[0]
            out = []
            for i in range(B):
                arr = (rgb_tensor[i].cpu().numpy() * 255).astype('uint8').tobytes()
                h = abs(hash(arr)) % (10**8)
                rng = np.random.RandomState(h)
                v = rng.normal(size=(self.dim,))
                v = v / (np.linalg.norm(v) + 1e-9)
                out.append(v)
            return torch.from_numpy(np.vstack(out)).float().to(self.device)
