import torch
from torch import nn
from transformers import AutoModel, AutoProcessor
from einops import rearrange

class VisionEncoder(nn.Module):
    def __init__(self, emb_dim, seq_len, ckpt="siglip2_base_patch16_naflex"):
        super().__init__()
        self.emb_dim = emb_dim
        self.seq_len = seq_len
        self.ckpt = ckpt
        self.venc_model = AutoModel.from_pretrained(ckpt, torch_dtype=torch.float16, device_map="auto", attn_implementation="flash_attention_2").eval()
        self.venc_processor = AutoProcessor.from_pretrained(ckpt)
        self._no_grad()

    def _no_grad(self):
        for param in self.venc_model.parameters():
            param.requires_grad = False
        self.venc_model.eval()

    def embedding_cc1(self, x):
        B, M, L = x.shape
        x = x.unsqueeze(-1).repeat(1,1,1,3)
        inputs = self.venc_processor(images=x, return_tensors='pt', input_data_format="channels_last").to(self.venc_model.device)
        with torch.no_grad():
            embeddings = self.venc_model.get_image_features(**inputs)
            embeddings = embeddings.unsqueeze(1).repeat(1, M, 1)
        return embeddings.float()

    def embedding_ci1(self, x, num_seg):
        x = x.unsqueeze(-1)
        x = x.repeat(1,1,1,3)
        B, M, L, c = x.shape
        x = rearrange(x, 'B M (n p) c -> (B M) n p c', B=B, M=M, c=c, n=num_seg)
        inputs = self.venc_processor(images=x, return_tensors="pt").to(self.venc_model.device)
        with torch.no_grad():
            embeddings = self.venc_model.get_image_features(**inputs)
            embeddings = rearrange(embeddings, '(B M) emb -> B M emb', B=B, M=M)
        return embeddings

    def embedding_ci2(self, x, heit=7):
        x = x.unsqueeze(-1).repeat(1,1,1,3)
        B, M, L, c = x.shape
        x = x.unsqueeze(2).repeat(1,1,heit,1,1)
        x = rearrange(x, 'B M h L c -> (B M) h L c', B=B, M=M, L=L, c=c, h=heit)
        inputs = self.venc_processor(images=x, return_tensors='pt').to(self.venc_model.device)
        with torch.no_grad():
            embeddings = self.venc_model.get_image_features(**inputs)
            embeddings = rearrange(embeddings, '(B M) emb -> B M emb', B=B, M=M)
        return embeddings 