import json
import os
import requests
import random

import time
import tiktoken
from xopen import xopen
from copy import deepcopy
from tqdm import tqdm
import numpy as np

from lost_in_the_middle.prompting import (
    Document,
    get_closedbook_qa_prompt,
    get_qa_prompt,
)
from models.prompt_compression import PromptCompressor

encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
# llm_lingua = PromptCompressor(device_map="cuda:0")

import openai

openai.api_key = "sk-"

def get_results(predict_path: str):
    from tqdm import tqdm
    import json
    import statistics
    from lost_in_the_middle.metrics import best_subspan_em
    METRICS = [(best_subspan_em, "best_subspan_em")]

    def get_metrics_for_example(example):
        gold_answers = example["answers"]
        model_answer = example["model_answer"]

        # NOTE: we take everything up to the first newline, since otherwise models could hack
        # the metric by simply copying te input context (as the gold answer is guaranteed
        # to occur in the input context).
        model_answer = model_answer.split("\n")[0].strip()

        example_metrics = {}
        for (metric, metric_name) in METRICS:
            example_metrics[metric_name] = metric(prediction=model_answer, ground_truths=gold_answers)
        return (example_metrics, example)
    preds = json.load(open(predict_path))

    res = []
    all_example_metrics = []
    for example in tqdm(preds.values()):
        pred, ref = example["pred"], example["reference"]
        pred = pred.split("\n\n")[0]
        res.append([ref, pred, example["idx"]])
        all_example_metrics.append(get_metrics_for_example({"answers": ref, "model_answer": pred}))

    # Average metrics across examples
    for (_, metric_name) in METRICS:
        average_metric_value = statistics.mean(
            example_metrics[metric_name] for (example_metrics, _) in all_example_metrics
        )
        print(f"{metric_name}: {average_metric_value}")
    return res


def get_token_length(text: str):
    return len(encoding.encode(text))

def get_dataset_retrieval(path: str, doc_num: int = 20, d_idx: int = 0):
    def get_distance_bm25(corpus, query):
        from rank_bm25 import BM25Okapi
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        bm25 = BM25Okapi(tokenized_corpus)
        tokenized_query = query.split(" ")
        doc_scores = bm25.get_scores(tokenized_query)
        idx = [ii for ii in (-doc_scores).argsort()].index(d_idx)
        return idx
    def get_distance_gzip(corpus, query):
        def get_score(x, y):
            cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
            cxy = len(gzip.compress(f"{x} {y}".encode()))
            return (cxy - min(cx, cy)) / max(cx, cy)
        import gzip
        doc_scores = [get_score(doc, query) for doc in corpus]
        # return np.argsort(doc_scores)
        idx = [ii for ii in np.argsort(doc_scores)].index(d_idx)
        return idx

    def get_distance_sentbert(corpus, query):
        doc_embeds = model.encode(corpus)
        query = model.encode(query)
        doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
        idx = [ii for ii in np.argsort(doc_scores)].index(d_idx)
        return idx

    def get_distance_sentbert_bge(corpus, query):
        instruction = "Represent this sentence for searching relevant passages:"
        doc_embeds = model.encode([i for i in corpus], normalize_embeddings=True)
        query = model.encode(query, normalize_embeddings=True)
        doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
        idx = [ii for ii in np.argsort(doc_scores)].index(d_idx)
        return idx

    def get_distance_openai(corpus, query):
        import openai
        openai.api_key = ""
        openai.api_version = '2023-05-15'
        deployment_name="embed"
        def get_embed(text):
            try:
                # time.sleep(0.2)
                return openai.Embedding.create(input=[text.replace("\n", " ")], engine=deployment_name)['data'][0]['embedding']
            except openai.error.RateLimitError:
                time.sleep(5)
                return get_embed(text)
        doc_embeds = [get_embed(i) for i in corpus]
        query = get_embed(query)
        doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
        idx = [ii for ii in np.argsort(doc_scores)]
        # idx = [ii for ii in np.argsort(doc_scores)].index(d_idx)
        return idx
    def ensemble(rank1, rank2):
        c = defaultdict(int)
        for idx, ii in enumerate(rank1):
            c[ii] += idx
        for idx, ii in enumerate(rank2):
            c[ii] += idx
        return [ii for ii, _ in sorted(c.items(), key=lambda x: x[1])].index(d_idx)

    from sentence_transformers import SentenceTransformer, util
    model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
    # model = SentenceTransformer('all-mpnet-base-v2')
    # model = SentenceTransformer('BAAI/bge-large-en')

    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])
            corpus = demonstration.split("\n")

            # idx = get_distance_openai(corpus, question)
            idx = get_distance_sentbert(corpus, question)
            res.append(idx)
    for idx in range(1, 21):
        print("R@{},{:.2f}".format(idx, len([ii for ii in res if ii < idx]) / len(res) * 100))

    for idx in range(1, 21):
        get_idxs = [jj for jj, ii in enumerate(res) if ii == idx - 1]
        ans = [best_subspan_em(ii[1], ii[0]) for ii in xx if int(ii[-1]) in get_idxs]
        print("R@{},{:.2f}".format(idx, sum(ans), sum(ans) / len(ans) * 100 if ans else 0))

