import clip
import torch
import torch.nn as nn

class CLIPTextEncoder(nn.Module):
    def __init__(self, clip_version, clip_device='cuda'):
        super().__init__()
        self.clip_device = clip_device
        self.clip_version = clip_version
        self.load_and_freeze_clip(clip_version)
    
    def load_and_freeze_clip(self, clip_version):
        clip_model, _ = clip.load(clip_version, device=self.clip_device, jit=False)

        self.clip_token_embedding = clip_model.token_embedding
        self.clip_transformer = clip_model.transformer
        self.clip_positional_embedding = clip_model.positional_embedding
        self.clip_ln_final = clip_model.ln_final
        self.clip_dtype = clip_model.dtype

        for p in self.clip_transformer.parameters():
            p.requires_grad = False
        for p in self.clip_token_embedding.parameters():
            p.requires_grad = False
        for p in self.clip_ln_final.parameters():
            p.requires_grad = False
        
        clipTransLayer = nn.TransformerEncoderLayer(d_model=self.clip_ln_final.normalized_shape[0],
                                                    nhead=8,
                                                    dim_feedforward=2048,
                                                    dropout=0.1,
                                                    activation="gelu",
                                                    batch_first=True)
        self.clipTrans = nn.TransformerEncoder(clipTransLayer, num_layers=2)
        self.clipln = nn.LayerNorm(self.clip_ln_final.normalized_shape[0])
        self.clipTrans.to(self.clip_device)
        self.clipln.to(self.clip_device)

    def forward(self, raw_text):
        with torch.no_grad():
            text = clip.tokenize(raw_text, truncate=True).to(self.clip_device)
            x = self.clip_token_embedding(text).type(self.clip_dtype)
            pe_tokens = x + self.clip_positional_embedding.type(self.clip_dtype)
            x = pe_tokens.permute(1,0,2)
            x = self.clip_transformer(x)
            x = x.permute(1,0,2)
            clip_out = self.clip_ln_final(x).type(self.clip_dtype)
        
        if clip_out.dtype != torch.float32:
            clip_out = clip_out.to(torch.float32)
        out = self.clipTrans(clip_out)
        out = self.clipln(out)
        feat_clip_text = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
        return feat_clip_text
    

class CLIPTextEncoderV2(nn.Module):
    def __init__(self, clip_version, clip_final_proj=False, clip_device='cuda'):
        super().__init__()
        self.clip_device = clip_device
        self.clip_version = clip_version
        self.clip_final_proj = clip_final_proj
        self.load_and_freeze_clip(clip_version)
    
    def load_and_freeze_clip(self, clip_version):
        clip_model, _ = clip.load(clip_version, device=self.clip_device, jit=False)
        
        for p in clip_model.parameters():
            p.requires_grad = False

        self.clip_token_embedding = clip_model.token_embedding
        self.clip_transformer = clip_model.transformer
        self.clip_positional_embedding = clip_model.positional_embedding
        self.clip_ln_final = clip_model.ln_final
        self.clip_dtype = clip_model.dtype
        if self.clip_final_proj:
            self.text_projection = clip_model.text_projection

    def forward(self, raw_text):
        text = clip.tokenize(raw_text, truncate=True).to(self.clip_device)
        x = self.clip_token_embedding(text).type(self.clip_dtype)
        pe_tokens = x + self.clip_positional_embedding.type(self.clip_dtype)
        x = pe_tokens.permute(1,0,2)
        x = self.clip_transformer(x)
        x = x.permute(1,0,2)
        clip_out = self.clip_ln_final(x).type(self.clip_dtype)
    
        if self.clip_final_proj:
            clip_out = clip_out @ self.text_projection

        # Extract the end-of-text ([EOT]) token position
        eot_positions = text.argmax(dim=-1)  # Position of [EOT] token for each text sequence

        # Compute lengths based on the [EOT] token position
        lengths = eot_positions + 1  # Include the [EOT] token in the length

        # Create a mask based on lengths
        batch_size, n_ctx = text.shape
        mask = torch.arange(n_ctx, device=text.device).unsqueeze(0) < lengths.unsqueeze(1)  # [batch_size, n_ctx]

        # Apply the mask element-wise to token embeddings
        clip_out = clip_out * mask.unsqueeze(-1)  # Mask shape expanded to [batch_size, n_ctx, 1] for multiplication

        return clip_out.to(torch.float32), ~mask


