from factscore.factscorer import FactScorer
import os
import json
import tqdm
import jsonlines
import numpy as np
import openai
import backoff
import time
import argparse
import vllm
import spacy
from .baseline import FavaBaseline
from .utils import remove_error_tags, load_jsonlines, split_sentences, run_detection

# Fava prompt formats
INPUT = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
nlp = spacy.load("en_core_web_sm")

class FavaEvaluation(object):
    def __init__(self,
                 model_name="meta-llama/Llama-2-7b-hf",
                 input_file=None,
                 output_file=None,
                 metric="factscore",
                 openai_key=None,
                 retrieval=False,
                 use_gpt=False,
                 max_new_tokens=1024,
                 do_sample=False,
                 temperature=1.0,
                 top_p=1.0):
        self.model_name = model_name
        self.input_file = input_file
        self.output_file = output_file
        self.use_gpt = use_gpt
        self.lm = None
        self.sampling_params = None
        if use_gpt:
            openai.api_key = openai_key
        else:
            self.lm = vllm.LLM(model=model_name)
            self.sampling_params = vllm.SamplingParams(
                temperature=0,
                top_p=1.0,
                max_tokens=2000,
            )
        self.retrieval = retrieval 
        self.openai_key = openai_key
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample
        self.temperature = temperature
        self.top_p = top_p
        self.metric = metric
    
    def collect_lm_preds(self, instances):
        outputs = []
        if "fava" in self.model_name:   
            # format prompts using passage and evidence
            prompts = []
            for inst in instances:
                evidences = []
                if "ctxs" in inst:
                    for c in inst["ctxs"]:
                        evidences.append(c["text"])
                if "evidence" in inst:
                    evidences.append(input["evidence"])
                idx = 1
                evidence_prompt = ""
                for e in evidences:
                    evidence_prompt += "Reference [" + str(idx) + "]: " + e + "\n"
                    idx += 1
                prompts.append(
                    INPUT.format_map({"evidence": evidence_prompt, "output": inst["output"]})
                )
            # populate edited outputs
            model = vllm.LLM(model=self.model_name)
            sampling_params = vllm.SamplingParams(
                temperature=self.temperature if self.do_sample else 0,
                top_p=self.top_p,
                max_tokens=self.max_new_tokens,
            )
            outputs = model.generate(prompts, sampling_params)
        else:
            fb = FavaBaseline(model_name=self.model_name,
                    input_file=self.input_file,
                    output_file=self.output_file,
                    openai_key=self.openai_key,
                    retrieval=self.retrieval,
                    use_gpt=self.use_gpt)
            instances = json.load(open(self.input_file))
            for inst in instances:
                inst = fb.get_baseline_pred(inst)
                outputs.append(inst["lm_edited"])
        return outputs
    
    def run_eval(self):
        # generate target model preds
        if self.model_name is not None:
            if self.input_file is not None:
                instances = json.load(open(self.input_file))
                outputs = self.collect_lm_preds(instances)
            else:
                print("no input file given")
        else:
            print("no model given")

        # if factscore metric, set up factscore
        if self.metric == "factscore":
            fs = FactScorer(model_name="retrieval+ChatGPT", openai_key=self.openai_key)

        for i in range(len(outputs)):
            # calculate factscore for each instance
            if self.metric == "factscore":
                title = instances[i]["title"]
                edited_output = remove_error_tags(outputs[i])
                try:
                    metric_result = fs.get_score([title], [edited_output], gamma=10)
                except:
                    try:
                        print("sleeping for 60s...")
                        time.sleep(60)
                        metric_result = fs.get_score([title], [edited_output], gamma=10)
                    except:
                        metric_result = -1
                instances[i]["edited_output"] = outputs[i]
                instances[i]["factscore"] = metric_result
            # calculate fine grained error detection result for each instance
            if self.metric == "detection":
                instances[i]["edited_output"] = outputs[i]
                pred_sentences = split_sentences(outputs[i])
                gold_sentences = split_sentences(instances[i]["annotated"])
                sentences = min(len(pred_sentences), len(gold_sentences))
                new_inputs = run_detection(
                    gold_sentences[0:sentences], pred_sentences[0:sentences], instances, i
                )
                instances = new_inputs
        
        return instances


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default=None,
        help="Huggingface model name or path.",
    )
    parser.add_argument(
        "--input_file",
        type=str,
        default=None,
        help="Input .json files containing input passages and references and/or titles if using factscore.",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default="output/result.jsonl",
        help="Output .jsonl file containing passages and scores.",
    )
    parser.add_argument(
        "--metric", type=str, default="factscore", help="Metric for evaluation."
    )
    parser.add_argument(
        "--openai_key", type=str, default=None, help="OpenAI key for factscore."
    )
    parser.add_argument(
        "--retrieval", action="store_true", help="Whether to use retrieval or not during prediction."
    )
    parser.add_argument(
        "--use_gpt", action="store_true", help="Whether using openai model for pred."
    )
    parser.add_argument("--max_new_tokens", type=int, default=1024)
    parser.add_argument("--do_sample", action="store_true")
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    # create output path
    if args.output_file is not None:
        output_dir = os.path.dirname(args.output_file)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    fe = FavaEvaluation(
                 model_name=args.model_name,
                 input_file=args.input_file,
                 output_file=args.output_file,
                 metric=args.metric,
                 openai_key=args.openai_key,
                 retrieval=args.retrieval,
                 use_gpt=args.use_gpt,
                 max_new_tokens=args.max_new_tokens,
                 do_sample=args.do_sample,
                 temperature=args.temperature,
                 top_p=args.top_p)
    
    instances = fe.run_eval()
    # write results to output file
    with open(args.output_file, "w") as f:
        for inst in instances:
            f.write(json.dumps(inst) + "\n")
    print("done")
    
    