def get_dataset(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in enumerate(f):
            input_example = json.loads(jj)

            question = input_example["question"]

            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )
            # prompt = get_closedbook_qa_prompt(question)
            res.append({"id": ii, "prompt": prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/full_{doc_num}_{idx}.json", "w"))


def get_dataset_selective_context(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )
            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            y = sc(prompt, reduce_ratio=0.55)
            compressed_prompt = y[0].replace("<s>", "").strip()
            compressed_prompt = {
                "compressed_prompt": compressed_prompt,
                "origin_tokens": len(encoding.encode(prompt)),
                "compressed_tokens": len(encoding.encode(compressed_prompt)),
            }
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})
    json.dump(res, open(f"prompt/loss_in_middle/selective_context_{doc_num}_{idx}_2x.json", "w"))

def get_dataset_unrelated(path: str, doc_num: int = 20, idx: int = 0, k=1):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])
            demonstrations = demonstration.split("\n")
            demonstration = "\n".join(demonstrations[0:k])
            compressed_prompt = {"compressed_prompt": instruction + "\n\n" + demonstration + "\n\n" + question}

            # compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.75, use_sentence_level_filter=False, condition_in_question="after", reorder_demonstrations=False, dynamic_demonstration_compression_ratio=0, rank_method="bm25", use_token_level_filter=False, demonstration_budget="+0")
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})
    json.dump(res, open(f"prompt/loss_in_middle/unrelated_{doc_num}_{idx}_{k}.json", "w"))


compressed_prompt = llm_lingua(demonstrations, instruction, question, 0.5, use_sentence_level_filter=False, condition_in_question="none", reorder_demonstrations=False, dynamic_demonstration_compression_ratio=0, demonstration_budget="+100")

def get_dataset_baselines(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.5, use_sentence_level_filter=False, condition_in_question="none", reorder_demonstrations=False, dynamic_demonstration_compression_ratio=0, rank_method="sentbert", use_token_level_filter=False, demonstration_budget="+0")
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})
    json.dump(res, open(f"prompt/loss_in_middle/sentbert_{doc_num}_{idx}_2x.json", "w"))

    for i in [0, 4, 9, 14, 19]:
        data = json.load(open(f"prompt/loss_in_middle/bm25_20_{i}_4x.json"))
        print(sum([f"Document [{i + 1}]" in ii["prompt"]["compressed_prompt"] for ii in data]) / 2655)

def get_dataset_baselines_reorder(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.5, use_sentence_level_filter=False, condition_in_question="after", reorder_demonstrations=True, dynamic_demonstration_compression_ratio=0, rank_method="openai", use_token_level_filter=False, demonstration_budget="+0")
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/openai_{doc_num}_{idx}_2x_reorder.json", "w"))


def get_dataset_llmlingua(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.75, use_sentence_level_filter=False, condition_in_question="none", reorder_demonstrations=True, dynamic_demonstration_compression_ratio=0, condition_compare=False, demonstration_budget="+100",)
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/llmlingua_{doc_num}_{idx}_4x_reorder.json", "w"))

def get_dataset_ours_sbert(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.75, use_sentence_level_filter=False, condition_in_question="after_condition", reorder_demonstrations="original", dynamic_demonstration_compression_ratio=0.2, condition_compare=True, demonstration_budget="+400", token_budget_ratio=1.0, rank_method="ours_sentbert")
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/ours_sbert_{doc_num}_{idx}_4x_dem_400_after_add_prompt1_dy02dem_condition_ori_order.json", "w"))

