import json
import copy
from tqdm.notebook import trange, tqdm
from tabulate import tabulate
import pandas as pd

import openai

import os
import copy
import numpy as np
from scipy.spatial.distance import cosine
from scipy import stats

import seaborn as sns
import matplotlib.pyplot as plt

import torch

from transformers import AutoModel, AutoTokenizer,T5Tokenizer, T5ForConditionalGeneration
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer,DPRQuestionEncoder, DPRQuestionEncoderTokenizer

import logging
import transformers
transformers.tokenization_utils.logger.setLevel(logging.ERROR)
transformers.configuration_utils.logger.setLevel(logging.ERROR)
transformers.modeling_utils.logger.setLevel(logging.ERROR)

import random
from datasets import load_dataset

def prepare_story_data():
    key = "your key"
    openai.api_key = key

    template=[{"role": "user", "content": "Please respond with a creative short sentence you generate containing the entities in the following sentence:{sent}"}]
    prompt = "Please respond with a creative short sentence you generate containing the entities in the following sentence:{sent}"

    def ask_gpt(messages):
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            max_tokens=25,
            temperature=1,
            n=1,
            messages = messages)

        return response["choices"][0]["message"]["content"]


    story_raw = pd.read_csv("./dataset/story/StoryAnalogy.csv").to_dict('index')

    raw_keys = list(story_raw.keys())
    rank_scores = []
    for k in raw_keys:
        item = story_raw[k]
        rank_scores.append(item["relation"]-item["entity"])

    selected_index = np.array(rank_scores).argsort()[::-1][:1000]
    selected_stories = [story_raw[raw_keys[x]] for x in selected_index]

    our_story_dataset = [] 

    for item in tqdm(selected_stories):
        messages = [{"role": "user", "content":prompt.replace("{sent}",item["s1"])}]
        semantic_sent = ask_gpt(messages).strip()
        our_story_dataset.append([item["s1"],item["s2"],semantic_sent])

    path = "./dataset/story/pir_story.json"
    with open(path,"w",encoding="utf-8") as f:
        json.dump(our_story_dataset,f)
    
    
    return our_story_dataset

def load_attrprompt_label(folder):
    
    topics = []
    subtopics = {}
    
    for filename in os.listdir(folder):
        f = os.path.join(folder, filename)
        if os.path.isfile(f):
            topic = filename.replace(".jsonl","").replace(".txt","")
            topics.append(topic)
            subtopics[topic] = []
            
            with open(f,"r",encoding="utf-8") as this_jsonl:
                lines = this_jsonl.readlines()
                
            for line in lines:
                subtopics[topic].append(line.strip())
                
    return topics, subtopics
    
    
