import torch
import torch.nn as nn
from PIL import Image
import clip


class RewardModel(nn.Module):
    def __init__(self, embed_dim=768):  # ViT-L/14嵌入维度768
        super(RewardModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2, 1024),  # 两个embed拼起来
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, image_embed, text_embed):
        combining_embedding = torch.cat((image_embed, text_embed), dim=-1)
        return self.fc(combining_embedding)


def calculate_batch_scores(prompts, imgs, preprocess, clip_model, reward_model, device):
    # 处理图像为模型输入格式
    processed_imgs = []
    for img in imgs:
        if isinstance(img, torch.Tensor):
            # 将张量转换为PIL图像
            img = img.cpu().numpy()
            if img.shape[0] == 3:  # 假设是CHW格式
                img = img.transpose(1, 2, 0)  # 转换为HWC
            # 假设数值范围是0-1，转换为0-255
            img = (img * 255).astype('uint8')
            img_pil = Image.fromarray(img)
        else:
            img_pil = Image.fromarray(img)
        processed = preprocess(img_pil).to(device)
        processed_imgs.append(processed)
    
    imgs_tensor = torch.stack(processed_imgs)  # (batch_size, 3, 224, 224)
    text_tokens = clip.tokenize(prompts, truncate=True).to(device)  # (batch_size, max_token_len)

    # 模型设置为评估模式
    reward_model.eval()
    clip_model.eval()

    with torch.no_grad():
        # 获取文本和图像嵌入
        text_embeddings = clip_model.encode_text(text_tokens).float()  # (batch_size, embedding_dim)
        image_embeddings = clip_model.encode_image(imgs_tensor).float()  # (batch_size, embedding_dim)

        # 计算图文对的分数
        scores = reward_model(image_embeddings, text_embeddings)  # (batch_size, 1)

    return scores.squeeze()