def get_demonstration_ppl(demonstration):
    demonstrations = demonstration.split("\n")
    tokens = [llm_lingua.get_token_length(ii, False) + 1 for ii in demonstrations]
    ppl = llm_lingua.get_ppl(demonstration, "token")
    ppl2 = llm_lingua.get_ppl(question + demonstration, "token")
    q_tokens = llm_lingua.get_token_length(question, False)
    ppl3 = ppl - ppl2[q_tokens:]
    pre = 0
    ppl = (ppl - ppl.min()) / (ppl.max() - ppl.min())
    ppl3 = (ppl3 - ppl3.min()) / (ppl3.max() - ppl3.min())
    # res, res2 = [], []
    # for pre in range(0, len(ppl), 10):
    #     tmp = ppl[pre:pre+10]
    #     tmp2= ppl3[pre:pre+10]
    #     res.append(tmp.mean())
    #     res2.append(tmp2.mean())
    # max1, min1 = max(res), min(res)
    # max2, min2 = max(res2), min(res2)
    # res = [(ii - min1) / (max1 - min1) for ii in res]
    # res2 = [(ii - min2) / (max2 - min2) for ii in res2]

    for t in tokens:
        tmp = ppl[pre:pre+t]
        tmp2= ppl3[pre:pre+t]
        pre += t
        res.append(tmp.mean())
        res2.append(tmp2.mean())
    return res



def get_dataset_ours_sbert(path: str, doc_num: int = 20, idx: int = 0, y=""):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.75, use_sentence_level_filter=False, condition_in_question="after_condition", reorder_demonstrations=y, dynamic_demonstration_compression_ratio=0.2, condition_compare=True, demonstration_budget="+400", token_budget_ratio=1.00, rank_method="ours_sentbert")
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/ours_sbert_{doc_num}_{idx}_4x_dem_400_after_add_prompt1_dy02dem_condition_sent100_{y}.json", "w"))

def get_dataset_ours(path: str, doc_num: int = 20, idx: int = 0):
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.95, use_sentence_level_filter=False, condition_in_question="after_condition", reorder_demonstrations="original", dynamic_demonstration_compression_ratio=0.4, condition_compare=True, demonstration_budget="+100", token_budget_ratio=1.05)
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/ours_{doc_num}_{idx}_4x_dem_13_after_add_prompt1_dy04dem_compare_condition_sent105_ori_sort.json", "w"))


def get_dataset_ours(path: str, doc_num: int = 20, idx: int = 0):
    idxs = list(range(2655))
    random.seed(181)
    random.shuffle(idxs)
    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res) or ii not in idxs[:200]:
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])

            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.55, use_sentence_level_filter=False, condition_in_question="none_condition", reorder_demonstrations="original", dynamic_demonstration_compression_ratio=0.4, condition_compare=True, demonstration_budget="+100", token_budget_ratio=1.05)
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})
    json.dump(res, open(f"prompt/loss_in_middle/wo_coarse_level_{idx}.json", "w"))

    res = []
    with xopen(path) as f:
        for ii, jj in tqdm(enumerate(f), total=2655):
            if ii < len(res):
                continue
            input_example = json.loads(jj)
            question = input_example["question"]
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))

            prompt = get_qa_prompt(
                question,
                documents,
                mention_random_ordering=False,
                query_aware_contextualization=False,
            )

            c = prompt.split("\n\n")
            instruction, question = c[0], c[-1]
            demonstration = "\n".join(c[1:-1])
            compressed_prompt = llm_lingua(demonstration.split("\n"), instruction, question, 0.9, use_sentence_level_filter=True, condition_in_question="after", force_demonstrations_number=2)
            res.append({"id": ii, "prompt": compressed_prompt, "answer": input_example["answers"]})

    json.dump(res, open(f"prompt/loss_in_middle/ours_{doc_num}_{idx}_10x_dem_sent_after_force2.json", "w"))

def recover(
    # self,
    original_prompt: str,
    compressed_prompt: str,
    response: str,
):
    def match_from_compressed(response_words):
        # response_input_ids = self.tokenizer(response)["input_ids"]
        response_input_ids = tokenizer(response_words, add_special_tokens=False)["input_ids"]
        starts_ids = [idx for idx in range(M) if response_input_ids[0] == original_input_ids[idx]]
        res, res_min = None, float("inf")
        n = len(response_input_ids)
        for l in starts_ids:
            x, y = 0, l
            while x < n and y < M:
                if response_input_ids[x] == original_input_ids[y]:
                    x += 1
                    y += 1
                else:
                    y += 1
                if y - l >= res_min:
                    break
            if x == n:
                if y - l < res_min:
                    res_min = y - l
                    res = (l, y)
        if res is None:
            return response_words
        # while l > 0 and not tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
        #     l -= 1
        # while r < M - 1 and not tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
        #     l -= 1
        return tokenizer.decode(original_input_ids[res[0]:res[1]])


    response_words = response.split()
    compressed_prompt_words = compressed_prompt.split()
    original_prompt_words = original_prompt.split(" ")

    # original_tokens = self.tokenizer(original_prompt, return_offsets_mapping=True, add_special_tokens=False)

    original_input_ids = tokenizer(original_prompt, add_special_tokens=False)['input_ids']
    N, M = len(response_words), len(original_input_ids)
    recovered_response_words = []
    l = 0
    while l < N:
        if response_words[l] not in compressed_prompt:
            recovered_response_words.append(response_words[l])
            l += 1
            continue
        r = l
        while r + 1 < N and " ".join(response_words[l:r+2]) in compressed_prompt:
            r += 1
        # print(l, r, " ".join(response_words[l:r+1]))
        match_words = match_from_compressed(" ".join(response_words[l:r+1]))
        recovered_response_words.append(match_words)
        l = r + 1
    return " ".join(recovered_response_words)

