import torch
import math
import torch.nn as nn
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


class image_encoder_new_0(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.encoder = model

        for i, p in self.encoder.named_parameters():
            p.requires_grad = False

    def forward(self, images):
        x = self.encoder.conv1(images)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.encoder.class_embedding.to(x.dtype) +
                       torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        x = x + self.encoder.positional_embedding.to(x.dtype)
        x = self.encoder.ln_pre(x)

        x = x.permute(1, 0, 2)
        x = self.encoder.transformer(x)
        x = x.permute(1, 0, 2)

        img_feature = self.encoder.ln_post(x[:, 0, :])
        x = img_feature @ self.encoder.proj

        return img_feature, x


class image_encoder(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.encoder = model
        for i, p in self.encoder.named_parameters():
            if i == "proj":
                p.requires_grad = False

    def forward(self, images):
        x = self.encoder.conv1(images)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.encoder.class_embedding.to(x.dtype) +
                       torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        x = x + self.encoder.positional_embedding.to(x.dtype)
        x = self.encoder.ln_pre(x)

        x = x.permute(1, 0, 2)
        x = self.encoder.transformer(x)
        x = x.permute(1, 0, 2)

        img_feature = self.encoder.ln_post(x[:, 0, :])
        x = img_feature @ self.encoder.proj

        return img_feature, x


class text_encoder(nn.Module):
    def __init__(self, model, token_embedding, positional_embedding, ln_final, text_projection):
        super().__init__()
        self.transformer = model
        self.token_embedding = token_embedding
        self.positional_embedding = positional_embedding
        self.ln_final = ln_final
        self.text_projection = text_projection

        self.positional_embedding.requires_grad = False

    def forward(self, text):
        dtype = torch.float32
        x = self.token_embedding(text).type(dtype)
        x = x + self.positional_embedding.type(dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(dtype)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x


class text_encoder_0(nn.Module):
    def __init__(self, model, token_embedding, positional_embedding, ln_final, text_projection):
        super().__init__()
        self.transformer = model
        self.token_embedding = token_embedding
        self.positional_embedding = positional_embedding
        self.ln_final = ln_final
        self.text_projection = text_projection

        self.positional_embedding.requires_grad = False

    def forward(self, text):
        dtype = torch.float32
        x = self.token_embedding(text).type(dtype)
        x = x + self.positional_embedding.type(dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(dtype)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

