# implementation follows https://github.com/UCSB-NLP-Chang/llm_uncertainty
# @article{hou2023decomposing,
#   title={Decomposing Uncertainty for Large Language Models through Input Clarification Ensembling},
#   author={Hou, Bairu and Liu, Yujian and Qian, Kaizhi and Andreas, Jacob and Chang, Shiyu and Zhang, Yang},
#   journal={arXiv preprint arXiv:2311.08718},
#   year={2023}
# }
import jsonlines
from uncertainty.utils import PromptTemplate
from openai import OpenAI
import copy
from tqdm import tqdm
from uncertainty.utils import LLM   
import argparse
DEEPSEEK_API_KEY=""

def build_tempate(instruction):
   
 
    prompt_template = instruction + '\n\n' + 'Question: ' + '{query}'

    return prompt_template

def extract_paraphrases(model_ans: str):
    lines = model_ans.split('\n')
    extract_list = []
    others = []
    for line in lines:
        if line.startswith('Rephrase'):
            ext = line[len("Rephrase 1: "):]
            if 'specific' in ext.lower():
                continue
            extract_list.append(ext)
        elif line.startswith('Rephrased question'):
            ext = line[len("Rephrased question 1: "):]
            extract_list.append(ext)
        elif line.startswith('Rephrased'):
            ext = line[len("Rephrased 1: "):]
            extract_list.append(ext)
        elif line.startswith("Question: "):
            break
        else:
            others.append(line)
    return extract_list, others



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='generation paraphrased query')
    parser.add_argument('-m', '--model', type=str, default="deepseek", help='api  or models used to generated paraphrased queries')
    parser.add_argument('-d', '--dataset', type=str, help='dataset to estimate')
    parser.add_argument("-n", "--paraphrase_num", type=int, default=10, help = "max paraphrased per query" )

    parser.add_argument("-o", "--output_path", type=str, help = "path to save results" )
    parser.add_argument("--max_n", type=int, help = "max dataset example paraphrased" )
    parser.add_argument("-b", "--batch_size", type=int, default=5,  help = "batch size for model generation")
    
    args, _ = parser.parse_known_args()

    if args.model == "openai":
        client = OpenAI(
                    base_url='https://api.deepseek.com',
                    api_key='',
                )
        use_api = True
        model = "gpt-4o"
    elif args.model == "deepseek":
        use_api = True
        client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
        model = "deepseek-chat"
    else:
        use_api = False
        model = args.model
        generation_kwargs = {
            "batch_size": args.batch_size, #per device batch size
            "temperature": 1.0,
            "top_p": 0.9,
            "do_sample": True,
            "max_new_tokens": 1000,
            "num_responses_per_prompt": 3
        }
        tokenizition_kwargs = {
            "padding": "longest",
            "truncation": True,
            "padding_side": "left",
            "truncation_side": "left",
        }
    system_path = "./uncertainty/uncertainty_estimation/paraphrase_system.txt"
    dataset = args.dataset
    sample_n = args.paraphrase_num
    dataset_path = f"./data/datasets/{dataset}/test.jsonl"
    if args.output_path is None:
        if args.max_n is not None:

            output_path = f"./data/datasets/{dataset}/paraphrase_{args.max_n}.jsonl"
        else:
            output_path = f"./data/datasets/{dataset}/paraphrase.jsonl"
    else:
        assert args.output_path.endswith("jsonl"), f"given path should be a jsonl file, but '{args.output_path}' is given"
        output_path = args.output_path
    with open(system_path,'r',encoding='utf-8') as f:
        system_message = f.read().strip()
    
    data = []
    with jsonlines.open(dataset_path, "r") as reader:
        for line in reader:
            data.append(line)
    if args.max_n is  not None:
        data = data[:args.max_n]

    template = build_tempate("")
    prompt_template = PromptTemplate(model, template=template, system_message=system_message)

    print(f"start to call model {args.model}")
    all_results = []
    if use_api:
        completions = []
        
        
        max_try_time = 3


        for ex in tqdm(data):
            if dataset == "coqa":
                q_key = "question"
                
            else: 
                q_key = "query"
            
            prompt = prompt_template.build_prompt({"query": ex[q_key]})

            paraphrase_prompts = []

            try_time = 1

            other_outputs = []
            while len(paraphrase_prompts) <= sample_n and try_time <= max_try_time:
                chat_completion = client.chat.completions.create(
                        messages=prompt,
                        model=model,
                        temperature=1.0,
                        max_tokens = 512,
                        n=1
                    )
                completions.append(chat_completion)
                ans_model = chat_completion.choices[0].message.content
                extraction, others = extract_paraphrases(ans_model)
                paraphrase_prompts += extraction
                paraphrase_prompts = list(set([p.strip() for p in paraphrase_prompts]))

                try_time += 1

            if len(paraphrase_prompts) < sample_n:
                print(f"there is repetition in paraphrase_prompts for query {ex[q_key]}")
                paraphrase_prompts = paraphrase_prompts + paraphrase_prompts[:sample_n - len(paraphrase_prompts)]

            result = copy.deepcopy(ex)       
            result['paraphrases'] = paraphrase_prompts
            all_results.append(result)

        
    else:
        if dataset == "coqa":
            q_key = "question"
        else: 
            q_key = "query"
        
        prompts = [prompt_template.build_prompt({"query": ex[q_key]}) for ex in data]
        model_key = LLM.initial_lm(model, None, verbose=True, tokenizer_kwargs=tokenizition_kwargs)
        model, tokenizer = LLM.loaded_llms[model_key]

        results = LLM.lm_generate(model, prompts, generate_kwargs=generation_kwargs, tokenizer=tokenizer, tokenizer_kwargs=tokenizition_kwargs, verbose=True )

        for ex, res in zip(data, results["responses"]):


            paraphrase_prompts = []

            if isinstance(res, str):
                res = [res]
            
            for i in range(len(res)):
                extraction, others = extract_paraphrases(res[i])
                paraphrase_prompts += extraction
                paraphrase_prompts = list(set([p.strip() for p in paraphrase_prompts]))


            if len(paraphrase_prompts) < sample_n:
                print(f"there is repetition in paraphrase_prompts for query {ex[q_key]}")
                paraphrase_prompts = paraphrase_prompts + paraphrase_prompts[:sample_n - len(paraphrase_prompts)]

    

            new_ex = copy.deepcopy(ex)       
            new_ex['paraphrases'] = paraphrase_prompts
            all_results.append(new_ex)

    print(f"QUERY PARAPHRASE FINISHED!!!")

    with jsonlines.open(output_path, "w") as writer:
        writer.write_all(all_results)
        
    
