# import pdb
# pdb.set_trace()
import os
import json
import random
import numpy as np
import pandas as pd
from template import WikiMultiHop_FewShot, WikiMultiHop_Retrieval_Query_Template, Query_Responses, NQ_FewShot, NQ_Retrieval_Query_Template, NQ_Elicitation_Prompt, Knowledge_Prompt, MultiHop_FewShot, SingleHop_FewShot, QUERY_GEN_PROMPT_DENSE_NEW
from openai import OpenAI
from beir.retrieval.search.lexical import BM25Search
from beir.retrieval.search.lexical.elastic_search import ElasticSearch
from beir.retrieval.evaluation import EvaluateRetrieval
from typing import List, Dict, Tuple
import numpy as np
import tqdm
from argparse import ArgumentParser
import concurrent.futures
import torch
import jsonlines
from fastchat.model.model_adapter import get_conversation_template
import time
from flashrag.retriever import DenseRetriever
import re
import string

def bm25search_search(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, Tuple[str, str]], top_k: int, *args, **kwargs) -> Dict[str, Dict[str, float]]:

    if self.initialize:
        self.index(corpus)

        time.sleep(self.sleep_for)

    query_ids = list(queries.keys())
    filter_ids = [queries[qid][1] for qid in query_ids]
    queries = [queries[qid][0] for qid in query_ids]

    final_results: Dict[str, Dict[str, Tuple[float, str]]] = {}
    for start_idx in tqdm.trange(0, len(queries), self.batch_size, desc='que', disable=True):
        query_ids_batch = query_ids[start_idx:start_idx+self.batch_size]
        results = self.es.lexical_multisearch(
            texts=queries[start_idx:start_idx+self.batch_size],
            filter_ids=filter_ids[start_idx:start_idx+self.batch_size],
            top_hits=top_k)
        for (query_id, hit) in zip(query_ids_batch, results):
            scores = {}
            for corpus_id, score, title, text in hit['hits']:
                scores[corpus_id] = (score, title, text)
                final_results[query_id] = scores

    return final_results

BM25Search.search = bm25search_search

def get_queries_and_retrieval_result_specified(main_model, 
                                     rewrite_model, 
                                     bm25_retriever, 
                                     dense_retriever, 
                                     question,
                                     analysis,
                                     retrieved_ids,
                                     args,
                                     max_retrieved_document=1):
    generate_query_system = QUERY_GEN_PROMPT_DENSE_NEW
    generate_query_user_input = """Question: {}
    Model Analysis: {}"""

    generate_queries  =  rewrite_model.chat.completions.create(
        model="",
        messages = [
            {"role": "system", "content": generate_query_system},
            {"role": "user", "content": generate_query_user_input.format(question, analysis)}
        ],
        temperature=0.5,
        max_tokens=100,
        stop=['<|eot_id|>', '\n']
    ).choices[0].message.content.strip().split('\n')

    random.shuffle(generate_queries)
    for q in generate_queries:
        query=q[len('Query:'):].strip()
        retrieval_results = dense_retriever.search(query)

        for idx in range(len(retrieval_results)):
            selected_document = []
            document = []
            for idxx in range(idx, len(retrieval_results)):
                doc_id = retrieval_results[idxx]['id']
                if doc_id not in selected_document and doc_id not in retrieved_ids:
                    selected_document.append(doc_id)
                    document.append(retrieval_results[idxx]['contents'].split('\n')[-1])
                if len(selected_document) == max_retrieved_document or idxx == len(retrieval_results)-1:
                    break
            document = ' '.join(document)
            retrieved_ids.extend(selected_document)
            return {'query': q, 'document': document }

    return 'failed'

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def elasticsearch_lexical_multisearch(self, texts: List[str], filter_ids: List[str] = None, top_hits: int = 10, skip: int = 0) -> Dict[str, object]:
    """Multiple Query search in Elasticsearch

    Args:
        texts (List[str]): Multiple query texts
        top_hits (int): top k hits to be retrieved
        skip (int, optional): top hits to be skipped. Defaults to 0.

    Returns:
        Dict[str, object]: Hit results
    """
    request = []

    assert skip + top_hits <= 10000, "Elastic-Search Window too large, Max-Size = 10000"

    filter_ids = filter_ids or ([None] * len(texts))
    for text, fid in zip(texts, filter_ids):
        req_head = {"index" : self.index_name, "search_type": "dfs_query_then_fetch"}
        if fid is not None:
            req_body = {
                "_source": True, 
                "query": {
                    "bool": {
                        "must": {
                            "multi_match": {
                                "query": text,  
                                "type": "best_fields",
                                "fields": [self.title_key, self.text_key],
                                "tie_breaker": 0.5
                            },
                        },
                        "filter": {
                            "term": {
                                "_id": fid
                            }
                        }
                    },
                },
                "size": skip + top_hits, # The same paragraph will occur in results
            }
        else:
            req_body = {
                "_source": True, # No need to return source objects
                "query": {
                    "multi_match": {
                        "query": text, # matching query with both text and title fields
                        "type": "best_fields",
                        "fields": [self.title_key, self.text_key],
                        "tie_breaker": 0.5
                    }
                },
                "size": skip + top_hits, # The same paragraph will occur in results
            }
        request.extend([req_head, req_body])

    res = self.es.msearch(body = request)

    result = []
    for resp in res["responses"]:
        responses = resp["hits"]["hits"][skip:] if 'hits' in resp else []

        hits = []
        for hit in responses:
            if args.index_name=='wikipedia_dpr':
                hits.append((hit["_id"], hit['_score'], hit['_source']['title'],hit['_source']['txt']))
            else:    
                hits.append((hit["_id"], hit['_score'], hit['_source']['title'],hit['_source']['text']))

        result.append(self.hit_template(es_res=resp, hits=hits))
    return result

ElasticSearch.lexical_multisearch = elasticsearch_lexical_multisearch


def elasticsearch_hit_template(self, es_res: Dict[str, object], hits: List[Tuple[str, float]]) -> Dict[str, object]:
    """Hit output results template

    Args:
        es_res (Dict[str, object]): Elasticsearch response
        hits (List[Tuple[str, float]]): Hits from Elasticsearch

    Returns:
        Dict[str, object]: Hit results
    """
    result = {
        'meta': {
            'total': es_res['hits']['total']['value'] if 'hits' in es_res else None,
            'took': es_res['took'] if 'took' in es_res else None,
            'num_hits': len(hits)
        },
        'hits': hits,
    }
    return result

ElasticSearch.hit_template = elasticsearch_hit_template


def load_data(data_path):
    if '2wiki_dev_dataset.pt' in data_path:
        data_new = []
        data = torch.load(data_path)
        for i in data:
            data_new.append(i)
            data_new[-1]['golden_answers'] = [i['answer']]
        data = data_new
    elif '2wiki' in data_path:
        # with open(data_path, 'r') as f:
        #     data = json.load(f)
        data = []
        with open(data_path, 'r') as f:
            for line in f:
                data.append(json.loads(line))
    else:
        data = []
        with open(data_path, 'r') as f:
            for line in f:
                data.append(json.loads(line))

    if args.max_data_size is not None:
        random.seed(2024)
        data = random.sample(data, args.max_data_size)
    print(data[0].keys())
    return data


