#!/usr/bin/env python3
"""Given a data file with LM QA predictions, evaluate the predictions.
"""
import argparse
import json
import logging
import statistics
import sys
from copy import deepcopy

from tqdm import tqdm
from xopen import xopen
import torch
from lost_in_the_middle.metrics import best_subspan_em

logger = logging.getLogger(__name__)

METRICS = [
    (best_subspan_em, "best_subspan_em"),
]

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")



def recover_v2(
    # self,
    original_prompt: str,
    compressed_prompt: str,
    response: str,
):
    import bisect
    from collections import defaultdict
    def lcs(X, Y):  
        n, m = len(X), len(Y)  
        if n < m:  
            X, Y, n, m = Y, X, m, n
    
        dp = [0] * (m + 1)
        for i in range(1, n + 1):  
            prev = dp[0]  
            for j in range(1, m + 1):  
                temp = dp[j]  
                if X[i - 1] == Y[j - 1]:  
                    dp[j] = prev + 1  
                else:  
                    dp[j] = max(dp[j], dp[j - 1])  
                prev = temp  
        return dp[-1] 

    def match_from_compressed(response_word):
        response_input_ids = tokenizer(response_word, add_special_tokens=False)["input_ids"]
        if response_input_ids and response_input_ids[0] == 29871 and response_word[0] != " ":
            response_input_ids = response_input_ids[1:]
        response_set, response_c = set(response_input_ids), defaultdict(list)
        for idx in range(M):
            if original_input_ids[idx] in response_set:
                response_c[original_input_ids[idx]].append(idx)
        res, res_min, res_c = None, float("inf"), 1
        n = len(response_input_ids)
        if n == 0:
            return response_word
        for l in response_c[response_input_ids[0]]:
            x, y, c = 0, l, 1
            flag = True
            for x in range(1, n):
                idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
                if idx >= len(response_c[response_input_ids[x]]) or response_c[response_input_ids[x]][idx] - y > 10:
                    continue
                c += 1
                    # flag = False
                    # break
                y = response_c[response_input_ids[x]][idx]
            # if not flag:
            #     break
            if c > res_c:
                res_c = c
                res_min = y - l + 1
                res = (l, y + 1)
            elif c == res_c and y - l + 1 < res_min:
                res_min = y - l + 1
                res = (l, y + 1)

        if res is None:
            return response_word
        # while l > 0 and not tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
        #     l -= 1
        # while r < M - 1 and not tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
        #     l -= 1
        y = tokenizer.decode(original_input_ids[res[0]:res[1]])
        return response_word if lcs(y, response_word) < len(response_word) else y

    def replace_dot(r, compressed_prompt):
        if response_words[r + 1] == "-" and " ".join(response_words[l:r+1]) + "–" in compressed_prompt:
            response_words[r + 1] = "–"
            return True
        return False


    response_words = response.split(" ")
    compressed_prompt_words = compressed_prompt.split()
    original_prompt_words = original_prompt.split(" ")

    # original_tokens = self.tokenizer(original_prompt, return_offsets_mapping=True, add_special_tokens=False)

    original_input_ids = tokenizer(original_prompt, add_special_tokens=False)['input_ids']
    N, M = len(response_words), len(original_input_ids)
    recovered_response_words = []
    l = 0
    while l < N:
        if response_words[l] not in compressed_prompt:
            recovered_response_words.append(response_words[l])
            l += 1
            continue
        r = l
        while r + 1 < N and (" ".join(response_words[l:r+2]) in compressed_prompt or replace_dot(r, compressed_prompt)):
            r += 1
        # print(" ".join(response_words[l:r+1]))
        match_words = match_from_compressed(" ".join(response_words[l:r+1]))
        # print(l, r, " ".join(response_words[l:r+1]), "@", match_words)
        recovered_response_words.append(match_words)
        l = r + 1
    return " ".join(recovered_response_words)

def main(
    input_path,
    output_path,
    prompt_path,
):
    all_examples = []
    prompts = json.load(open("prompt/loss_in_middle/full_20_0.json"))
    compressed_prompts = json.load(open(prompt_path))
    with xopen(input_path) as fin:
        for idx, line in tqdm(enumerate(fin)):
            input_example = json.loads(line)
            pred = input_example["model_answer"]
            original_prompt = prompts[int(idx)]["prompt"]
            compressed_prompt = compressed_prompts[int(idx)]["prompt"]["compressed_prompt"]
            # pred1 = recover(original_prompt, compressed_prompt, pred)
            pred = pred.replace("<|im_end|>", "")
            pred2 = recover_v2(original_prompt, compressed_prompt, pred)
            # pred3 = recover_v2(original_prompt, compressed_prompt, pred2.replace(", ", ",").replace(". ", ".").replace("! ", "!").replace("? ", "?"))
            # if pred3 != pred2.replace(", ", ",").replace(". ", ".").replace("! ", "!").replace("? ", "?"):
                # pred2 = pred3
            # print(pred, pred2, input_example["answers"])
            input_example["model_answer"] = pred2
            all_examples.append(input_example)

    # Compute normal metrics in parallel, if applicable
    logger.info("Computing metrics")
    all_example_metrics = []
    for example in tqdm(all_examples):
        all_example_metrics.append(get_metrics_for_example(example))

    # Average metrics across examples
    for (_, metric_name) in METRICS:
        average_metric_value = statistics.mean(
            example_metrics[metric_name] for (example_metrics, _) in all_example_metrics
        )
        logger.info(f"{metric_name}: {average_metric_value}")

    if output_path:
        with xopen(output_path, "w") as f:
            for (example_metrics, example) in all_example_metrics:
                example_with_metrics = deepcopy(example)
                for metric_name, metric_value in example_metrics.items():
                    example_with_metrics[f"metric_{metric_name}"] = metric_value
                f.write(json.dumps(example_with_metrics) + "\n")


def get_metrics_for_example(example):
    gold_answers = example["answers"]
    model_answer = example["model_answer"]

    # NOTE: we take everything up to the first newline, since otherwise models could hack
    # the metric by simply copying te input context (as the gold answer is guaranteed
    # to occur in the input context).
    model_answer = model_answer.split("\n")[0].strip()
    # print(model_answer, gold_answers)

    example_metrics = {}
    for (metric, metric_name) in METRICS:
        example_metrics[metric_name] = metric(prediction=model_answer, ground_truths=gold_answers)
    return (example_metrics, example)


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s - %(module)s - %(levelname)s - %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt-path", help="Path to data with model predictions and answers.", required=True)
    parser.add_argument("--input-path", help="Path to data with model predictions and answers.", required=True)
    parser.add_argument(
        "--output-path",
        help="Path to write data with model predictions, answers, and scores.",
    )
    args = parser.parse_args()

    logger.info("running %s", " ".join(sys.argv))
    main(
        args.input_path,
        args.output_path,
        args.prompt_path,
    )
    logger.info("finished running %s", sys.argv[0])