# -*- coding: gb18030 -*-
import random
import json
from sentence_transformers import SentenceTransformer, util, CrossEncoder 
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from PIL import Image
import open_clip
import os
import shutil
import torch.nn as nn
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" 
device0 = torch.device("cuda:0")
device1 = torch.device("cuda:1")


# calculate the matching score, remove the paragraphs with a CLIPScore lower than a threshold of 17.0
clip_tokenizer = open_clip.get_tokenizer('ViT-bigG-14')
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-bigG-14', pretrained='laion2b_s39b_b160k')
clip_model = clip_model.cuda()
# clip_model = clip_model.to(device1)

def get_clip_score(image, text):
    image = preprocess(image).unsqueeze(0).cuda()
    image_features = clip_model.encode_image(image)
    image_features /= image_features.norm(p=2, dim=-1, keepdim=True)

    text_features = []
    sentences = clip_tokenizer(text.split('.')[:10])
    text_features = clip_model.encode_text(sentences.cuda())
    text_features /= text_features.norm(p=2, dim=-1, keepdim=True)
    text_features = text_features.mean(dim=0, keepdim=True)

    similarity = (100.0 * image_features @ text_features.T).sum(dim=-1)
    return round(similarity.item(),4)




# human preference score
reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(reward_name), AutoTokenizer.from_pretrained(reward_name)
rank_model = rank_model.to(device1)
# full reward score: 10
def get_reward_score(question, answer):
    inputs = tokenizer(question, answer, return_tensors='pt').to(device1)
    score = rank_model(**inputs).logits[0].cpu().detach()
    return round(score.item(),4)

max_caption_length = 1009 # minigpt4
# full length score: 10 
def caption_length_score(answer):
    return 10 * len(answer) / 1009

def final_score(image, question, answer):
    clip = get_clip_score(image, answer)
    reward = get_reward_score(question, answer)
    length = caption_length_score(answer)
    return clip + reward + length



with open('filter_cap.json', 'r', encoding='gb18030') as f:
    data_origin = json.load(f)


with open('gpt4score/gpt_final_score.json', 'r', encoding='gb18030') as f:
    gpt_data = json.load(f)


origin_annotations = data_origin['annotations']
gpt_annotations = gpt_data['annotations']

new_annotations = []
with torch.no_grad():
    for i in range(len(origin_annotations)):
        image_id = origin_annotations[i]["image_id"]
        caption = origin_annotations[i]["caption"]
        
        question = "Describe this image in detail."
        # question = data["question"]
        # gpt_rewrite = data["gpt_rewrite"]
        # new_annotation = {"image_id": image_id, "caption": caption, "question": question}
        new_annotation = {"image_id": image_id, "caption": caption}
     
        img = Image.open(f"image/{image_id}.jpg")
        #origin_final = final_score(img, question, caption) + gpt_annotations[i]["gpt_score"] / 4  
        new_annotation["clip_score"] = get_clip_score(img, caption)
        new_annotation["reward_score"] = get_reward_score(question, caption)
        new_annotation["length_score"] = caption_length_score(caption)
        #new_annotation["final_score"] = origin_final
        # new_annotation["nli_score"] = get_nil_score(question, caption)
        new_annotations.append(new_annotation)



final = sorted(new_annotations, key=lambda x: int(x['image_id']))
new_data = {'annotations': final}

with open('full_score_data.json', 'w', encoding='utf-8') as f:
    json.dump(new_data, f, indent=4, ensure_ascii=False)