def recover_v2(
    # self,
    original_prompt: str,
    compressed_prompt: str,
    response: str,
):
    import bisect
    from collections import defaultdict
    def lcs(X, Y):  
        n, m = len(X), len(Y)  
        if n < m:  
            X, Y, n, m = Y, X, m, n
    
        dp = [0] * (m + 1)
        for i in range(1, n + 1):  
            prev = dp[0]  
            for j in range(1, m + 1):  
                temp = dp[j]  
                if X[i - 1] == Y[j - 1]:  
                    dp[j] = prev + 1  
                else:  
                    dp[j] = max(dp[j], dp[j - 1])  
                prev = temp  
        return dp[-1] 

    def match_from_compressed(response_word):
        response_input_ids = tokenizer(response_word, add_special_tokens=False)["input_ids"]
        if response_input_ids and response_input_ids[0] == 29871 and response_word[0] != " ":
            response_input_ids = response_input_ids[1:]
        response_set, response_c = set(response_input_ids), defaultdict(list)
        for idx in range(M):
            if original_input_ids[idx] in response_set:
                response_c[original_input_ids[idx]].append(idx)
        res, res_min, res_c = None, float("inf"), 1
        n = len(response_input_ids)
        if n == 0:
            return response_word
        for l in response_c[response_input_ids[0]]:
            x, y, c = 0, l, 1
            flag = True
            for x in range(1, n):
                idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
                if idx >= len(response_c[response_input_ids[x]]) or response_c[response_input_ids[x]][idx] - y > 10:
                    continue
                c += 1
                    # flag = False
                    # break
                y = response_c[response_input_ids[x]][idx]
            # if not flag:
            #     break
            if c > res_c:
                res_c = c
                res_min = y - l + 1
                res = (l, y + 1)
            elif c == res_c and y - l + 1 < res_min:
                res_min = y - l + 1
                res = (l, y + 1)

        if res is None:
            return response_word
        # while l > 0 and not tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
        #     l -= 1
        # while r < M - 1 and not tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
        #     l -= 1
        y = tokenizer.decode(original_input_ids[res[0]:res[1]])
        return response_word if lcs(y, response_word) < len(response_word) else y

    def replace_dot(r, compressed_prompt):
        if response_words[r + 1] == "-" and " ".join(response_words[l:r+1]) + "–" in compressed_prompt:
            response_words[r + 1] = "–"
            return True
        return False


    response_words = response.split(" ")
    compressed_prompt_words = compressed_prompt.split()
    original_prompt_words = original_prompt.split(" ")

    # original_tokens = self.tokenizer(original_prompt, return_offsets_mapping=True, add_special_tokens=False)

    original_input_ids = tokenizer(original_prompt, add_special_tokens=False)['input_ids']
    N, M = len(response_words), len(original_input_ids)
    recovered_response_words = []
    l = 0
    while l < N:
        if response_words[l] not in compressed_prompt:
            recovered_response_words.append(response_words[l])
            l += 1
            continue
        r = l
        while r + 1 < N and (" ".join(response_words[l:r+2]) in compressed_prompt or replace_dot(r, compressed_prompt)):
            r += 1
        # print(" ".join(response_words[l:r+1]))
        match_words = match_from_compressed(" ".join(response_words[l:r+1]))
        # print(l, r, " ".join(response_words[l:r+1]), "@", match_words)
        recovered_response_words.append(match_words)
        l = r + 1
    return " ".join(recovered_response_words)