def load_perspectrum_main(datasets):
    # perspectrum
    datasets["perspectrum"] = {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}
    
    path = ""
    raw_keys = ["dataset_split","perspective_pool","perspectrum_with_answers"]
    version = "_v1.0"
    raw_data = {}
    
    for k in raw_keys:
        with open(path+k+version+".json","r",encoding="utf-8") as f:
            raw_data[k] = json.load(f)
    
    cid_map = {}
    for i, entry in enumerate(raw_data["perspectrum_with_answers"]):
        cid_map[entry["cId"]] = i
        
    pid_map = {}
    for i, entry in enumerate(raw_data["perspective_pool"]):
        pid_map[entry["pId"]] = i
    
    # we work on the test data
    sampled_cid = []
    for k,v in raw_data["dataset_split"].items():
        if v == "test":
            sampled_cid.append(cid_map[int(k)])
            
    sampled_args = []
    for cid in sampled_cid:
        sampled_args.append(raw_data["perspectrum_with_answers"][cid])
        
    for arg in sampled_args:
        arg_text = arg["text"]
        start_corpus_len = len(datasets["perspectrum"]["corpus"])
        
        for perspective in arg["perspectives"]:
            
            query_id = len(datasets["perspectrum"]["queries"])
            
            if perspective["stance_label_3"] == "SUPPORT":
                datasets["perspectrum"]["queries"].append("Find a claim that supports the argument: "+arg_text)
                datasets["perspectrum"]["source_queries"].append(arg_text)
                datasets["perspectrum"]["perspectives"].append("a claim that supports the argument")
                datasets["perspectrum"]["query_labels"].append("support")
                datasets["perspectrum"]["key_ref"][query_id] = []
                
                for pid in perspective["pids"]:
                    datasets["perspectrum"]["corpus"].append(raw_data["perspective_pool"][pid_map[pid]]["text"])
                    datasets["perspectrum"]["key_ref"][query_id].append(len(datasets["perspectrum"]["corpus"])-1)
                    
            elif perspective["stance_label_3"] == "UNDERMINE":
                datasets["perspectrum"]["queries"].append("Find a claim that opposes the argument: "+arg_text)
                datasets["perspectrum"]["source_queries"].append(arg_text)
                datasets["perspectrum"]["perspectives"].append("a claim that opposes the argument")
                datasets["perspectrum"]["query_labels"].append("undermine")
                datasets["perspectrum"]["key_ref"][query_id] = []
                
                for pid in perspective["pids"]:
                    datasets["perspectrum"]["corpus"].append(raw_data["perspective_pool"][pid_map[pid]]["text"])
                    datasets["perspectrum"]["key_ref"][query_id].append(len(datasets["perspectrum"]["corpus"])-1)
                    
            elif perspective["stance_label_3"] == "not-a-perspective":
                datasets["perspectrum"]["queries"].append("Find a claim that relates to the argument: "+arg_text)
                datasets["perspectrum"]["source_queries"].append(arg_text)
                datasets["perspectrum"]["perspectives"].append("a claim that relates to the argument")
                datasets["perspectrum"]["query_labels"].append("not-a-perspective")
                datasets["perspectrum"]["key_ref"][query_id] = []
                
                for pid in perspective["pids"]:
                    datasets["perspectrum"]["corpus"].append(raw_data["perspective_pool"][pid_map[pid]]["text"])
                    datasets["perspectrum"]["key_ref"][query_id].append(len(datasets["perspectrum"]["corpus"])-1)
                    
        # general relevance
        query_id = len(datasets["perspectrum"]["queries"])
        datasets["perspectrum"]["queries"].append("Find a claim relates to the argument: "+arg_text)
        datasets["perspectrum"]["source_queries"].append(arg_text)

        datasets["perspectrum"]["perspectives"].append("a claim that relates to the argument")
        datasets["perspectrum"]["query_labels"].append("general")
        datasets["perspectrum"]["key_ref"][query_id] = []
        for temp_id in range(start_corpus_len, len(datasets["perspectrum"]["corpus"])):
            datasets["perspectrum"]["key_ref"][query_id].append(temp_id)
            
    return datasets


