import json
import copy
from tqdm import trange, tqdm
import pandas as pd


import os
import copy
import numpy as np
from scipy.spatial.distance import cosine
from scipy import stats

import random
from datasets import load_dataset


from InstructorEmbedding import INSTRUCTOR
from sklearn.metrics.pairwise import cosine_similarity

import logging
import transformers
transformers.tokenization_utils.logger.setLevel(logging.ERROR)
transformers.configuration_utils.logger.setLevel(logging.ERROR)
transformers.modeling_utils.logger.setLevel(logging.ERROR)



def evaluation(key_ref, corpus_scores, query_labels, dataset_name):
    # evaluation of a dataset    
    recall_threshold = [1,5,10]
    recall_results = [0 for thresh in recall_threshold]
    # 'perspectrum', 'agnews', 'story', 'ambigqa'
    
    if "source" in dataset_name:
        parts = ["none"]
    else:
        if dataset_name == "perspectrum":
            parts = ["support","undermine","general"]
        elif dataset_name == "agnews":
            parts = ["subtopic", "location"]
        elif dataset_name == "story":
            parts = ["analogy", "entity"]
        elif dataset_name == "ambigqa":
            parts = ["perspective"]
        elif dataset_name == "allsides":
            parts = ["left","right","center"]
        elif dataset_name == "exfever":
            parts = ["SUPPORT","REFUTE","NOT ENOUGH INFO"]
    
    parts_size = [0 for x in parts]
        
    for lb in query_labels:
        parts_size[parts.index(lb)] += 1
            
    partial_recall_results = []
    for i in range(len(parts)):
        partial_recall_results.append([0 for thresh in recall_threshold])

    
    for k,v in key_ref.items():
        for j, thresh in enumerate(recall_threshold):
            # important: find one is ok, this can be modified
            ranked_scores = (-np.array(corpus_scores[int(k)])).argsort()[:thresh]
            
            indicator = 0
            try:
                for index in v:
                    if index in ranked_scores:
                        indicator = 1 
            except:
                for index in [v]:
                    if index in ranked_scores:
                        indicator = 1                
            recall_results[j] += indicator
            partial_recall_results[parts.index(query_labels[int(k)])][j] += indicator
    
    final_results = [result/len(key_ref.items()) for result in recall_results]
        
    print("overall")
    for i, thresh in enumerate(recall_threshold):
        print("Recall@"+str(thresh)+":",final_results[i])
        
    macro_threshs = [[] for x in recall_threshold]
    
    for t, recall_results in enumerate(partial_recall_results):
        print(parts[t])
        final_results = [result/parts_size[t] for result in recall_results]
        
        for i, thresh in enumerate(recall_threshold):
            print("Recall@"+str(thresh)+":",final_results[i])
            macro_threshs[i].append(final_results[i])
                
    print("macro_average")
    for i, thresh in enumerate(recall_threshold):
        print("Recall@"+str(thresh)+":",sum(macro_threshs[i])/len(macro_threshs[i]))


if __name__ == "__main__":

    path = ""
    with open(path,"r",encoding="utf-8") as f:
        datasets = json.load(f)

    source_datasets = {}

    for data_name, dataset in datasets.items():
        source_datasets["source_"+data_name] = {"corpus":dataset["corpus"],"queries":[],"source_queries":[],"perspectives":[],"key_ref":{},"query_labels":[]}
        
        reverse_source_query_dic = {}
        
        for i, query in enumerate(dataset["source_queries"]):
            if query not in list(reverse_source_query_dic.keys()):
                query_id = str(len(source_datasets["source_"+data_name]["queries"]))
                reverse_source_query_dic[query] = query_id
                source_datasets["source_"+data_name]["queries"].append(query)
                source_datasets["source_"+data_name]["source_queries"].append(query)
                source_datasets["source_"+data_name]["perspectives"].append("none")
                source_datasets["source_"+data_name]["query_labels"].append("none")
                source_datasets["source_"+data_name]["key_ref"][query_id] = dataset["key_ref"][str(i)]
            else:
                # this source query already exists
                source_datasets["source_"+data_name]["key_ref"][str(reverse_source_query_dic[query])].extend(dataset["key_ref"][str(i)])

    model_name = "instructor-base"
    model.eval()

    for k,v in datasets.items():
        
        print("we are working on:",k)
        corpus_scores = []
        
        queries = v["queries"]
        source_queries = v["source_queries"]
        perspectives = v["perspectives"]
        corpus = v["corpus"]
        key_ref = v["key_ref"]
        query_labels = v["query_labels"]

        ins_queries = []
        ins_corpus = []
        for i, q in enumerate(source_queries):
            ins_queries.append(["",queries[i]])
            
        for c in corpus:
            ins_corpus.append(["",c])

        query_embeddings = model.encode(ins_queries)
        corpus_embeddings = model.encode(ins_corpus)

        for emb1 in tqdm(query_embeddings):
            scores = []
            for emb2 in corpus_embeddings:
                scores.append(1 - cosine(emb1, emb2))

            corpus_scores.append(scores)

            
        with open(model_name+"_Xpq_"+k+"_scores.json","w",encoding="utf-8") as f:
            json.dump(corpus_scores,f)
            
        evaluation(key_ref, corpus_scores, query_labels, k)



                
                    