class CLIPTextEncoderV3(nn.Module):
    def __init__(self, clip_version, clip_final_proj=False, clip_device='cuda'):
        super().__init__()
        self.clip_device = clip_device
        self.clip_version = clip_version
        self.clip_final_proj = clip_final_proj
        self.load_and_freeze_clip(clip_version)
    
    def load_and_freeze_clip(self, clip_version):
        clip_model, _ = clip.load(clip_version, device=self.clip_device, jit=False)

        self.clip_token_embedding = clip_model.token_embedding
        self.clip_transformer = clip_model.transformer
        self.clip_positional_embedding = clip_model.positional_embedding
        self.clip_ln_final = clip_model.ln_final
        self.clip_dtype = clip_model.dtype

        for p in self.clip_transformer.parameters():
            p.requires_grad = False
        for p in self.clip_token_embedding.parameters():
            p.requires_grad = False
        for p in self.clip_ln_final.parameters():
            p.requires_grad = False
        
        if self.clip_final_proj:
            self.text_projection = clip_model.text_projection
        
        clipTransLayer = nn.TransformerEncoderLayer(d_model=self.clip_ln_final.normalized_shape[0],
                                                    nhead=8,
                                                    dim_feedforward=2048,
                                                    dropout=0.1,
                                                    activation="gelu",
                                                    batch_first=True)
        self.clipTrans = nn.TransformerEncoder(clipTransLayer, num_layers=2)
        self.clipln = nn.LayerNorm(self.clip_ln_final.normalized_shape[0])
        self.clipTrans.to(self.clip_device)
        self.clipln.to(self.clip_device)

    def forward(self, raw_text):
        text = clip.tokenize(raw_text, truncate=True).to(self.clip_device)
        x = self.clip_token_embedding(text).type(self.clip_dtype)
        pe_tokens = x + self.clip_positional_embedding.type(self.clip_dtype)
        x = pe_tokens.permute(1,0,2)
        x = self.clip_transformer(x)
        x = x.permute(1,0,2)
        clip_out = self.clip_ln_final(x).type(self.clip_dtype)
        
        if clip_out.dtype != torch.float32:
            clip_out = clip_out.to(torch.float32)
        out = self.clipTrans(clip_out)
        out = self.clipln(out)
        if self.clip_final_proj:
            out = out @ self.text_projection
        
        # Extract the end-of-text ([EOT]) token position
        eot_positions = text.argmax(dim=-1)  # Position of [EOT] token for each text sequence

        # Compute lengths based on the [EOT] token position
        lengths = eot_positions + 1  # Include the [EOT] token in the length

        # Create a mask based on lengths
        batch_size, n_ctx = text.shape
        mask = torch.arange(n_ctx, device=text.device).unsqueeze(0) < lengths.unsqueeze(1)  # [batch_size, n_ctx]

        # Apply the mask element-wise to token embeddings
        clip_out = out * mask.unsqueeze(-1)  # Mask shape expanded to [batch_size, n_ctx, 1] for multiplication

        return clip_out.to(torch.float32), ~mask
    

class CLIPTextEncoderV4(nn.Module):
    def __init__(self, clip_version, clip_final_proj=False, clip_device='cuda'):
        super().__init__()
        self.clip_device = clip_device
        self.clip_version = clip_version
        self.clip_final_proj = clip_final_proj
        self.load_and_freeze_clip(clip_version)
    
    def load_and_freeze_clip(self, clip_version):
        clip_model, _ = clip.load(clip_version, device=self.clip_device, jit=False)

        self.clip_token_embedding = clip_model.token_embedding
        self.clip_transformer = clip_model.transformer
        self.clip_positional_embedding = clip_model.positional_embedding
        self.clip_ln_final = clip_model.ln_final
        self.clip_dtype = clip_model.dtype

        for p in self.clip_transformer.parameters():
            p.requires_grad = False
        for p in self.clip_token_embedding.parameters():
            p.requires_grad = False
        for p in self.clip_ln_final.parameters():
            p.requires_grad = False
        
        if self.clip_final_proj:
            self.text_projection = clip_model.text_projection
        
        clipTransLayer = nn.TransformerEncoderLayer(d_model=self.clip_ln_final.normalized_shape[0],
                                                    nhead=8,
                                                    dim_feedforward=2048,
                                                    dropout=0.1,
                                                    activation="gelu",
                                                    batch_first=True)
        self.clipTrans = nn.TransformerEncoder(clipTransLayer, num_layers=2)
        self.clipln = nn.LayerNorm(self.clip_ln_final.normalized_shape[0])
        self.clipTrans.to(self.clip_device)
        self.clipln.to(self.clip_device)

    def forward(self, raw_text):
        text = clip.tokenize(raw_text, truncate=True).to(self.clip_device)
        x = self.clip_token_embedding(text).type(self.clip_dtype)
        pe_tokens = x + self.clip_positional_embedding.type(self.clip_dtype)
        x = pe_tokens.permute(1,0,2)
        x = self.clip_transformer(x)
        x = x.permute(1,0,2)
        clip_out = self.clip_ln_final(x).type(self.clip_dtype)
        
        if clip_out.dtype != torch.float32:
            clip_out = clip_out.to(torch.float32)
        out = self.clipTrans(clip_out)
        out = self.clipln(out)
        if self.clip_final_proj:
            out = out @ self.text_projection
        
        # Extract the end-of-text ([EOT]) token position
        eot_positions = text.argmax(dim=-1)  # Position of [EOT] token for each text sequence

        # Compute lengths based on the [EOT] token position
        lengths = eot_positions  # Include the [EOT] token in the length

        # Create a mask based on lengths
        batch_size, n_ctx = text.shape
        mask = torch.arange(n_ctx, device=text.device).unsqueeze(0) < lengths.unsqueeze(1)  # [batch_size, n_ctx]

        # Apply the mask element-wise to token embeddings
        clip_out = out * mask.unsqueeze(-1)  # Mask shape expanded to [batch_size, n_ctx, 1] for multiplication

        return clip_out.to(torch.float32), ~mask