def load_agnews_main(datasets):
    
    data_scope = 2000
    
    datasets["agnews"] = {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}
    attrprompt = load_dataset("yyu/agnews-attrprompt", data_files="attrprompt-v1.jsonl", split = 'train')
    attrprompt = list(attrprompt)
    
    topic_path = ""
    loc_path = ""

    topics, subtopics = load_attrprompt_label(topic_path)
    loc_cat, loc_places = load_attrprompt_label(loc_path)
    labels = ["world", "sports", "business", "sci_tech"]
    
    topic_template = "Find a news article relates to {subtopic} and similar the following news: {query}"
    location_template = "Find a news article happened in {location} and similar to the following news: {query}"
    
    ref_to_index = {} # label_length_location_subtopics_style->index
    reverse_index = {} # index->label_length_location_subtopics_style
    
    query_candidates = []
    query_ref = []
    
    segs = ["_id", "length", "location", "subtopics", "style"]
    max_segs = [0 for seg in segs]
    
    for i, item in enumerate(attrprompt):    
        key_lst = []
        
        if item["length"] == 0:
            query_candidates.append(item) # original index and item
            
        for seg in segs:
            key_lst.append(str(item[seg]))
            if item[seg] > max_segs[segs.index(seg)]:
                max_segs[segs.index(seg)] = item[seg]

        ref_to_index["_".join(key_lst)] = i
        reverse_index[i] = "_".join(key_lst)
        
    temp_subs = [0,0,0,0]
    for k in list(ref_to_index.keys()):
        segs = k.split("_")
        if int(segs[3]) > temp_subs[int(segs[0])]:
            temp_subs[int(segs[0])] = int(segs[3])
            
    for i, query in enumerate(tqdm(query_candidates)):
        
        if i > data_scope:
            continue
        
        # subtopic
        query_id = len(datasets["agnews"]["queries"])
        
        matched_subtopics = subtopics[labels[query["_id"]]]
        
        available_cands = []
        for k,v in ref_to_index.items():
            cond_a = "_".join(str(x) for x in [str(query["_id"]),"1",query["location"]])
            cond_b = "_" + str(query["style"])
            
            if k.startswith(cond_a) and k.endswith(cond_b):
                available_cands.append(k)
                
        if len(available_cands) < 2: # no enough candidates
            continue
            
        random.Random(42).shuffle(available_cands)
        target_ref = available_cands[0]
        target_index = int(target_ref.split("_")[3])
        
        if target_index == len(matched_subtopics): # no defined subtopics
            continue
            
        interfere_ref = available_cands[1]
        target_subtopic = matched_subtopics[target_index]
        
        
        topic_query = topic_template.replace("{query}",query["text"]).replace("{subtopic}",target_subtopic)
        datasets["agnews"]["queries"].append(topic_query)
        datasets["agnews"]["source_queries"].append(query["text"])
        datasets["agnews"]["perspectives"].append("a news article relates to "+target_subtopic)
        datasets["agnews"]["query_labels"].append("subtopic")
        datasets["agnews"]["key_ref"][query_id] = [len(datasets["agnews"]["corpus"])]
        datasets["agnews"]["corpus"].append(attrprompt[ref_to_index[target_ref]]["text"]) # correct one
        datasets["agnews"]["corpus"].append(attrprompt[ref_to_index[interfere_ref]]["text"]) # wrong one
        
        # location
        query_id = len(datasets["agnews"]["queries"])

        available_cands = []
        for k,v in ref_to_index.items():
            cond_a = "_".join(str(x) for x in [str(query["_id"]),"1"])
            cond_b = "_".join(str(x) for x in[query["subtopics"],query["style"]])
            
            if k.startswith(cond_a) and k.endswith(cond_b):
                    available_cands.append(k)
        
        if len(available_cands) < 2: # no enough candidates
            continue
            
        random.Random(42).shuffle(available_cands)
        target_ref = available_cands[0]
        interfere_ref = available_cands[1]
        target_index = int(target_ref.split("_")[2])
        target_location = loc_places["location"][target_index]
                    
        location_query = location_template.replace("{query}",query["text"]).replace("{location}",target_location)
        datasets["agnews"]["queries"].append(location_query)
        datasets["agnews"]["source_queries"].append(query["text"])
        datasets["agnews"]["perspectives"].append("a news article happened in "+target_location)
        datasets["agnews"]["query_labels"].append("location")
        datasets["agnews"]["key_ref"][query_id] = [len(datasets["agnews"]["corpus"])]
        datasets["agnews"]["corpus"].append(attrprompt[ref_to_index[target_ref]]["text"]) # correct one
        datasets["agnews"]["corpus"].append(attrprompt[ref_to_index[interfere_ref]]["text"]) # wrong one
        
    return datasets


def load_story_main(datasets):
        
    datasets["story"] = {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}
    path = "./dataset/story/pir_story.json"
    with open(path,"r",encoding="utf-8") as f:
        our_story_dataset = json.load(f)
        
    analogy_template = "Find a story that is an analogy to the following story: {query}"
    semantic_template = "Find a story that is with similar entities with the following story: {query}"
    
    for i, item in enumerate(tqdm(our_story_dataset)):
        # analogy
        query_id = len(datasets["story"]["queries"])
        analogy_query = analogy_template.replace("{query}",item[0])
        datasets["story"]["queries"].append(analogy_query.strip())
        datasets["story"]["source_queries"].append(item[0])
        datasets["story"]["perspectives"].append("the analogy of the story") # not sure about this
        datasets["story"]["query_labels"].append("analogy")
        datasets["story"]["key_ref"][query_id] = [len(datasets["story"]["corpus"])]
        datasets["story"]["corpus"].append(item[1])
        # entity similarity
        query_id = len(datasets["story"]["queries"])
        entity_query = semantic_template.replace("{query}",item[0])
        datasets["story"]["queries"].append(entity_query.strip())
        datasets["story"]["source_queries"].append(item[0])
        datasets["story"]["perspectives"].append("similar entities of the story") # not sure about this
        datasets["story"]["query_labels"].append("entity")
        datasets["story"]["key_ref"][query_id] = [len(datasets["story"]["corpus"])]
        datasets["story"]["corpus"].append(item[2])
        
    return datasets


