import json
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import numpy as np
from global_vars import *

def get_embedding_model(name=""):
    # if name == "sbert":
    #    return SentenceTransformer("all-MiniLM-L6-v2")
    # elif name == "llm2vec":
    return SentenceTransformer(name, device="cuda")


def get_top1_matches(opinions, pool, model):
    out = []
    for i, opinion in enumerate(opinions):
        query_embedding = model.encode(opinion, prompt_name="query")
        scores = []
        for comment in pool:
            comment_embedding = model.encode(comment, prompt_name="query")
            similarity_scores = model.similarity(query_embedding, comment_embedding)
            scores.append(similarity_scores.item())
        # get top 5, store in a list
        top5_indices = np.argsort(scores)[-5:]
        top5_matches = [pool[i] for i in top5_indices]
        out.append({
            "opinion": opinion,
            "top5_matches": top5_matches
        })
    return out



def find_best_match(kialo_path):
    model = get_embedding_model()
    kialo_data = json.load(open(kialo_path))
    new_data = []
    for item in tqdm(kialo_data):
        topic = item["topic"]
        pool = item["views"]
        pros_result = get_top1_matches(item["pros"], pool, model)
        cons_result = get_top1_matches(item["cons"], pool, model)

        new_data.append({
            "topic": topic,
            "pros": pros_result,
            "cons": cons_result
        })
        json.dump(new_data, open("", "w"), indent=4)

    return new_data

def get_avg_view_length(data):
    total_length = 0
    max_length = 0
    for view in data["views"]:
        max_length = max(len(view.split()), max_length)
    return max_length

def remove_too_long_views(data):
    for item in data:
        views = item["views"]
        for view in views:
            if len(view.split()) > 1200:
                views.remove(view)
        item["views"] = views
    return data

if __name__ == "__main__":

    data = find_best_match("")
    json.dump(data, open("", "w"), indent=4)
    exit()
