import json
import argparse
from model import Model
from tqdm import tqdm
import time
import re
from util import is_irrelevant
import os

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--num_docs", type=int, default=10)
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--fewshot_examples", type=str, default=None)
parser.add_argument("--max_fewshot", type=int, default=1)

args = parser.parse_args()

data = json.load(open(args.data))
if os.path.exists(args.output):
    print("Output file already exists. Loading existing data.")
    data = json.load(open(args.output))

model = Model(args.model)
fewshot_examples = json.load(open(args.fewshot_examples)) if args.fewshot_examples is not None else None

prompt = None
if args.prompt is not None:
    prompt = open(args.prompt).read().strip()
else:
    assert "full_prompt" in data[0]

for k, dat in tqdm(enumerate(data),total=len(data)):
    if "output" in dat and dat["output"] is not None:
        continue
    context = dat["context"]
    if "summaries" in dat:
        context = dat["summaries"]
    elif "summaries_individual" in dat:
        context = dat["summaries_individual"]
    
    if type(context) is list:
        context = context[:args.num_docs]
        context_str = ""
        for i, ctxt in enumerate(context):
            if not is_irrelevant(ctxt):
                ctxt = ctxt.replace("- ", "")
                context_str += f"Passage {i+1}:\n{ctxt}\n\n"
        context = context_str
    
    if context == "":
        next_available_context = None
        for ctxt in dat["context"]:
            if ctxt != "":
                next_available_context = ctxt
                break
        if next_available_context is None:
            dat["output"] = ""
            continue
        context_str += f"Passage 1:\n{next_available_context}\n\n"
        context = context_str

    if fewshot_examples is not None:
        # prepare fewshot examples
        fewshot_str = ""
        fewshot_i = 0

        for kk, vv in fewshot_examples.items():
            if int(kk) != int(k): # use the ones that are different ids
                fewshot_str += f"Example {fewshot_i+1}:\n"
                fewshot_str += """Question: {question}

Passages:
{context}

Output:
{output}""".format(question=vv["question"], context=vv["context"], output=vv["output"])
                fewshot_str += "\n\n"
                fewshot_i += 1
                if fewshot_i >= args.max_fewshot:
                    break
        cur_prompt = prompt.format(context=context, question=dat["question"], examples=fewshot_str)
    else:
        cur_prompt = prompt.format(context=context, question=dat["question"])
    
    # import pdb; pdb.set_trace()
    res = model.run(cur_prompt)
    
    dat["output"] = res
    
    # print(cur_prompt)
    # print("---"*10)
    # print(res)
    # print("==="*10)
    # aeuaeu

    with open(args.output, "w") as f:
        json.dump(data, f, indent=4)
    # break

    
