# energy/energy_fn.py
"""
Energy function E(x) and gradient computation for EGP.

E(x) = fidelity_term(x, x_target) + beta * sum_k ReLU(dot(zI(x), zT(nk)) - tau)
 - fidelity_term: e.g., L2 distance in latent or feature space to x_target (optional)
 - CLIP-repulsion: encourages image embedding zI(x) to have small similarity with negatives zT(nk).

This module supports two backends:
 - decode_then_clip: use VAE decoder -> RGB -> CLIP embed (most faithful, costliest)
 - projector: use a learned MLP projector mapping latent -> CLIP_embed (fast)

Returns energy scalar and gradient wrt x (torch autograd).
"""
import torch
import torch.nn.functional as F

class EnergyFunction:
    def __init__(self, clip_embedder, vae_decoder=None, projector=None, beta=2.5, tau=0.25, device="cpu"):
        """
        clip_embedder: CLIPEmbedder instance with embed_image_from_rgb or embed_text
        vae_decoder: optional VAE decoder (latent->RGB); if None projector must be provided when using latent
        projector: optional MLP mapping latent->clip embedding (fast)
        """
        self.clip = clip_embedder
        self.vae = vae_decoder
        self.projector = projector
        self.beta = beta
        self.tau = tau
        self.device = device

    def _image_embedding(self, x_latent):
        """
        Compute CLIP image embedding given latent. Use projector if enabled, else decode->clip.
        x_latent: torch.Tensor [B, D], requires_grad=True
        Returns: embeddings [B, dim] (torch.Tensor)
        """
        if (self.projector is not None):
            # projector expects latents and returns embedding
            return self.projector(x_latent)
        elif (self.vae is not None):
            rgb = self.vae.decode(x_latent)  # [B,3,H,W], values in [0,1]
            # embed via clip (placeholder or real)
            emb = self.clip.embed_image_from_rgb(rgb)
            return emb
        else:
            raise RuntimeError("Either projector or vae_decoder must be provided to compute image embeddings.")

    def energy_and_grad(self, x_latent, x_target_latent=None, negatives_text_embs=None):
        """
        Compute energy scalar and gradient wrt x_latent.
        Params:
            x_latent: torch.Tensor [B, D] (requires_grad=True)
            x_target_latent: torch.Tensor [B, D] or None (fidelity anchor)
            negatives_text_embs: torch.Tensor [K, dim] (text embeddings for negative prompts) or None
        Returns:
            energy: torch.Tensor scalar (sum over batch)
            grad: torch.Tensor same shape as x_latent (gradient)
        """
        x = x_latent
        x = x.requires_grad_(True)
        emb = self._image_embedding(x)  # [B, dim]
        # Normalize embeddings
        emb_n = F.normalize(emb, dim=1)

        energy = 0.0
        # fidelity term (L2 in latent space if x_target provided)
        if x_target_latent is not None:
            fidelity = ((x - x_target_latent).pow(2).sum(dim=1)).mean()
            energy = energy + fidelity

        # CLIP repulsion: for each negative text embedding nk compute relu(dot(emb, zT(nk)) - tau)
        if negatives_text_embs is not None:
            # negatives_text_embs: [K, dim]; compute dot products -> [B, K]
            zt = negatives_text_embs.to(emb_n.device)
            sims = torch.matmul(emb_n, zt.t())  # [B, K]
            # compute hinge: relu(s - tau)
            hinge = F.relu(sims - self.tau)   # [B, K]
            # sum over negatives then mean across batch
            repulsion = hinge.sum(dim=1).mean()
            energy = energy + self.beta * repulsion

        # compute gradient
        grad = torch.autograd.grad(energy, x, retain_graph=False, create_graph=False)[0]
        return energy, grad