def load_ambigqa_main(datasets):
    datasets["ambigqa"] = {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}
    
    with open("./you path","r") as f:
        ambigqa = json.load(f)

    for i, item in enumerate(ambigqa):
        for anno in item["annotations"]:
            if anno["type"] == "multipleQAs":
                # print(item["question"])
                # print(anno["qaPairs"])
                entries = []
                for temp in item["articles_plain_text"]:
                    entries.extend(temp.split("\n"))

                para_count = 0 # to include a query, para_count should > 1
                paragraphs = []
                
                for q_index,pair in enumerate(anno["qaPairs"]):
                    # print(pair["question"])
                    # print(pair["answer"])
                    paragraphs.append([])
                    for j, entry in enumerate(entries):
                        if pair["answer"][0] in entry:
                            raw = entries[j-1]+ entry+ ".".join(entries[j+1].split(".")[:-1])
                            paragraphs[-1].append(raw)
                
                for temp in paragraphs:
                    if len(temp) > 1:
                        para_count += 1
                
                if para_count > 2:
                    for t, paras in enumerate(paragraphs):
                        if len(paras) > 1:
                            query_id = len(datasets["ambigqa"]["queries"])
                            datasets["ambigqa"]["queries"].append(anno["qaPairs"][t]["question"])
                            datasets["ambigqa"]["source_queries"].append(item["question"])
                            perspective = anno["qaPairs"][t]["question"].replace(item["question"],"")

                            datasets["ambigqa"]["perspectives"].append(perspective)
                            datasets["ambigqa"]["query_labels"].append("perspective")
                            datasets["ambigqa"]["key_ref"][query_id] = [len(datasets["ambigqa"]["corpus"])]
                            datasets["ambigqa"]["corpus"].append(paras[0])
                            
    return datasets


def load_allsides_main(datasets):
    
    datasets["allsides"] = {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}

    directory = ''
    
    raw_dic = {} # topic->news
    
    for filename in os.listdir(directory):
        f = os.path.join(directory, filename)
        # checking if it is a file
        if os.path.isfile(f):
            with open(f,encoding="utf-8") as jsonf:
                this_item = json.load(jsonf)
                
                if this_item["topic"] in list(raw_dic.keys()):
                    raw_dic[this_item["topic"]].append(copy.deepcopy(this_item))
                else:
                    raw_dic[this_item["topic"]] = [copy.deepcopy(this_item)]
    
    print(len(list(raw_dic.keys())))
    
    fine_dic = {}
    for k,v in raw_dic.items():
        temp_side_list = []
        temp_news_list = []
        for item in v:
            if temp_side_list.count(item["bias_text"]) < 2:
                temp_side_list.append(item["bias_text"])
                temp_news_list.append(item)
                
        # we only select the sensitive case with >2 sides on a topic
        if len(temp_side_list) >= 3:
            fine_dic[k] = temp_news_list
            
    # convert the fine-dic to pir dataset format
    for k,v in fine_dic.items():
        for item in v:
            topic_text = item["topic"].replace("_"," ")
            shared = "find a news article [side] on the topic: "
            query_mapping = {"left":shared.replace("[side]","from the left wing"),
                            "right":shared.replace("[side]","from the right wing"),
                            "center":shared.replace("[side]","that is neutral")
                            }
            
            query_id = len(datasets["allsides"]["queries"])
            datasets["allsides"]["queries"].append(query_mapping[item["bias_text"]]+topic_text)
            datasets["allsides"]["source_queries"].append(topic_text)
            datasets["allsides"]["perspectives"].append("a news article biased towards: "+item["bias_text"]) 
            datasets["allsides"]["query_labels"].append(item["bias_text"])
            datasets["allsides"]["key_ref"][query_id] = [len(datasets["allsides"]["corpus"])]
            datasets["allsides"]["corpus"].append(item["content"])
            
    return datasets


def load_exfever_main(datasets):
    
    datasets["exfever"] = {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}
    
    raw_data = pd.read_csv("")
    
    data_scope = 1000000

    entity_count = {}
    for i, item in enumerate(raw_data["golden entity"]):
        
        if i >= data_scope:
            continue
        
        if raw_data["label"][i] not in ["SUPPORT","REFUTE"]:
            continue
        
        for char in ["[","]","'"]:
            item = item.replace(char,"")
        for entity in item.split(","):
            if entity not in list(entity_count.keys()):
                entity_count[entity] = 1
            else:
                entity_count[entity] += 1
                
    golden_entity_colleciton = []
    
    for i, item in enumerate(raw_data["golden entity"]):
        
        if i >= data_scope:
            continue
        
        if raw_data["label"][i] not in ["SUPPORT","REFUTE"]:
            continue
            
        entities = item.replace("[","").replace("]","").replace("'","")
        entities = entities.split(",")
        
        # we need queries with at least one entity appear once
        this_count = sum([entity_count[x] for x in entities])
        
        if this_count < 50 or item in golden_entity_colleciton:
            continue


        golden_entity_colleciton.append(item)
        
        query_mapping = {"SUPPORT":"Find an explanation that supports the claim: ",
                        "REFUTE":"Find an explanation that refutes the claim: "
                        }        
        
        # start adding
        query_id = len(datasets["exfever"]["queries"])
        
        # sampling from 17946 in total
        
        datasets["exfever"]["queries"].append(query_mapping[raw_data["label"][i]]+raw_data["claim"][i])
        datasets["exfever"]["source_queries"].append(raw_data["claim"][i])
        datasets["exfever"]["perspectives"].append(query_mapping[raw_data["label"][i]][:-2].replace("Find an","")) 
        datasets["exfever"]["query_labels"].append(raw_data["label"][i])
        datasets["exfever"]["key_ref"][query_id] = [len(datasets["exfever"]["corpus"])]
        datasets["exfever"]["corpus"].append(raw_data["explanation"][i])
        
    return datasets
        
        
                            