def wiki_test_thread(main_model, 
                         retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):

    i = data[id]
    print('Processing id: ', id)

    retrieved_ids = []
    queries = [i['question']]
    template = get_conversation_template('llama-3')

    max_iter = 10
    current_iter = 1

    first_model_output = None

    while max_iter > 0:

        results = []
        document = None
        if queries is not None:
            retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                        None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
        if args.verbose:
            print(retrieved_results[0].items())
        for query_id, retrieved_result in retrieved_results.items():

            if retrieved_result == {}:
                continue
            retrieved_documents = list(retrieved_result.keys())
            document = None
            in_flag = False
            for doc_id in retrieved_documents:
                if doc_id not in retrieved_ids:
                    retrieved_ids.append(doc_id)
                    document = retrieved_result[doc_id][2]
                    break
                else:
                    continue
        if current_iter == 1:
            template.append_message(template.roles[0], 
                                    "Question: "+i['question'].strip()+
                                    "\n\nRetrieved Document_{}: ".format(current_iter)+document.strip())
        else:
            template.append_message(template.roles[0], 
                                    "Retrieved Document_{}: ".format(current_iter)+document.strip())
        prompt = template.get_prompt()
        if args.verbose:
            print('input', prompt)
        first_model_output = main_model.completions.create(
            model=args.main_model,
            prompt= prompt,
            temperature=0.0,
            max_tokens=200,
            stop=['<|eot_id|>']
        ).choices[0].text.strip()

        if args.verbose:
            print('output:', first_model_output)

        if 'Refined Query:'.lower() in first_model_output.lower():
            queries = [first_model_output.split('Refined Query:')[-1].strip()]
            if '<|eot_id|>' in queries[0]:
                queries = [queries[0].split('<|eot_id|>')[0].strip()]
            current_iter += 1 
        elif 'final answer' in first_model_output.lower():
            template.append_message(template.roles[1], 
                                first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            break
        else:
            print('Exception: Follow Failed')
            print(template.get_prompt())
            print(first_model_output)
        template.append_message(template.roles[1], 
                                first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
        max_iter -= 1

    max_iter = 5
    while 'Refined Query:' in first_model_output and max_iter > 0:
        
        query = first_model_output.split('Refined Query:')[-1].strip()
        if '<|eot_id|>' in query:
            query = query.split('<|eot_id|>')[0].strip()

        document_prompt = Knowledge_Prompt.format(i['question'], query)
        if args.verbose:
            print("Knowledge prompt:", document_prompt)
        document = main_model.completions.create(
            model = args.main_model,
            prompt = document_prompt,
            temperature=0.0,
            max_tokens=200,
            stop=['\n\n']
        ).choices[0].text.strip()
        if args.verbose:
            print("Document:", document)

        template.append_message(template.roles[0], 
                                    "Retrieved Document_{}: ".format(current_iter)+document.strip())
        prompt = template.get_prompt()
        if args.verbose:
            print('input:', prompt)
        first_model_output = main_model.completions.create(
            model=args.main_model,
            prompt= prompt,
            temperature=0.0,
            max_tokens=150
        ).choices[0].text.strip()

        template.append_message(template.roles[1], 
                                first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
        max_iter -= 1
        current_iter+=1
    all_output = template.get_prompt()
    if args.verbose:
        print(all_output)

    trace.append({
        'id': id,
        'question': i['question'],
        'trace': all_output,
        'golden_answer': i['answer'],
        'answer_id': i['answer_id']
    })
    answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()
    if i['answer'].lower() in answer.lower():
        print('Success!')
        success.append({
        'id': id,
        'question': i['question'],
        'trace': all_output,
        'golden_answer': i['answer'],
        'answer_id': i['answer_id']
    })
    print('processed id: ', id)
    print('success_rate:', len(success)/len(trace))
    # except Exception as e:
    #     print('Exception!', e)
    #     exception.append({
    #         'id': id,
    #         'question': i['question'],
    #     })

def wiki_test(main_model, retriever, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:

        for id in tqdm.trange(len(data)):
            executor.submit(wiki_test_thread, 
                            main_model, 
                            retriever, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def nq_test_thread(main_model, 
                         retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    try:
        i = data[id]
        print('Processing id: ', id)

        retrieved_ids = []
        queries = [i['question']+'?']
        template = get_conversation_template('llama-3')

        max_iter = 5
        current_iter = 1

        first_model_output = None

        while max_iter > 0:

            results = []
            document = None
            if queries is not None:
                retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                            None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
            if args.verbose:
                print(retrieved_results[0].items())
            for query_id, retrieved_result in retrieved_results.items():
                # No documents found
                if retrieved_result == {}:
                    continue
                retrieved_documents = list(retrieved_result.keys())
                document = None
                in_flag = False
                for doc_id in retrieved_documents:
                    if doc_id not in retrieved_ids:
                        retrieved_ids.append(doc_id)
                        document = retrieved_result[doc_id][2]
                        break

            if current_iter == 1:
                template.append_message(template.roles[0], 
                                        "Question: "+i['question'].strip()+'?'+
                                        "\n\nRetrieved Document_{}: ".format(current_iter)+document.strip())
            else:
                template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()

            if args.verbose:
                print('output:', first_model_output)

            if 'Refined Query:' in first_model_output:
                queries = [first_model_output.split('Refined Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1 
            elif 'final answer' in first_model_output.lower():
                template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
                break
            else:
                print('Exception: Follow Failed')
            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1

        max_iter = 5
        while 'Refined Query:' in first_model_output and max_iter > 0:
            
            query = first_model_output.split('Refined Query:')[-1].strip()
            if '<|eot_id|>' in query:
                query = query.split('<|eot_id|>')[0].strip()

            document_prompt = Knowledge_Prompt.format(i['question'], query)
            if args.verbose:
                print("Knowledge prompt:", document_prompt)
            document = main_model.completions.create(
                model = args.main_model,
                prompt = document_prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['\n\n']
            ).choices[0].text.strip()
            if args.verbose:
                print("Document:", document)

            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input:', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=150
            ).choices[0].text.strip()

            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1
            current_iter+=1
        all_output = template.get_prompt()
        if args.verbose:
            print(all_output)

        answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()

        if 'answer' in i.keys():
            for k in i['answer']:

                    for p in ['.', ',', '?', '!', ':', ';']:
                        k = k.replace(' ' + p, p)
                    
                    if k.lower() in answer.lower():
                        success.append({
                            'id': id,
                            'question': i['question'],
                            'trace': all_output,
                            'golden_answer': i['answer'],
                        })
                        break
            trace.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['answer'],
            })
        elif 'short_answers' in i.keys():
            for k in i['short_answers']:

                for p in ['.', ',', '?', '!', ':', ';']:
                    k = k.replace(' ' + p, p)
                
                if k.lower() in answer.lower():
                    success.append({
                        'id': id,
                        'question': i['question'],
                        'trace': all_output,
                        'golden_answer': i['short_answers'],
                    })
                    break
            trace.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['short_answers'],
            })
        else:
            print('No golden answer found')
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        print('Exception!', e)
        exception.append({
            'id': id,
            'question': i['question'],
        })

def nq_test(main_model, retriever, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:

        for id in tqdm.trange(len(data)):
            executor.submit(nq_test_thread, 
                            main_model, 
                            retriever, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def few_shot_wiki_test_thread(main_model, 
                         retriever, 
                         rewrite_model,
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):

    i = data[id]
    print('Processing id: ', id)
    try:
        retrieved_ids = []
        queries = [i['question']]

        max_iter = 10
        current_iter = 1

        template = WikiMultiHop_FewShot
        first_model_output = None
        last_query = queries[0]
        while max_iter > 0:

            results = []
            document = None
            if queries is not None:
                retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                            None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
            if args.verbose:
                print(retrieved_results[0].items())
            for query_id, retrieved_result in retrieved_results.items():

                if retrieved_result == {}:
                    continue
                retrieved_documents = list(retrieved_result.keys())
                document = None
                in_flag = False
                for doc_id in retrieved_documents:
                    if doc_id not in retrieved_ids:
                        retrieved_ids.append(doc_id)
                        document = retrieved_result[doc_id][2]
                        break
                    else:
                        continue
                if in_flag:
                    last_query = queries[query_id]
                    break
            if document is None:
                if args.verbose:
                    print('No document found')

                random_choose_query_id = random.choice(range(len(retrieved_results)))
                keys = list(retrieved_results[random_choose_query_id].keys())
                for doc_id in keys:
                    if doc_id not in retrieved_ids:
                        document = retrieved_results[random_choose_query_id][doc_id][2]
                        retrieved_ids.append(doc_id)
                        break
                query_id = random_choose_query_id
                last_query = queries[query_id]

            assert document is not None, 'No document found'

            if current_iter == 1:
                template += '\nQuestion: ' + i['question'].strip() + '\n'
                template += '\nRetrieved Document_{}: {}'.format(current_iter, document.strip()) + '\n'

            else:
                template += '\nRefined Query: ' + queries[query_id].strip() + '\n'
                template += '\nRetrieved Document_{}: {}'.format(current_iter, document.strip()) + '\n'
            
            template += '\nIntermediate Answer_{}: '.format(current_iter)

            main_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template,
                temperature=0.0,
                max_tokens=100,
            )
            
            first_model_output = main_model_output.choices[0].text.split('\n\n')[0]
            
            template += first_model_output.strip() + '\n'
            if args.verbose:
                print(first_model_output)
            retrieve_flag = False
            for j in Query_Responses:
                if j.lower() in first_model_output.lower():
                    retrieve_flag = True
                    break
            if retrieve_flag:
                inst = WikiMultiHop_Retrieval_Query_Template.format(first_model_output, last_query)

                rewrite_model_output =  rewrite_model.chat.completions.create(
                    model="xxx",
                    messages = [
                        {"role": "user", "content": inst}
                    ],
                    temperature=0.0,
                    max_tokens=100
                )
                # print(rewrite_model_output.choices[0].message.content)
                rewrite_queries = rewrite_model_output.choices[0].message.content.strip().split('\n')
                
                queries = [q.strip()[len('Query 1: '):] for q in rewrite_queries]
                if args.verbose:
                    print(queries)
                current_iter += 1
            else:
                queries = None
                break
            max_iter -= 1
        still_retrieval_flag = False
        for j in Query_Responses:
            if j.lower() in first_model_output.lower():
                still_retrieval_flag = True
                break
        if still_retrieval_flag:
            print('still_retrieval')
            inst = NQ_Retrieval_Query_Template.format(i['question']+'?', first_model_output, last_query)

            rewrite_model_output =  rewrite_model.chat.completions.create(
                model="",
                messages = [
                    {"role": "user", "content": inst}
                ],
                temperature=0.0,
                max_tokens=100
            )

            rewrite_queries = rewrite_model_output.choices[0].message.content.strip().split('\n')
            
            query = [q.strip()[len('Query 1: '):] for q in rewrite_queries][0]
            if args.verbose:
                print(queries)
            
            template += '\nRefined Query: ' + query.strip() + '\n'

            elicit_prompt = NQ_Elicitation_Prompt.format(i['question'],
                                                        first_model_output,
                                                        last_query)
            elicit_knowledge = main_model.completions.create(
                model=args.main_model,
                prompt= elicit_prompt,
                temperature=0.0,
                max_tokens=100,
            )
            elicit_knowledge = elicit_knowledge.choices[0].text.split('\n\n')[0].strip()
            template += '\nRetrieved Document_{}: '.format(current_iter) + elicit_knowledge + '\n'
            
            if args.verbose:
                print('elicit knowledge: ', elicit_knowledge)
            template += '\nIntermediate Answer_{}: Based on the Retrieved Document_{}, '.format(current_iter, current_iter)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template,
                temperature=0.0,
                max_tokens=100,
            )
            first_model_output = first_model_output.choices[0].text.split('\n\n')[0].strip()
            template += first_model_output + '\n'
        
        final_output = main_model.completions.create(
            model=args.main_model,
            prompt= template,
            temperature=0.0,
            max_tokens=100,
        )

        final_output = final_output.choices[0].text.split('\n\n')[0].strip()

        template = template + '\n' + final_output
        trace.append({
            'id': id,
            'trace': template.split('###\n\n')[-1],
            'answer_id': i['answer_id'],
            'golden_answer': i['answer']
        })
        print(trace[-1])
        answer = final_output.split('Final Answer:')[-1].strip()
        if i['answer'].lower() in answer.lower():
            print('Success!')
            success.append({
            'id': id,
            'trace': template.split('###\n\n')[-1],
            'answer_id': i['answer_id'],
            'golden_answer': i['answer']
        })
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        print('Exception!', e)
        exception.append({
            'template': template,
            'id': id,
            'question': i['question'],
        })
        with open('', 'w') as f:
            json.dump(exception, f)
    print(trace[-1])
    print('processed id: ', id)

def few_shot_wiki_test(main_model, 
                       retriever, 
                       rewrite_model,
                         data, 
                         args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:

        for id in tqdm.trange(len(data)):
            executor.submit(few_shot_wiki_test_thread, 
                            main_model, 
                            retriever, 
                            rewrite_model,
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def few_shot_nq_test_thread(main_model, 
                         retriever, 
                        rewrite_model,
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    i = data[id]
    print('Processing id: ', id)
    try:
        i = data[id]

        retrieved_ids = []
        queries = [i['question']+'?']
        historical_queries = []

        max_iter = 3
        current_iter = 1

        template = NQ_FewShot
        first_model_output = None
        last_query = queries[0]
        while max_iter > 0:
            # 
            results = []
            document = None
            if queries is not None:
                retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                            None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
            # if args.verbose:
            #     print(retrieved_results[0].items())
            for query_id, retrieved_result in retrieved_results.items():
                # No documents found
                if retrieved_result == {}:
                    continue
                if queries[query_id] not in historical_queries:
                    historical_queries.append(queries[query_id])
                    retrieved_documents = list(retrieved_result.keys())
                    in_flag = False
                    for doc_id in retrieved_documents:
                        if doc_id not in retrieved_ids:
                            document = retrieved_result[doc_id][2]
                            retrieved_ids.append(doc_id)
                            in_flag = True
                            break
                    if in_flag:
                        last_query = queries[query_id]
                        break

            if document is None:
                if args.verbose:
                    print('No document found')
                # 
                random_choose_query_id = random.choice(range(len(retrieved_results)))
                keys = list(retrieved_results[random_choose_query_id].keys())
                for doc_id in keys:
                    if doc_id not in retrieved_ids:
                        document = retrieved_results[random_choose_query_id][doc_id][2]
                        retrieved_ids.append(doc_id)
                        break
                query_id = random_choose_query_id
                last_query = queries[query_id]

            assert document is not None, 'No document found'
            if args.verbose:
                print("Document: ", document)
            if current_iter == 1:
                template += '\nQuestion: ' + i['question'].strip() + '\n'
                template += '\nRetrieved Document_{}: {}'.format(current_iter, document.strip()) + '\n'

            else:
                template += '\nRefined Query: ' + queries[query_id].strip() + '\n'
                template += '\nRetrieved Document_{}: {}'.format(current_iter, document.strip()) + '\n'
            
            template += '\nIntermediate Answer_{}: Based on the Retrieved Document_{}, '.format(current_iter, current_iter)
            # print(template)
            main_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template,
                temperature=0.0,
                max_tokens=100,
            )
            
            first_model_output = main_model_output.choices[0].text.split('\n\n')[0]
            
            template += first_model_output.strip() + '\n'
            if args.verbose:
                print("model_output:", first_model_output)
            retrieve_flag = False
            for j in Query_Responses:
                if j.lower() in first_model_output.lower():
                    retrieve_flag = True
                    break
            if retrieve_flag:
                inst = NQ_Retrieval_Query_Template.format(i['question']+'?', first_model_output, last_query)

                rewrite_model_output =  rewrite_model.chat.completions.create(
                    model="",
                    messages = [
                        {"role": "user", "content": inst}
                    ],
                    temperature=0.0,
                    max_tokens=100
                )
                # print(rewrite_model_output.choices[0].message.content)
                rewrite_queries = rewrite_model_output.choices[0].message.content.strip().split('\n')
                
                queries = [q.strip()[len('Query 1: '):] for q in rewrite_queries]
                if args.verbose:
                    print("queries:", queries)
                current_iter += 1
            else:
                queries = None
                break
            max_iter -= 1

        still_retrieval_flag = False
        for j in Query_Responses:
            if j.lower() in first_model_output.lower():
                still_retrieval_flag = True
                break
        if still_retrieval_flag:
            print('still_retrieval')
            inst = NQ_Retrieval_Query_Template.format(i['question']+'?', first_model_output, last_query)

            rewrite_model_output =  rewrite_model.chat.completions.create(
                model="",
                messages = [
                    {"role": "user", "content": inst}
                ],
                temperature=0.0,
                max_tokens=100
            )
            # print(rewrite_model_output.choices[0].message.content)
            rewrite_queries = rewrite_model_output.choices[0].message.content.strip().split('\n')
            
            query = [q.strip()[len('Query 1: '):] for q in rewrite_queries][0]
            if args.verbose:
                print(queries)
            
            template += '\nRefined Query: ' + query.strip() + '\n'

            elicit_prompt = NQ_Elicitation_Prompt.format(i['question'],
                                                        first_model_output,
                                                        last_query)
            elicit_knowledge = main_model.completions.create(
                model=args.main_model,
                prompt= elicit_prompt,
                temperature=0.0,
                max_tokens=100,
            )
            elicit_knowledge = elicit_knowledge.choices[0].text.split('\n\n')[0].strip()
            template += '\nRetrieved Document_{}: '.format(current_iter) + elicit_knowledge + '\n'
            
            if args.verbose:
                print('elicit knowledge: ', elicit_knowledge)

            template += '\nIntermediate Answer_{}: Based on the Retrieved Document_{}, '.format(current_iter, current_iter)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template,
                temperature=0.0,
                max_tokens=100,
            )
            first_model_output = first_model_output.choices[0].text.split('\n\n')[0].strip()
            template += first_model_output + '\n'
        
        final_output = main_model.completions.create(
            model=args.main_model,
            prompt= template,
            temperature=0.0,
            max_tokens=100,
        )

        final_output = final_output.choices[0].text.split('\n\n')[0].strip()

        template = template + '\n' + final_output
        answer = final_output.split('Final Answer:')[-1].strip()
        if 'golden_answers' in i.keys():
            for k in i['golden_answers']:
                    # 
                    k = normalize_answer(k)
                    
                    if k.lower() in answer.lower():
                        success.append({
                            'id': id,
                            'question': i['question'],
                            'trace': template,
                            'golden_answer': i['answer'],
                        })
                        break
            trace.append({
                'id': id,
                'question': i['question'],
                'trace': template.split('###\n\n')[-1],
                'golden_answer': i['answer'],
            })
        else:
            print('No golden answer found.')
        if args.verbose:
            print(trace[-1])
        print(trace[-1])
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        print('Exception!', e)
        exception.append({
            'template': template,
            'id': id,
            'question': i['question'],
        })
        with open('', 'w') as f:
            json.dump(exception, f)
    print('processed id: ', id)

def few_shot_nq_test(main_model, 
                         retriever, 
                         rewrite_model,
                         data, 
                         args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(few_shot_nq_test_thread, 
                            main_model, 
                            retriever, 
                            rewrite_model,
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')

def test_thread(main_model, 
                         retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    try:
        i = data[id]
        print('Processing id: ', id)

        retrieved_ids = []
        queries = [i['question']]
        template = get_conversation_template('llama-3')

        max_iter = args.retrieval_max_iter
        current_iter = 1

        first_model_output = None

        while max_iter > 0:
            # 
            results = []
            document = None
            if args.search_engine == 'bm25':
                if queries is not None:
                    retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                                None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
                if args.verbose:
                    print(retrieved_results[0].items())
                
                current_passages = 0
                documents = []
                for query_id, retrieved_result in retrieved_results.items():
                    # No documents found
                    if retrieved_result == {}:
                        continue
                    retrieved_documents = list(retrieved_result.keys())
                    for doc_id in retrieved_documents:
                        if doc_id not in retrieved_ids:
                            retrieved_ids.append(doc_id)
                            documents.append(retrieved_result[doc_id][2])
                            current_passages += 1
                            if current_passages >= args.num_passages:
                                break
                        else:
                            continue
                document = ' '.join(documents)
            else:
                documents = []
                retrieval_results = retriever.search(queries[0])
                for result in retrieval_results:
                    if result['id'] not in retrieved_ids:
                        retrieved_ids.append(result['id'])
                        documents.append(result['contents'].split('\n')[-1])
                    if len(documents) >= args.num_passages:
                        break
                document = ' '.join(documents)
                        
            if current_iter == 1:
                template.append_message(template.roles[0], 
                                        "Question: "+i['question'].strip()+
                                        "\n\nRetrieved Document_{}: ".format(current_iter)+document.strip())
            else:
                template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()

            if args.verbose:
                print('output:', first_model_output)
            # print('output:', first_model_output)
            if 'Refined Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Refined Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1 
            elif 'final answer' in first_model_output.lower():
                template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
                break
            else:
                print('Exception: Follow Failed')
                print(template.get_prompt())
                print(first_model_output)
            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1

        max_iter = args.elicit_max_iter 
        while 'Refined Query:' in first_model_output and max_iter > 0:
            
            query = first_model_output.split('Refined Query:')[-1].strip()
            if '<|eot_id|>' in query:
                query = query.split('<|eot_id|>')[0].strip()

            document_prompt = Knowledge_Prompt.format(i['question'], query)
            if args.verbose:
                print("Knowledge prompt:", document_prompt)
            document = main_model.completions.create(
                model = args.main_model,
                prompt = document_prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['\n', '<|eot_id|>']
            ).choices[0].text.strip()
            if args.verbose:
                print("Document:", document)
            print('Document:', document)
            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input:', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=150
            ).choices[0].text.strip()

            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1
            current_iter+=1
        all_output = template.get_prompt()
        if args.verbose:
            print(all_output)

        trace.append({
            'id': id,
            'question': i['question'],
            'trace': all_output,
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
        answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()
        
        for ans in i['golden_answers']:
            if ans.lower() in answer.lower():
                print('Success!')
                success.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['golden_answers'],
                })
                break

        print(trace[-1])
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        print('Exception!', e)

def few_shot_test_thread(main_model, 
                         retriever, 
                        rewrite_model,
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    i = data[id]
    print('Processing id: ', id)
    max_retrieved_document = args.num_passages
    try:
        retrieved_ids = []
        queries = [i['question']]

        # intermediate_answers = [e[2] for e in i['evidences']]

        max_iter = args.max_iter
        current_iter = 1

        if args.few_shot_template == 'multihop':
            template = MultiHop_FewShot
        elif args.few_shot_template == 'singlehop':
            template = SingleHop_FewShot
        else:
            assert False, 'Few shot template not found'
        last_query = queries[0]

        while max_iter > 0:
            if current_iter == 1:
                template += '\nQuestion: ' + i['question'].strip() + '\n'
                template += '\nAnalysis: '

                analysis = main_model.completions.create(
                    model=args.main_model,
                    prompt= template,
                    temperature=0.5,
                    max_tokens=100,
                    stop=['<|eot_id|>', '\n']
                ).choices[0].text.strip()

                max_tries = 5
                success_flag = False
                while max_tries > 0:
                    retrieval_result = get_queries_and_retrieval_result_specified(main_model, 
                                                                        rewrite_model, 
                                                                        None, 
                                                                        retriever, 
                                                                        i['question'],
                                                                        analysis,
                                                                        retrieved_ids,
                                                                        # intermediate_answers,
                                                                        args,
                                                                        max_retrieved_document=max_retrieved_document)
                    if type(retrieval_result) == dict:
                        break
                    max_tries -= 1

                last_query, document = retrieval_result['query'], retrieval_result['document']
                template += analysis + '\n'
                template += '\nInitial Query: ' + last_query[len('Query:'):].strip() + '\n'
                template += '\nRetrieved Document_{}: {}'.format(current_iter, document.strip()) + '\n'
                template += '\nIntermediate Answer_{}: '.format(current_iter)
                analysis = main_model.completions.create(
                    model=args.main_model,
                    prompt= template,
                    temperature=0.0,
                    max_tokens=100,
                    stop=['<|eot_id|>', '\n']
                ).choices[0].text.strip()
                template += analysis + '\n'
                current_iter += 1
            else:
                retrieve_flag = False
                for j in Query_Responses:
                    if j.lower() in analysis.lower():
                        retrieve_flag = True
                        break
                if retrieve_flag:
                    success_flag = False
                    max_tries = 5
                    while max_tries > 0:
                        retrieval_result = get_queries_and_retrieval_result_specified(main_model, 
                                                                            rewrite_model, 
                                                                            None, 
                                                                            retriever, 
                                                                            i['question'],
                                                                            analysis,
                                                                            retrieved_ids,
                                                                            args,
                                                                            max_retrieved_document=max_retrieved_document,)
                                                                            
                        if type(retrieval_result) == dict:
                            break
                        max_tries -= 1

                    last_query, document = retrieval_result['query'], retrieval_result['document']
                    template += '\nRefined Query: ' + last_query[len('Query:'):].strip() + '\n'
                    template += '\nRetrieved Document_{}: {}'.format(current_iter, document.strip()) + '\n'
                    template += '\nIntermediate Answer_{}: '.format(current_iter)
                    analysis = main_model.completions.create(
                        model=args.main_model,
                        prompt= template,
                        temperature=0.0,
                        max_tokens=100,
                        stop=['<|eot_id|>', '\n']
                    ).choices[0].text.strip()
                    template += analysis + '\n'
                    current_iter += 1
                else:
                    break
            max_iter -= 1
        still_retrieval_flag = False
        for j in Query_Responses:
            if j.lower() in analysis.lower():
                still_retrieval_flag = True
                break
        max_iter = args.elicit_max_iter
        while still_retrieval_flag and max_iter > 0:
            inst = NQ_Retrieval_Query_Template.format(i['question']+'?', analysis, last_query)
            rewrite_model_output =  rewrite_model.chat.completions.create(
                model="",
                messages = [
                    {"role": "user", "content": inst}
                ],
                temperature=0.0,
                max_tokens=100
            )
            rewrite_queries = rewrite_model_output.choices[0].message.content.strip().split('\n')
            query = [q.strip()[len('Query 1: '):] for q in rewrite_queries][0]
            template += '\nRefined Query: ' + query[0].strip() + '\n'
            last_query = query[0]
            elicit_prompt = NQ_Elicitation_Prompt.format(i['question'], 
                                                         analysis,
                                                        last_query)
            elicit_knowledge = main_model.completions.create(
                model=args.main_model,
                prompt= elicit_prompt,
                temperature=0.0,
                max_tokens=100,
            )
            elicit_knowledge = elicit_knowledge.choices[0].text.split('\n\n')[0].strip()

            template += '\nRetrieved Document_{}: '.format(current_iter) + elicit_knowledge + '\n'
            template += '\nIntermediate Answer_{}: Based on the Retrieved Document_{}, '.format(current_iter, current_iter)
            analysis = main_model.completions.create(
                model=args.main_model,
                prompt= template,
                temperature=0.0,
                max_tokens=100,
            ).choices[0].text.split('\n\n')[0].strip()
            template += analysis + '\n'
            current_iter += 1
            max_iter -= 1
            still_retrieval_flag = False
            for j in Query_Responses:
                if j.lower() in analysis.lower():
                    still_retrieval_flag = True
                    break
        final_output = main_model.completions.create(
            model=args.main_model,
            prompt= template,
            temperature=0.0,
            max_tokens=100,
            stop=['<|eot_id|>', '\n']
        )

        final_output = final_output.choices[0].text.strip()

        template = template + '\n' + final_output



        if 'final answer' in final_output.lower():
                answer = final_output.split('Final Answer: ')[-1].split('<|eot_id|>')[0].strip()
                for ans in i['golden_answers']:
                    if normalize_answer(ans.lower()) ==  normalize_answer(answer.lower()):
                        success.append({
                            'id': id,
                            'data': template.split('###\n\n')[-1]
                        })
                        break
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': template.split('###\n\n')[-1],
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']

        print(trace[-1])
        print('success_rate', len(success) / len(trace))
        
    except Exception as e:
        print('Exception!', e)
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': template.split('###\n\n')[-1],
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
    print('processed id: ', id)

def few_shot_test(main_model, 
                         retriever, 
                         rewrite_model,
                         data, 
                         args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(few_shot_test_thread, 
                            main_model, 
                            retriever, 
                            rewrite_model,
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')
    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def test(main_model, retriever, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(test_thread, 
                            main_model, 
                            retriever, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def naive_test_thread(main_model,
                        data,
                        args,
                        id,
                        trace,
                        exception,
                        success):
    i = data[id]

    template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nAnswer the question based on your own knowledge. Only give me the answer and do not output any other words.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nQuestion: {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
    prompt = template.format(i['question'])

    main_model_output = main_model.completions.create(
            model = args.main_model,
            prompt=prompt,
            temperature=0.0,
            max_tokens=200,
            stop=['<|eot_id|>', '\n']
        ).choices[0].text.strip()

    trace.append({
        'id': id,
        'question': i['question'],
        'trace': main_model_output,
        'golden_answers': i['golden_answers'],
    })
    if 'answer_id' in i.keys():
        trace[-1]['answer_id'] = i['answer_id']
def naive_test(main_model, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(naive_test_thread, 
                            main_model, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace


def standard_test_thread(main_model,
                        data,
                        args,
                        id,
                        trace,
                        exception,
                        success):
    i = data[id]

    template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{}\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nQuestion: {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
    document_template = """Doc {}(Title: {}) {}\n"""
    system = """Answer the question based on the given document.Only give me the answer and do not output any other words.
    The following are given documents.
    
    {}"""
    queries = [i['question']]
    # print(queries)
    if queries is not None:
        retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                    None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
    # print(retrieved_results)
    retrieved_ids = []
    documents=""
    current_passages=0
    for query_id, retrieved_result in retrieved_results.items():
        # No documents found
        if retrieved_result == {}:
            continue
        retrieved_documents = list(retrieved_result.keys())
        for doc_id in retrieved_documents:
            if doc_id not in retrieved_ids:
                retrieved_ids.append(doc_id)
                documents+=document_template.format(current_passages+1,
                                                    retrieved_result[doc_id][1],
                                                    retrieved_result[doc_id][2])
                current_passages += 1
                if current_passages >= 5:
                    break
            else:
                continue
    system = system.format(documents)
    prompt = template.format(system,i['question'])
    # print(prompt)
    main_model_output = main_model.completions.create(
            model = args.main_model,
            prompt=prompt,
            temperature=0.0,
            max_tokens=200,
            stop=['<|eot_id|>', '\n']
        ).choices[0].text.strip()

    trace.append({
        'id': id,
        'question': i['question'],
        'trace': main_model_output,
        'golden_answers': i['golden_answers'],
    })
    if 'answer_id' in i.keys():
        trace[-1]['answer_id'] = i['answer_id']
    
    print('Processed id: ', id)

def standard_test(main_model, data, args):
    trace = []
    exception = []
    success = []


    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(standard_test_thread, 
                            main_model, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')
    return useful, trace




def adaptive_test_thread(main_model, 
                         bm25_retriever, 
                         dense_retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    try:
        i = data[id]
        print('Processing id: ', id)

        retrieved_ids = []
        queries = [i['question']]
        template = get_conversation_template('llama-3')
        template.set_system_message('Answer the question by retrieving external knowledge. Extract useful information from each retrieved document. If the information is insufficient or irrelevant, refine your query and search again until you are able to answer the question.')
        template.append_message(template.roles[0], "Question: "+i['question'].strip())

        max_iter = args.retrieval_max_iter
        current_iter = 0

        first_model_output = None

        while max_iter > 0:
            prompt = template.get_prompt()
            if args.verbose:
                print('input', prompt)
            # print(prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            
            if 'Refined Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Refined Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1 
            elif 'Initial Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Initial Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1
            elif 'final answer' in first_model_output.lower():
                template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
                break
            else:
                print('Exception: Follow Failed')
                print(template.get_prompt())
                print(first_model_output)
            # print(queries[0])
            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

            document = None
            if '[BM25]' in queries[0]:
                queries[0] = queries[0].replace('[BM25]', '').strip()
                # print(queries[0])
                if queries is not None:
                    retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = bm25_retriever.retrieve(
                                None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
                if args.verbose:
                    print(retrieved_results[0].items())
                
                current_passages = 0
                documents = []
                for query_id, retrieved_result in retrieved_results.items():
                    # No documents found
                    if retrieved_result == {}:
                        continue
                    retrieved_documents = list(retrieved_result.keys())
                    for doc_id in retrieved_documents:
                        if doc_id not in retrieved_ids:
                            retrieved_ids.append(doc_id)
                            documents.append(retrieved_result[doc_id][2])
                            current_passages += 1
                            if current_passages >= args.num_passages:
                                break
                        else:
                            continue
                document = ' '.join(documents)
            else:
                queries[0] = queries[0].replace('[Dense]', '').strip()
                # print(queries[0])
                documents = []
                retrieval_results = dense_retriever.search(queries[0])
                for result in retrieval_results:
                    if result['id'] not in retrieved_ids:
                        retrieved_ids.append(result['id'])
                        documents.append(result['contents'].split('\n')[-1])
                    if len(documents) >= args.num_passages:
                        break
                document = ' '.join(documents)
                        

            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            
            max_iter -= 1
        first_model_output=""
        if max_iter == 0:
            # print(template.get_prompt())
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template.get_prompt(),
                temperature=0.0,
                max_tokens=150,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            template.append_message(template.roles[1], 
                                        first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

        max_iter = args.elicit_max_iter 
        while 'Refined Query:' in first_model_output and max_iter > 0:
            current_iter+=1
            query = first_model_output.split('Refined Query:')[-1].strip()
            if '<|eot_id|>' in query:
                query = query.split('<|eot_id|>')[0].strip()

            document_prompt = Knowledge_Prompt.format(i['question'], query)
            if args.verbose:
                print("Knowledge prompt:", document_prompt)
            document = main_model.completions.create(
                model = args.main_model,
                prompt = document_prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['\n', '<|eot_id|>']
            ).choices[0].text.strip()
            if args.verbose:
                print("Document:", document)
            # print('Document:', document)
            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input:', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=150
            ).choices[0].text.strip()

            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1
        all_output = template.get_prompt()
        if args.verbose:
            print(all_output)

        trace.append({
            'id': id,
            'question': i['question'],
            'trace': all_output,
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
        answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()
        
        for ans in i['golden_answers']:
            if normalize_answer(ans.lower()) ==  normalize_answer(answer.lower()):
                print('Success!')
                success.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['golden_answers'],
                })
                break

        print(trace[-1])
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': template.get_prompt(),
            'golden_answers': i['golden_answers'],
        })
        print('Exception!', e)



def adaptive_test(main_model, bm25_retriever, dense_retriever, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(adaptive_test_thread, 
                            main_model, 
                            bm25_retriever, 
                            dense_retriever,
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def dense_test_thread(main_model, 
                         bm25_retriever, 
                         dense_retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success,
                         test_retrieve_cnt):
    try:
        i = data[id]
        print('Dense Processing id: ', id)

        retrieved_ids = []
        queries = [i['question']]
        template = get_conversation_template('llama-3')
        template.set_system_message('Answer the question by retrieving external knowledge. Extract useful information from each retrieved document. If the information is insufficient or irrelevant, refine your query and search again until you are able to answer the question.')
        template.append_message(template.roles[0], "Question: "+i['question'].strip())

        max_iter = args.retrieval_max_iter
        current_iter = 0

        first_model_output = None

        while max_iter > 0:
            prompt = template.get_prompt()
            if args.verbose:
                print('input', prompt)
            # print(prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            print(first_model_output)
            if 'Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1 
            elif 'final answer' in first_model_output.lower():
                template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
                break
            else:
                print('Exception: Follow Failed')
                print(template.get_prompt())
                print(first_model_output)
            # print(queries[0])
            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

            document = None
            
            queries[0] = queries[0].replace('[Dense]', '').strip()
            # print(queries[0])
            documents = []
            retrieval_results = dense_retriever.search(queries[0])
            test_retrieve_cnt[0] = test_retrieve_cnt[0]+1
            for result in retrieval_results:
                if result['id'] not in retrieved_ids:
                    retrieved_ids.append(result['id'])
                    documents.append(result['contents'].split('\n')[-1])
                if len(documents) >= args.num_passages:
                    break
            document = ' '.join(documents)
                        

            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            
            max_iter -= 1
        first_model_output=""
        if max_iter == 0:
            # print(template.get_prompt())
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template.get_prompt(),
                temperature=0.0,
                max_tokens=150,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            template.append_message(template.roles[1], 
                                        first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

        max_iter = args.elicit_max_iter 
        while 'Refined Query:' in first_model_output and max_iter > 0:
            current_iter+=1
            query = first_model_output.split('Refined Query:')[-1].strip()
            if '<|eot_id|>' in query:
                query = query.split('<|eot_id|>')[0].strip()

            document_prompt = Knowledge_Prompt.format(i['question'], query)
            if args.verbose:
                print("Knowledge prompt:", document_prompt)
            document = main_model.completions.create(
                model = args.main_model,
                prompt = document_prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['\n', '<|eot_id|>']
            ).choices[0].text.strip()
            if args.verbose:
                print("Document:", document)
            # print('Document:', document)
            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input:', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=150
            ).choices[0].text.strip()

            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1
        all_output = template.get_prompt()
        if args.verbose:
            print(all_output)

        trace.append({
            'id': id,
            'question': i['question'],
            'trace': all_output,
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
        answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()
        
        for ans in i['golden_answers']:
            if normalize_answer(ans.lower()) ==  normalize_answer(answer.lower()):
                print('Success!')
                success.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['golden_answers'],
                })
                break

        print(trace[-1])
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
        return 'success'
    except Exception as e:
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': template.get_prompt(),
            'golden_answers': i['golden_answers'],
        })
        print('Exception!', e)
        return 'fail'


def dense_test(main_model, bm25_retriever, dense_retriever, data, args):

    trace = []
    exception = []
    success = []

    test_retrieve_cnt = [0]
    futures = []
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        for id in tqdm.trange(len(data)):
            future = executor.submit(
                dense_test_thread, 
                main_model, 
                bm25_retriever, 
                dense_retriever,
                data, 
                args, 
                id, 
                trace, 
                exception, 
                success,
                test_retrieve_cnt
            )
            futures.append(future)
        
        print('Submit done')
        print(test_retrieve_cnt)
        
        # Option 1: Wait for all futures to complete
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()  # This will raise exceptions if any occurred in the thread
            except Exception as e:
                print(f"An exception occurred: {e}")

        # Option 2: Use executor.shutdown with wait=True (default behavior)
        # executor.shutdown(wait=True)

    print(test_retrieve_cnt)
    
    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')
    return success, trace


def adaptive_json_test_thread(main_model, 
                         bm25_retriever, 
                         dense_retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    try:
        i = data[id]
        print('Processing id: ', id)

        retrieved_ids = []
        queries = [i['question']]
        template = get_conversation_template('llama-3')
        template.set_system_message('Answer the question by retrieving external knowledge. Extract useful information from each retrieved document. If the information is insufficient or irrelevant, refine your query and search again until you are able to answer the question.')
        template.append_message(template.roles[0], "Question: "+i['question'].strip())

        max_iter = args.retrieval_max_iter
        current_iter = 0

        first_model_output = None

        while max_iter > 0:
            prompt = template.get_prompt()
            if args.verbose:
                print('input', prompt)
            # print(prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            
            if 'Refined Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Refined Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1 
            elif 'Initial Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Initial Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1
            elif 'final answer' in first_model_output.lower():
                template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
                break
            else:
                print('Exception: Follow Failed')
                print(template.get_prompt())
                print(first_model_output)
            # print(queries[0])
            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

            document = None
            try:
                # queries[0] 
                queries_json = json.loads(queries[0])
                queries[0] = queries_json['Query']
                if queries_json['Retriever'] == 'BM25':
                    retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = bm25_retriever.retrieve(
                                None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
                    if args.verbose:
                        print(retrieved_results[0].items())
                        
                    current_passages = 0
                    documents = []
                    for query_id, retrieved_result in retrieved_results.items():
                        # No documents found
                        if retrieved_result == {}:
                            continue
                        retrieved_documents = list(retrieved_result.keys())
                        for doc_id in retrieved_documents:
                            if doc_id not in retrieved_ids:
                                retrieved_ids.append(doc_id)
                                documents.append(retrieved_result[doc_id][2])
                                current_passages += 1
                                if current_passages >= args.num_passages:
                                    break
                            else:
                                continue
                    document = ' '.join(documents)
                else:
                    retrieval_results = dense_retriever.search(queries[0])
                    documents = []
                    for result in retrieval_results:
                        if result['id'] not in retrieved_ids:
                            retrieved_ids.append(result['id'])
                            documents.append(result['contents'].split('\n')[-1])
                        if len(documents) >= args.num_passages:
                            break
                    document = ' '.join(documents)
            except:
                print('Exception: queries[0] is not json format')
                print(queries[0])
                break
                        

            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            
            max_iter -= 1
        first_model_output=""
        if max_iter == 0:
            # print(template.get_prompt())
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template.get_prompt(),
                temperature=0.0,
                max_tokens=150,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            template.append_message(template.roles[1], 
                                        first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

        max_iter = args.elicit_max_iter 
        while 'Refined Query:' in first_model_output and max_iter > 0:
            current_iter+=1
            query = first_model_output.split('Refined Query:')[-1].strip()
            if '<|eot_id|>' in query:
                query = query.split('<|eot_id|>')[0].strip()

            document_prompt = Knowledge_Prompt.format(i['question'], query)
            if args.verbose:
                print("Knowledge prompt:", document_prompt)
            document = main_model.completions.create(
                model = args.main_model,
                prompt = document_prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['\n', '<|eot_id|>']
            ).choices[0].text.strip()
            if args.verbose:
                print("Document:", document)
            # print('Document:', document)
            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input:', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=150
            ).choices[0].text.strip()

            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1
        all_output = template.get_prompt()
        if args.verbose:
            print(all_output)

        trace.append({
            'id': id,
            'question': i['question'],
            'trace': all_output,
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
        answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()
        
        for ans in i['golden_answers']:
            if normalize_answer(ans.lower()) ==  normalize_answer(answer.lower()):
                print('Success!')
                success.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['golden_answers'],
                })
                break

        print(trace[-1])
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': template.get_prompt(),
            'golden_answers': i['golden_answers'],
        })
        print('Exception!', e)



def adaptive_json_test(main_model, bm25_retriever, dense_retriever, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(adaptive_json_test_thread, 
                            main_model, 
                            bm25_retriever, 
                            dense_retriever,
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace

def bm25_test_thread(main_model, 
                         retriever, 
                         data, 
                         args,
                         id,
                         trace,
                         exception,
                         success):
    try:
        i = data[id]
        print('Dense Processing id: ', id)

        retrieved_ids = []
        queries = [i['question']]
        template = get_conversation_template('llama-3')
        template.set_system_message('Answer the question by retrieving external knowledge. Extract useful information from each retrieved document. If the information is insufficient or irrelevant, refine your query and search again until you are able to answer the question.')
        template.append_message(template.roles[0], "Question: "+i['question'].strip())

        max_iter = args.retrieval_max_iter
        current_iter = 0

        first_model_output = None

        while max_iter > 0:
            prompt = template.get_prompt()
            if args.verbose:
                print('input', prompt)
            # print(prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            print(first_model_output)
            if 'Query:'.lower() in first_model_output.lower():
                queries = [first_model_output.split('Query:')[-1].strip()]
                if '<|eot_id|>' in queries[0]:
                    queries = [queries[0].split('<|eot_id|>')[0].strip()]
                current_iter += 1 
            elif 'final answer' in first_model_output.lower():
                template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
                break
            else:
                print('Exception: Follow Failed')
                print(template.get_prompt())
                print(first_model_output)
            # print(queries[0])
            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

            document = None
            
            queries[0] = queries[0].replace('[Dense]', '').strip()
            # print(queries)
            retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
                                None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
            current_passages = 0
            # print(retrieved_results)
            documents = []
            for query_id, retrieved_result in retrieved_results.items():
                # No documents found
                if retrieved_result == {}:
                    continue
                retrieved_documents = list(retrieved_result.keys())
                for doc_id in retrieved_documents:
                    if doc_id not in retrieved_ids:
                        retrieved_ids.append(doc_id)
                        documents.append(retrieved_result[doc_id][2])
                        current_passages += 1
                        if current_passages >= args.num_passages:
                            break
                    else:
                        continue
            document = ' '.join(documents)

            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            
            max_iter -= 1
        first_model_output=""
        if max_iter == 0:
            # print(template.get_prompt())
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= template.get_prompt(),
                temperature=0.0,
                max_tokens=150,
                stop=['<|eot_id|>']
            ).choices[0].text.strip()
            template.append_message(template.roles[1], 
                                        first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())

        max_iter = args.elicit_max_iter 
        while 'Refined Query:' in first_model_output and max_iter > 0:
            current_iter+=1
            query = first_model_output.split('Refined Query:')[-1].strip()
            if '<|eot_id|>' in query:
                query = query.split('<|eot_id|>')[0].strip()

            document_prompt = Knowledge_Prompt.format(i['question'], query)
            if args.verbose:
                print("Knowledge prompt:", document_prompt)
            document = main_model.completions.create(
                model = args.main_model,
                prompt = document_prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['\n', '<|eot_id|>']
            ).choices[0].text.strip()
            if args.verbose:
                print("Document:", document)
            # print('Document:', document)
            template.append_message(template.roles[0], 
                                        "Retrieved Document_{}: ".format(current_iter)+document.strip())
            prompt = template.get_prompt()
            if args.verbose:
                print('input:', prompt)
            first_model_output = main_model.completions.create(
                model=args.main_model,
                prompt= prompt,
                temperature=0.0,
                max_tokens=150
            ).choices[0].text.strip()

            template.append_message(template.roles[1], 
                                    first_model_output.split('<|start_header_id|>assistant<|end_header_id|>')[1].split('<|eot_id|>')[0].strip())
            max_iter -= 1
        all_output = template.get_prompt()
        if args.verbose:
            print(all_output)

        trace.append({
            'id': id,
            'question': i['question'],
            'trace': all_output,
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
        answer = all_output.split('Final Answer:')[-1].split('<|eot_id|>')[0].strip()
        
        for ans in i['golden_answers']:
            if normalize_answer(ans.lower()) ==  normalize_answer(answer.lower()):
                print('Success!')
                success.append({
                'id': id,
                'question': i['question'],
                'trace': all_output,
                'golden_answer': i['golden_answers'],
                })
                break

        print(trace[-1])
        print('processed id: ', id)
        print('success_rate:', len(success)/len(trace))
    except Exception as e:
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': template.get_prompt(),
            'golden_answers': i['golden_answers'],
        })
        print('Exception!', e)

def bm25_test(main_model, bm25_retriever, data, args):
    trace = []
    exception = []
    success = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(bm25_test_thread, 
                            main_model, 
                            bm25_retriever, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')


    print('Total training data: ', len(trace))
    return success, trace





def standard_test_concat_thread(main_model,
                        data,
                        args,
                        id,
                        trace,
                        exception,
                        success):
    try:
        i = data[id]

        template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{}\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nQuestion: {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
        document_template = """Doc {} {}\n"""
        system = """Answer the question based on the given document.Only give me the answer and do not output any other words.
        The following are given documents.
        
        {}"""
        queries = [i['question']]
        # print(queries)
        # if queries is not None:
        #     retrieved_results : Dict[str, Dict[str, Tuple[float, str]]] = retriever.retrieve(
        #                 None, dict(zip(range(len(queries)), list(zip(queries, ([None] * len(queries)))))))
        # print(retrieved_results)
        retrieved_ids = []
        documents=""
        current_passages=0
        # for query_id, retrieved_result in retrieved_results.items():
        #     # No documents found
        #     if retrieved_result == {}:
        #         continue
        #     retrieved_documents = list(retrieved_result.keys())
        #     for doc_id in retrieved_documents:
        #         if doc_id not in retrieved_ids:
        #             retrieved_ids.append(doc_id)
        #             documents+=document_template.format(current_passages+1,
        #                                                 retrieved_result[doc_id][1],
        #                                                 retrieved_result[doc_id][2])
        #             current_passages += 1
        #             if current_passages >= 5:
        #                 break
        #         else:
        #             continue
        docs =  [doc.split('<|eot_id|>')[0].split(':')[-1] for doc in i['trace'].split('\n\n') if 'retrieved document' in doc.lower() and 'answer' not in doc.lower()]
        
        for doc in docs:
            documents+=document_template.format(current_passages+1,
                                                        doc)
            current_passages+=1
        system = system.format(documents)
        prompt = template.format(system,i['question'])
        # print(prompt)

        main_model_output = main_model.completions.create(
                model = args.main_model,
                prompt=prompt,
                temperature=0.0,
                max_tokens=200,
                stop=['<|eot_id|>', '\n']
            ).choices[0].text.strip()
        print(main_model_output)
        trace.append({
            'id': id,
            'question': i['question'],
            'trace': main_model_output,
            'golden_answers': i['golden_answers'],
        })
        if 'answer_id' in i.keys():
            trace[-1]['answer_id'] = i['answer_id']
        
        print('Processed id: ', id)
    except Exception as e:
        print(e)

def standard_concat_test(main_model, data, args):
    trace = []
    exception = []
    success = []


    with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor:
        # 
        for id in tqdm.trange(len(data)):
            executor.submit(standard_test_concat_thread, 
                            main_model, 
                            data, 
                            args, 
                            id, 
                            trace, 
                            exception, 
                            success)
        print('submit done')

    with jsonlines.open(args.save_path, 'w') as f:
        f.write_all(trace)
        print('Result Saved!')
    return success, trace




if __name__ == '__main__':
    argparser = ArgumentParser()
    args = ...
    args = argparser.parse_args()



    if args.search_engine == 'bm25':
        retriever = EvaluateRetrieval(
            BM25Search(index_name=args.index_name, hostname='localhost', initialize=False, number_of_shards=1),
            k_values=[args.retrieval_top_k],)
    elif args.search_engine == 'dense':
        print('loading dense retriever')
        retrieval_config = {
            'retrieval_method': 'e5',
            'retrieval_model_path': '',
            'retrieval_query_max_length': 256,
            'retrieval_use_fp16': True,
            'retrieval_topk': 50,
            'retrieval_batch_size': 32,
            'index_path': args.dense_index_path,
            'corpus_path': args.dense_corpus_path,
            'save_retrieval_cache': False,
            'use_retrieval_cache': False,
            'retrieval_cache_path': None,
            'use_reranker': False,
            'faiss_gpu': False,
            'use_sentence_transformer': False,
            'retrieval_pooling_method': 'mean'
        }

        retriever = DenseRetriever(retrieval_config)
    elif args.search_engine == 'adaptive':
        print('loading bm25 retriever')
        bm25_retriever = EvaluateRetrieval(
            BM25Search(index_name=args.index_name, hostname='localhost', initialize=False, number_of_shards=1),
            k_values=[args.retrieval_top_k],)
        
        print('loading dense retriever')
        retrieval_config = {
            'retrieval_method': 'e5',
            'retrieval_model_path': '',
            'retrieval_query_max_length': 256,
            'retrieval_use_fp16': True,
            'retrieval_topk': 50,
            'retrieval_batch_size': 32,
            'index_path': args.dense_index_path,
            'corpus_path': args.dense_corpus_path,
            'save_retrieval_cache': False,
            'use_retrieval_cache': False,
            'retrieval_cache_path': None,
            'use_reranker': False,
            'faiss_gpu': False,
            'use_sentence_transformer': False,
            'retrieval_pooling_method': 'mean'
        }

        dense_retriever = DenseRetriever(retrieval_config)

        

    data = load_data(args.data_path)

    # Initialize OpenAI API

    main_model = OpenAI(
        base_url=args.main_model_url,
        api_key="EMPTY",
    )

    rewrite_model = OpenAI(
        base_url=args.rewrite_model_url,
        api_key="EMPTY",
    )
    if args.retrieve_mode == 'rf':
        useful, trace = test(main_model, retriever, data, args)
    elif args.retrieve_mode == 'naive':
        useful, trace = naive_test(main_model, data, args)
    elif args.retrieve_mode == 'standard':
        useful, trace = standard_concat_test(main_model, data, args)
    elif args.retrieve_mode == 'few_shot':
        useful, trace = few_shot_test(main_model, retriever, rewrite_model, data, args)
    elif args.retrieve_mode == 'adaptive':
        useful, trace = adaptive_test(main_model, bm25_retriever, dense_retriever, data, args)
    elif args.retrieve_mode == 'dense':
        useful, trace = dense_test(main_model, bm25_retriever, dense_retriever, data, args)
    elif args.retrieve_mode == 'adaptive_json':
        useful, trace = adaptive_json_test(main_model, bm25_retriever, dense_retriever, data, args)
    elif args.retrieve_mode == 'bm25':
        useful, trace = bm25_test(main_model, retriever, data, args)

    