def get_pred(preds: dict, prompt_path: str):
    compressed_prompts = json.load(open(prompt_path))
    prompts = json.load(open("prompt/loss_in_middle/full_20_0.json"))
    for idx, example in tqdm(preds.items()):
        pred, ref = example["pred"].strip(), example["reference"]
        pred = pred.split("\n\n")[0].split("\n")[0].strip()
        original_prompt = prompts[int(idx)]["prompt"]
        compressed_prompt = compressed_prompts[int(idx)]["prompt"]["compressed_prompt"]
        # pred1 = recover(original_prompt, compressed_prompt, pred)
        pred2 = recover_v2(original_prompt, compressed_prompt, pred)
        example["pred_covered"] = pred2
    return preds

def analysis_results(predict_path: str, prompt_path: str):
    from tqdm import tqdm
    import json
    import statistics
    from lost_in_the_middle.metrics import best_subspan_em
    METRICS = [(best_subspan_em, "best_subspan_em")]

    def get_metrics_for_example(example):
        gold_answers = example["answers"]
        model_answer = example["model_answer"]

        example_metrics = {}
        for (metric, metric_name) in METRICS:
            example_metrics[metric_name] = metric(prediction=model_answer, ground_truths=gold_answers)
        return (example_metrics, example)
    preds = json.load(open(predict_path))
    compressed_prompts = json.load(open(prompt_path))
    prompts = json.load(open("prompt/loss_in_middle/full_20_0.json"))
    pairs = []
    all_example_metrics = []
    for idx, example in tqdm(preds.items()):
        pred, ref = example["pred"].strip(), example["reference"]
        pred = pred.split("\n\n")[0].split("\n")[0].strip()
        original_prompt = prompts[int(idx)]["prompt"]
        compressed_prompt = compressed_prompts[int(idx)]["prompt"]["compressed_prompt"]
        # pred1 = recover(original_prompt, compressed_prompt, pred)
        pred = pred.replace("<|im_end|>", "")
        pred2 = recover_v2(original_prompt, compressed_prompt, pred)
        pred3 = recover_v2(original_prompt, compressed_prompt, pred2.replace(", ", ",").replace(". ", ".").replace("! ", "!").replace("? ", "?"))
        if pred3 != pred2.replace(", ", ",").replace(". ", ".").replace("! ", "!").replace("? ", "?"):
            pred2 = pred3
        # all_example_metrics.append({"answers": ref, "model_answer": pred2, "origin": pred, "idx": idx})
        pairs.append([ref, idx, pred, pred2])
        all_example_metrics.append(get_metrics_for_example({"answers": ref, "model_answer": pred2}))
    # return all_example_metrics
    # Average metrics across examples
    for (_, metric_name) in METRICS:
        average_metric_value = statistics.mean(
            example_metrics[metric_name] for (example_metrics, _) in all_example_metrics
        )
        print(f"{metric_name}: {average_metric_value}")
    return pairs

llm_client = LLMClient()

model = "gpt-3.5-turbo"
done, res = set(), {}
position = "14"
# threshold = f"full_20_{position}"
threshold = f"ours_20_{position}_4x_dem_400_after_add_prompt1_dy04dem_compare_condition_ori_order"
data = json.load(open(f"prompt/loss_in_middle/{threshold}.json"))

# output_path = f'outputs/gpt_3.5_turbo_{threshold}_text_llama2.txt'
output_path = f'outputs_loss_in_middle/{threshold}_text_llama2_nocon_breath.txt'
if os.path.exists(output_path):
    res = json.load(open(output_path))

for demonstration in tqdm(
    data, total=len(data)
):
    idx, prompt, answer = [demonstration[k] for k in ["id", "prompt", "answer"]]
    if isinstance(prompt, dict):
        prompt = prompt["compressed_prompt"]
    if str(idx) in res:
        continue

    # message = [
    #     {"role": "user", "content":  prompt},
    # ]

    request_data = {
        # "messages": message,
        "prompt": prompt,
        "max_tokens": 100,
        "temperature": 0,
        "top_p": 1,
        "n": 1,
        "stream": False,
        "stop": "\r\n",
    }

    response = openai.Completion.create(
        model="gpt-3.5-turbo",
        **request_data,
    )

    # torch.save(response, "debug.pt")
    if "choices" not in response:
        print("TOO FAST")
        if "error" in response:
            print(response["error"]["message"])
        time.sleep(62)
        continue
    # ans_model = response["choices"][0]["message"]["content"]
    ans_model = response["choices"][0]["text"]
    print(idx, position, ans_model.replace("\n", " "))
    res[idx] = {
        "idx": str(idx),
        "position": position,
        "pred": ans_model,
        "reference": answer,
    }
    json.dump(res, open(output_path, "w"))
    time.sleep(8)