def dataset_main():
    datasets = {}
    
    # perspectrum: pro vs.con
    datasets = load_perspectrum_main(datasets)
            
    # agnews: topic vs. location
    datasets = load_agnews_main(datasets)

    # story_gen: analogy vs. ent similartiy
    datasets = load_story_main(datasets)
    
    # ambigqa: question specific perspective
    datasets = load_ambigqa_main(datasets)
    
    # allside: left, center, and right wings(can scale to 1,000 question)
    datasets = load_allsides_main(datasets)

    # support or refute, where claims are with shared entites
    datasets = load_exfever_main(datasets)
                            
        
    return datasets


from rank_bm25 import BM25Okapi

def bm25_main(datasets):
    # corpuses,key_refs = corpus_building(datasets)

    for k,v in datasets.items():
        print("we are working on:",k)
        
        queries = v["queries"]
        corpus = v["corpus"]
        key_ref = v["key_ref"]
        query_labels = v["query_labels"]
        
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        bm25 = BM25Okapi(tokenized_corpus)

        corpus_scores = []

        for query in tqdm(queries):
            # query = item["query"]
            tokenized_query = query.split(" ")
            doc_scores = bm25.get_scores(tokenized_query)
            corpus_scores.append(doc_scores)
        
        with open("bm25_"+k+"_scores.json","w",encoding="utf-8") as f:
            json.dump([x.tolist() for x in corpus_scores],f)
        
        evaluation(key_ref, corpus_scores, query_labels, k)
        print()


def create_embeddings(tokenizer, model, texts):
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
        
    batch_size = 17 #29

    model.to(device)

    # naive batching
    if len(texts) < batch_size:
        inputs = tokenizer(texts,max_length=80, padding=True, truncation=True, return_tensors="pt")
        inputs = inputs.to(device)
        with torch.no_grad():
            batch_embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
            embeddings = []
            for embedding in batch_embeddings:
                embeddings.append(embedding.detach().cpu().tolist())
            del batch_embeddings
            torch.cuda.empty_cache()
    else:
        embeddings = []
        num_batch = len(texts)//batch_size

        for i in trange(num_batch+1):
            batch_start = i*batch_size
            batch_end = min(len(texts), (i+1)*batch_size)
            batch_texts = texts[batch_start:batch_end]

            inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
            inputs = inputs.to(device)

            with torch.no_grad():
                try:
                    batch_embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
                    embeddings.extend(batch_embeddings.detach().cpu().tolist())

                    # save cuda memory
                    del batch_embeddings
                    del inputs
                    torch.cuda.empty_cache()
                except:
                    message = "broken embeddings"
    return embeddings


def contriever_main(datasets):
    # corpuses,key_refs = corpus_building(datasets)
    
    def mean_pooling(token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        
        return sentence_embeddings
    
    def contriever_embeddings(texts, tokenizer, model):
        # device = torch.device('cuda')
        device = torch.device('cpu')
        # create tokenized inputs
        batch_size = 29

        model.to(device)
        embeddings = []
        # naive batching
        if len(texts) < batch_size:
            inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
            outputs = model(**inputs)    
            batch_embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
            for embedding in batch_embeddings:
                embeddings.append(embedding.detach().cpu().tolist())
                
            del batch_embeddings
            torch.cuda.empty_cache()
        else:
            num_batch = len(texts)//batch_size

            for i in trange(num_batch+1):
                batch_start = i*batch_size
                batch_end = min(len(texts), (i+1)*batch_size)
                batch_texts = texts[batch_start:batch_end]

                inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
                inputs = inputs.to(device)

                with torch.no_grad():
                    try:
                        batch_embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
                        embeddings.extend(batch_embeddings.detach().cpu().tolist())
                        del batch_embeddings
                        del inputs
                        torch.cuda.empty_cache()
                    except:
                        message = "broken embeddings"   
                        
        return embeddings
    

# datasets = dataset_main()