import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import clip
from typing import List, Optional, Union


class CLIPLoss(nn.Module):
    """
    Compute a CLIP-based loss to encourage alignment between generated images and text prompts.

    - Supports two loss types:
      - "cosine": 1 - cosine_similarity(image_features, text_features)
      - "contrastive": Symmetric cross-entropy on CLIP logits

    Usage:
        loss_fn = CLIPLoss(model_name="ViT-B/16", device="cuda", loss_type="cosine")
        loss = loss_fn(images, prompts)

    Args:
        model_name: CLIP model name, e.g. "ViT-B/16", "ViT-L/14".
        device: torch device for the CLIP model.
        loss_type: "cosine" or "contrastive".
        image_size: Spatial size to which images are resized for CLIP (default 224 for ViT-B/16).
    """

    def __init__(
        self,
        model_name: str = "ViT-B/16",
        device: Optional[Union[str, torch.device]] = None,
        loss_type: str = "cosine",
        image_size: int = 224,
    ) -> None:
        super().__init__()
        self.device = torch.device(device) if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, _ = clip.load(model_name, device=self.device,download_root="/ossfs/workspace/mmface/diffusers/examples/flux-control/CLIPLoss/ckpts")
        # Freeze CLIP params but allow gradients to flow to inputs through the graph
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.eval()

        self.loss_type = loss_type
        if self.loss_type not in {"cosine", "contrastive"}:
            raise ValueError("loss_type must be 'cosine' or 'contrastive'")
        self.image_size = image_size

        # CLIP normalization constants (OpenAI CLIP)
        self.register_buffer("clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device, dtype=torch.float32).view(1, 3, 1, 1))
        self.register_buffer("clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device, dtype=torch.float32).view(1, 3, 1, 1))

    @torch.no_grad()
    def _tokenize(self, texts: Union[List[str], torch.Tensor]) -> torch.Tensor:
        if isinstance(texts, list):
            tokens = clip.tokenize(texts, truncate=True)
        elif isinstance(texts, torch.Tensor):
            tokens = texts
        else:
            raise TypeError("texts must be List[str] or a token tensor")
        return tokens.to(self.device)

    def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        Expect images as float tensor in shape [B, 3, H, W] with values in [-1,1] or [0,1].
        Resize to CLIP input size and apply CLIP normalization.
        """
        if images.dim() != 4 or images.size(1) != 3:
            raise ValueError("images must have shape [B, 3, H, W]")

        # Move to device and float32 for CLIP
        images = images.to(self.device, dtype=torch.float32)

        # Map from [-1, 1] to [0, 1] if necessary
        if images.min() < 0.0:
            images = (images + 1.0) / 2.0
            images = images.clamp(0.0, 1.0)

        # Resize (bicubic) and center-crop/pad to square if needed
        # For simplicity, resize directly to target square size
        images = F.interpolate(images, size=(self.image_size, self.image_size), mode="bicubic", align_corners=False, antialias=True)

        # Normalize
        images = (images - self.clip_mean) / self.clip_std
        return images

    def forward(self, images: torch.Tensor, texts: Union[List[str], torch.Tensor]) -> torch.Tensor:
        """
        Args:
            images: [B, 3, H, W] float tensor in [-1,1] or [0,1]
            texts: List[str] (prompts) or already tokenized tensor [B, context_length]

        Returns:
            Scalar tensor loss.
        """
        batch_size = images.size(0)
        tokenized = self._tokenize(texts)
        if tokenized.size(0) != batch_size:
            # If a single prompt is provided, broadcast across the batch
            if tokenized.size(0) == 1:
                tokenized = tokenized.expand(batch_size, -1)
            else:
                raise ValueError("Number of texts must match batch size or be 1.")

        proc_images = self._preprocess_images(images)

        # Forward through CLIP to get features/logits. Do not disable grad so that
        # the loss can backpropagate to the input images via the computational graph.
        if self.loss_type == "cosine":
            image_features = self.model.encode_image(proc_images)
            text_features = self.model.encode_text(tokenized)
            image_features = F.normalize(image_features, dim=-1)
            text_features = F.normalize(text_features, dim=-1)
            cosine_sim = (image_features * text_features).sum(dim=-1)
            loss = 1.0 - cosine_sim
            return loss
        else:  # contrastive
            logits_per_image, logits_per_text = self.model(proc_images, tokenized)
            labels = torch.arange(batch_size, device=self.device)
            loss_i = F.cross_entropy(logits_per_image, labels)
            loss_t = F.cross_entropy(logits_per_text, labels)
            return 0.5 * (loss_i + loss_t)

if __name__ == "__main__":
    from PIL import Image
    image = Image.open("test_data/face/29996.jpg")
    text = "This person is attractive and has brown hair, mouth slightly open, high cheekbones, wavy hair, and bangs"

    # Convert PIL to tensor batch [1, 3, H, W] in [0,1]
    img_tensor = TF.to_tensor(image).unsqueeze(0).to(torch.device("cuda"))
    img_tensor = img_tensor.repeat(3, 1, 1, 1)

    # Cosine loss
    clip_loss_cos = CLIPLoss(model_name="ViT-B/16", loss_type="cosine")
    loss_cos = clip_loss_cos(img_tensor, [text,text,text])
    print("CLIPLoss (cosine):", float(loss_cos.detach().cpu()))

    # Contrastive loss
    clip_loss_con = CLIPLoss(model_name="ViT-B/16", loss_type="contrastive")
    loss_con = clip_loss_con(img_tensor, [text])
    print("CLIPLoss (contrastive):", float(loss_con.detach().cpu()))