import argparse
import json
import os
import numpy as np
import time
import backoff
import openai
import tiktoken
import vllm
from .utils import remove_error_tags, load_jsonlines, split_sentences, run_detection, completions_with_backoff

class FavaBaseline(object):

    def __init__(self,
                 model_name="meta-llama/Llama-2-7b-hf",
                 input_file=None,
                 output_file=None,
                 openai_key=None,
                 retrieval=False,
                 use_gpt=False,
                 max_new_tokens=1024,
                 do_sample=False,
                 temperature=1.0,
                 top_p=1.0):
        self.model_name = model_name
        self.input_file = input_file
        self.output_file = output_file
        self.use_gpt = use_gpt
        self.lm = None
        self.sampling_params = None
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample
        self.temperature = temperature
        self.top_p = top_p
        if use_gpt:
            openai.api_key = openai_key
        else:
            self.lm = vllm.LLM(model=model_name)
            self.sampling_params = vllm.SamplingParams(
                temperature=self.temperature if self.do_sample else 0,
                top_p=self.top_p,
                max_tokens=self.max_new_tokens,
            )
        self.retrieval = retrieval 
        self.openai_key = openai_key
        try:
            self.encoding = tiktoken.encoding_for_model(self.model_name)
        except:
            print("couldn't find model encodings, using 'cl100k_base' encodings")
            self.encoding = tiktoken.encoding_for_model("cl100k_base")

    # constructs prompt for LM, if evidence is not empty string, then includes it (retrieval)
    def get_prompt(self, output, evidence):
        prompt = ""
        prompt += "Given a passage with factual errors, identify any <entity>, <relation>, <contradictory>, <subjective>, <unverifiable> or <invented> errors in the passage and add edits for <entity> and <relation> errors by inserting additional <mark></mark> or <delete></delete> tags.  If there are no errors, return the passage with no tags. Any changes to the original passage should be marked in <> tags. Below are the error definitions followed by examples of what you need to follow.\n"
        prompt += "Definitions:\n"
        prompt += "1. entity errors (<entity>): a small part of a sentence, often an entity (e.g., location name), is incorrect (usually 1-3 words). Entity errors often involve noun phrases or nouns.\n"
        prompt += "2. relational error (<relation>): a sentence is partially incorrect as a small part (usually 1 - 3 words). Relational errors often involve verbs and are often the opposite of what it should be.\n"
        prompt += "3. contradictory sentence error (<contradictory>): a sentence where the entire sentence is contradicted by the given reference, meaning the sentence can be proven false due to a contradiction with information in the passage.\n"
        prompt += "4. invented info error (< invented >): these errors refer to entities that are not known  or do not exist. This does not include fictional characters in books or movies. made-up info errors include phrases or sentences which have unknown entities or misleading information.\n"
        prompt += "5. subjective sentence (<subjective>): an entire sentence or phrase that is subjective and cannot be verified, so it should not be included.\n"
        prompt += "6. unverifiable sentence (<unverifiable>): a sentence where the whole sentence or phrase is unlikely to be factually grounded although it can be true, and the sentence cannot be confirmed nor denied using the reference given or internet search, it is often something personal or private and hence cannot be confirmed.\n"
        prompt += "Follow the given example exactly, your task is to create the edited completion with error tags <>:\n##\n"
        prompt += "Passage: Marooned on Mars is a science fiction novel aimed at a younger audience. It was written by Andy Weir and published by John C. Winston Co. in 1952, featuring illustrations by Alex Schomburg. It ended up having a readership of older boys despite efforts for it to be aimed at younger kids .The novel inspired the famous Broadway musical \"Stranded Stars,\" which won six Tony Awards. The novel tells a story of being stranded on the Purple Planet. I wish the novel had more exciting and thrilling plot twists.\n"
        if len(evidence) > 1:
            prompt += "Reference: Marooned on Mars is a juvenile science fiction novel written by American writer Lester del Rey. It was published by John C. Winston Co. in 1952 with illustrations by Alex Schomburg.\n"
        prompt += "Edited: Marooned on Mars is a science fiction novel aimed at a younger audience. It was written by <entity><mark>Lester del Rey</mark><delete>Andy Weir</delete></entity> and published by John C. Winston Co. in 1952, featuring illustrations by Alex Schomburg. <contradictory>It ended up having a readership of older boys despite efforts for it to be aimed at younger kids .</contradictory>. <invented>The novel inspired the famous Broadway musical \"Stranded Stars,\" which won six Tony Awards.</invented> The novel tells a story of being stranded on the <entity><mark>Red</mark><delete>Purple</delete></entity> Planet. <subjective>I wish the novel had more exciting and thrilling plot twists.</subjective>\n"
        prompt += "##\n"
        prompt += "##\n Now detect errors and include edits in the following passage like done in the example above. Include error tags <> for ANYTHING YOU CHANGE IN THE ORIGINAL PASSAGE.\n\n"
        prompt += "Passage: " + output
        if len(evidence) > 1:
            prompt += "\nReference: " + evidence
        prompt += "\nEdited: "
        print(prompt)
        return prompt
    
    def call_lm(self, prompt, tokens):
        response = ""
        if(self.use_gpt):
            response = completions_with_backoff(
                    model=self.model_name,
                    messages=[
                    {"role": "user", "content": prompt}
                    ],
                max_tokens=tokens)
        else:
            prompts = [prompt]
            outputs = self.lm.generate(prompts, self.sampling_params)
            outputs = [it.outputs[0].text for it in outputs]
            response = outputs[0]
        
        return response
        

    def get_baseline_pred(self, inst):
        
        output = inst["output"]
        evidence = inst["reference"]
        prompt = self.get_prompt(output, evidence)
        tokens = self.encoding.encode(prompt)
        token = 4000 - len(tokens)
        if token < 0:
            prompt = self.get_prompt(output, evidence[0:300])
            tokens = self.encoding.encode(prompt)
            token = 4000 - len(tokens)
        try:
            response = self.call_lm(prompt, token)
        except:
            print("sleeping for 60s...")
            time.sleep(60)
            response = self.call_lm(prompt, token)
        passage = response.choices[0].message.content
        if len(passage) < 50:
            passage = output
        inst["lm_edited"] = passage

        return inst



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default=None,
        help="OpenAI or huggingface model name.",
    )
    parser.add_argument(
        "--input_file",
        type=str,
        default=None,
        help="Input .json file containing input passages and references.",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default="output/result.json",
        help="Output .json file containing passages and edits.",
    )
    parser.add_argument(
        "--openai_key", type=str, default=None, help="OpenAI key for factscore."
    )
    parser.add_argument(
        "--retrieval", action="store_true"
    )
    parser.add_argument(
        "--use_gpt", action="store_true"
    )
    parser.add_argument("--max_new_tokens", type=int, default=1024)
    parser.add_argument("--do_sample", action="store_true")
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    args = parser.parse_args()
    return args
  
if __name__ == "__main__":
    args = parse_args()
    fb = FavaBaseline(model_name=args.model_name,
                    input_file=args.input_file,
                    output_file=args.output_file,
                    openai_key=args.openai_key,
                    retrieval=args.retrieval,
                    use_gpt=args.use_gpt,
                    max_new_tokens=args.max_new_tokens,
                    do_sample=args.do_sample,
                    temperature=args.temperature,
                    top_p=args.top_p)
    instances = json.load(open(args.input_file))        
    idx = 0
    results = []

    for inst in instances:
        inst = fb.get_baseline_pred(inst)
        results.add(inst)

    with open(args.output_file, "w") as outfile:
        json.dump(results, outfile)