import torch
from PIL import Image
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
import os, sys
import matplotlib.pyplot as plt
from PIL import Image
import decord
import torch.nn as nn


class CLIP_VQA(nn.Module):
    def __init__(self, shared_dim=128):
        super().__init__()
        # 加载CLIP模型
        self.clip_model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True, checkpoint_path="./PE-Core-L14-336.pt")
        self.clip_model = self.clip_model.cuda()
        self.vit = self.clip_model.visual
        self.text_encoder = self.clip_model.transformer

        # 获取维度信息
        self.vit_dim = self.vit.conv1.out_channels  # 1024
        self.text_dim = self.text_encoder.width  # 1024
        # print(self.vit_dim,self.text_dim)

        # 冻结原始参数
        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.patch_size = 14
        self.width = 1024
        self.shared_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(shared_dim, shared_dim),
                nn.GELU(),
                nn.Linear(shared_dim, shared_dim),
                # nn.GELU(),
            )
            for _ in range(len(self.vit.transformer.resblocks))
        ]).cuda()

        self.all_shared_mlps = nn.Sequential(
            nn.Linear(shared_dim, shared_dim),
            nn.GELU(),
            nn.Linear(shared_dim, shared_dim),
            # nn.GELU(),
        ).cuda()

        # 为image encoder的每个block创建adapter
        self.image_adapters = nn.ModuleList([
            nn.ModuleDict({
                'down_proj': nn.Linear(self.vit_dim, shared_dim),
                'up_proj': nn.Linear(shared_dim, self.vit_dim),
                'ln': nn.LayerNorm(self.vit_dim)
            })
            for _ in range(len(self.vit.transformer.resblocks))
        ]).cuda()

        # 为text encoder的每个block创建adapter
        self.text_adapters = nn.ModuleList([
            nn.ModuleDict({
                'down_proj': nn.Linear(self.text_dim, shared_dim),
                'up_proj': nn.Linear(shared_dim, self.text_dim)
            })
            for _ in range(len(self.text_encoder.resblocks))
        ]).cuda()

        self.adapters = nn.ModuleDict({
            'down_proj': nn.Linear(self.vit_dim, shared_dim),
            'up_proj': nn.Linear(shared_dim, self.vit_dim)
        }).cuda()

        # 准备文本token
        print(self.clip_model.context_length)
        self.tokenizer = transforms.get_text_tokenizer(self.clip_model.context_length)

        self.texts = self.tokenizer(["X X X a video of bad quality","X X X a video of poor quality","X X X a video of fair quality",
                                     "X X X a video of good quality","X X X a video of excellent quality"]).cuda()
        self.embedding = self.clip_model.token_embedding(self.texts)   #self.embedding.shape torch.Size([5, 32, 1024])

        self.ctx = nn.Parameter(self.embedding[0:1, 1:1 + 3].clone())
        # print("self.ctx", self.ctx.shape)    # torch.Size([1, 3, 1024])
        self.register_buffer("prefix", self.embedding[:, :1, :].clone())  # SOS
        self.register_buffer("suffix", self.embedding[:, 1 + 3:, :].clone())  # CLS, EOS
        self.prefix.requires_grad = False
        self.suffix.requires_grad = False
        # print("self.prefix", self.prefix.shape)  # self.prefix torch.Size([1, 1, 1024])   start标记
        # print("self.suffix", self.suffix.shape)   #  self.suffix torch.Size([1, 28, 1024])   end标记
        self.prompts = torch.cat(
            [
                self.prefix,  # (batch, 1, 1024)
                self.ctx.repeat(self.prefix.shape[0], 1, 1),
                # self.ctx,  # (batch, 3, 1024)
                self.suffix,  # (batch, 28, 1024)
            ],
            dim=1,
        )
        self.text_adapter_weight = nn.Parameter(torch.tensor(0.1))
        self.image_adapter_weight = nn.Parameter(torch.tensor(0.1))
        # print(self.clip_model.visual.class_embedding)

        self.proj_adapter = nn.Sequential(
            nn.Linear(self.vit_dim, shared_dim),
            # nn.GELU(),
            nn.Linear(shared_dim, shared_dim),
            nn.GELU(),
            nn.Linear(shared_dim, self.vit_dim),
            # nn.GELU(),
        ).cuda()

        self.text_proj_adapter_weight = nn.Parameter(torch.tensor(0.1))
        self.image_proj_adapter_weight = nn.Parameter(torch.tensor(0.1))

        self.score_proj = nn.Parameter(torch.empty(5, 1)).cuda()
        nn.init.normal_(self.score_proj, std=5 ** -0.5)

    def forward_image(self, x):
        # 手动实现image encoder前向传播
        batch, _, h, w = x.shape
        grid_h, grid_w = h // self.patch_size, w // self.patch_size

        x = self.clip_model.visual.conv1(x)
        x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)

        x = torch.cat(
                [self.clip_model.visual.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
                dim=1,
            )

        x = x + self.clip_model.visual._sample_abs_posemb(grid_h, grid_w)
        self.clip_model.visual.rope.update_grid(x.device, grid_h, grid_w)
        x = self.clip_model.visual.ln_pre(x)


        # 逐层添加adapter
        for i, (block, smlp, adapter) in enumerate(
                zip(self.vit.transformer.resblocks, self.shared_mlps, self.image_adapters)):
            # for i, block in enumerate(self.vit.transformer.resblocks):
            x = block(x)
            if i >= 18:
                # residual = x
                adapter_out = adapter['down_proj'](x)  # 降维
                adapter_out =  self.all_shared_mlps(adapter_out)  # 共享中间层
                adapter_out = adapter['up_proj'](adapter_out)  # 升维
                adapter_out = adapter['ln'](adapter_out)   #标准化
                # # 在每个block后添加adapter
                x = x + self.image_adapter_weight * adapter_out  # 残差连接

        x = self.clip_model.visual.ln_post(x)
        x = self.clip_model.visual.attn_pool(x).squeeze(1)
        y = self.proj_adapter(x)
        # print("x.shape",x.shape)
        x = x @ self.clip_model.visual.proj
        x = x + self.image_proj_adapter_weight * y
        return x

    def forward_video(self, video): # b n c h w
        b, n, c, h, w = video.shape
        frms = video.reshape(b * n, c, h, w)
        frm_feats = self.forward_image(frms)
        video_feats = frm_feats.reshape(b, n, -1)
        video_feats = video_feats.mean(dim=1)
        return video_feats

    def forward_text(self, prompts, text):  # text.size()  torch.Size([1, 32])
        seq_len = text.shape[1]
        x = prompts + self.clip_model.positional_embedding[:seq_len]
        for i, (block, smlp, adapter) in enumerate(
                zip(self.text_encoder.resblocks, self.shared_mlps, self.image_adapters)):
            x = block(x)
            if i >= 18:
                adapter_out = adapter['down_proj'](x)  # 降维
                adapter_out =  self.all_shared_mlps(adapter_out)  # 共享中间层
                adapter_out = adapter['up_proj'](adapter_out)  # 升维
                adapter_out = adapter['ln'](adapter_out)    # 标准化
                x = x + self.text_adapter_weight * adapter_out  # 残差连接

        x = self.clip_model.ln_final(x)
        pooled, tokens = self.clip_model.text_global_pool(x, text, pool_type=self.clip_model.pool_type)
        # print("pooled.shape", pooled.shape)
        pooled = pooled @ self.clip_model.text_projection
        y = self.proj_adapter(pooled)
        pooled = pooled @ self.clip_model.text_projection
        y = pooled + self.image_proj_adapter_weight * y
        return y


    def forward(self, video):
        text_features = self.forward_text(self.prompts, self.texts)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        # print("text_features",text_features.shape)   #text_features torch.Size([1, 1024])
        # 图像特征提取
        video_features = self.forward_video(video)
        # print("video_feature",video_features.shape)
        video_features = video_features / video_features.norm(dim=-1, keepdim=True)
        similarity = video_features @ text_features.T
        score = similarity @ self.score_proj
        score = score.view(-1)
        return score

def preprocess_video(video_path, num_frames=8, transform=None, return_first_frame_for_demo=True):
    """
    Uniformly samples a specified number of frames from a video and preprocesses them.
    Parameters:
    - video_path: str, path to the video file.
    - num_frames: int, number of frames to sample. Defaults to 8.
    - transform: torchvision.transforms, a transform function to preprocess frames.
    Returns:
    - Video Tensor: a tensor of shape (num_frames, 3, H, W) where H and W are the height and width of the frames.
    """
    # Load the video
    vr = decord.VideoReader(video_path)
    total_frames = len(vr)
    # Uniformly sample frame indices
    frame_indices = [int(i * (total_frames / num_frames)) for i in range(num_frames)]
    frames = vr.get_batch(frame_indices).asnumpy()
    # Preprocess frames
    preprocessed_frames = [transform(Image.fromarray(frame)) for frame in frames]
    first_frame = None
    if return_first_frame_for_demo:
        first_frame = frames[0]
    return torch.stack(preprocessed_frames, dim=0), first_frame

if __name__ == '__main__':

    model = CLIP_VQA()
    # model = model.cuda()
    preprocess = transforms.get_image_transform(336)
    video, first_frame = preprocess_video("./3291.mp4", 8, transform=preprocess)
    video = video.unsqueeze(0).cuda()
    with torch.no_grad():
        score = model(video)
        print(score)
