import os
import re
import json
import argparse
import random
from langchain import LLMChain, OpenAI, PromptTemplate
from tqdm import tqdm
from base_prompt import *
from langchain.chat_models import ChatOpenAI

import openai

with open('./keys/openai.key', 'r') as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()

openai.api_key = os.getenv("OPENAI_API_KEY")

def load_data(args):
    problems = json.load(open(os.path.join(args.data_root, 'problems.json')))
    pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json')))
    captions = json.load(open(args.caption_file))["captions"]

    for qid in problems:
        problems[qid]['caption'] = captions[qid] if qid in captions else ""

    qids = pid_splits['%s' % (args.test_split)]
    qids = qids[:args.test_number] if args.test_number > 0 else qids
    print(f"number of test problems: {len(qids)}\n")

    # pick up shot examples from the training set
    shot_qids = args.shot_qids
    train_qids = pid_splits['train']
    if shot_qids == None:
        assert args.shot_number >= 0 and args.shot_number <= 32
        shot_qids = random.sample(train_qids, args.shot_number)  # random sample
    else:
        shot_qids = [str(qid) for qid in shot_qids]
        for qid in shot_qids:
            assert qid in train_qids  # check shot_qids
    print("training question ids for prompting: ", shot_qids, "\n")

    return problems, qids, shot_qids


def get_causal_reasoning_result(model, prompt, args):
    output = model.run({"input": prompt}).strip()
    return ("Observed Information: " + output).replace("\n\n", "\n")

def get_confidence(input_str):
    # match = re.findall(r"\nBelief: (\b(?:100|[1-9]?[0-9])\b)", input_str)
    match = re.findall(r"\nBelief: ([A-Z])", input_str)
    if match:
        confidence_score = match[-1]

        return confidence_score
    else:
        return 'B'
    
def extract_the_answer(output):
    # extract the answer
    pattern = re.compile(r'The answer is ([A-Z])')
    res = pattern.findall(output)
    if len(res) == 1:
        answer = res[0]  # 'A', 'B', ...
    else:
        answer = "FAILED"
    return answer

def get_refinement_result(model, prompt, args):
    output = model.run({"input": prompt}).strip()

    # extract the answer
    answer = extract_the_answer(output)

    return answer, output

def get_pred_idx(prediction, choices, options):
    """
    Get the index (e.g. 2) from the prediction (e.g. 'C')
    """
    if prediction in options[:len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))


def get_result_file(args, faithful_reasoning=False):
    if faithful_reasoning:
        result_file = "{}/{}/{}_{}_{}_{}_seed_{}_{}.json".format(args.output_root, args.model+'-fr', args.label, args.test_split,
                                                        args.prompt_format, args.shot_number, args.seed, args.version)
    else:
        result_file = "{}/{}/{}_{}_{}_{}_seed_{}.json".format(args.output_root, args.model, args.label, args.test_split,
                                                            args.prompt_format, args.shot_number, args.seed)

    return result_file


def save_results(result_file, acc, correct, count, shot_qids, args, causal_reasoning_outputs, results, outputs, refinement_causal_reasoning_outputs):
    data = {}
    data['acc'] = acc
    data['correct'] = correct
    data['count'] = count
    data['shot_qids'] = shot_qids
    data['args'] = vars(args)
    data['causal_reasoning_outputs'] = causal_reasoning_outputs
    data['results'] = results
    data['outputs'] = outputs
    data['refinement_causal_reasoning_outputs'] = refinement_causal_reasoning_outputs

    with open(result_file, 'w') as f:
        json.dump(data, f, indent=2, separators=(',', ': '))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_root', type=str, default='../data/scienceqa')
    parser.add_argument('--output_root', type=str, default='../results')
    parser.add_argument('--caption_file', type=str, default='../data/instruct_captions.json')
    parser.add_argument('--model', type=str, default='gpt-3.5-turbo')
    parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
    # user options
    parser.add_argument('--label', type=str, default='exp0')
    parser.add_argument('--version', type=str, default='v0')
    parser.add_argument('--test_split', type=str, default='val', choices=['test', 'val', 'minival'])
    parser.add_argument('--test_number', type=int, default=10, help='Chatgpt is expensive. -1 for whole val/test set')
    parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
    parser.add_argument('--save_every', type=int, default=10, help='Save the result with every n examples.')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--prompt_format',
                        type=str,
                        default='CQM-A',
                        choices=[
                            'CQM-A', 'CQM-LA', 'CQM-EA', 'CQM-LEA', 'CQM-ELA', 'CQM-AL', 'CQM-AE', 'CQM-ALE', 'QCM-A',
                            'QCM-LA', 'QCM-EA', 'QCM-LEA', 'QCM-ELA', 'QCM-AL', 'QCM-AE', 'QCM-ALE', 'QCML-A', 'QCME-A',
                            'QCMLE-A', 'QCLM-A', 'QCEM-A', 'QCLEM-A', 'QCML-AE'
                        ],
                        help='prompt format template')
    parser.add_argument('--shot_number', type=int, default=3, help='Number of n-shot training examples.')
    parser.add_argument('--shot_qids', type=list, default=None, help='Question indexes of shot examples')
    parser.add_argument('--seed', type=int, default=10, help='random seed')
    # GPT-3 settings
    parser.add_argument('--engine', type=str, default='gpt-3.5-turbo')
    parser.add_argument('--temperature', type=float, default=0.0)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)

    args = parser.parse_args()
    return args


if __name__ == '__main__':

    args = parse_args()
    print('====Input Arguments====')
    print(json.dumps(vars(args), indent=2, sort_keys=False))

    random.seed(args.seed)

    problems, qids, shot_qids = load_data(args)  # probelms, test question ids, shot example ids

    result_file = get_result_file(args)
    fr_result_file = get_result_file(args, faithful_reasoning=True)
    predicted_outputs = json.load(open(result_file))["outputs"]
    ori_results = json.load(open(result_file))['results']

    # load the check point
    if os.path.exists(result_file):
        print("# The result file exists! We will load the check point!!!")
        check_point = json.load(open(result_file))
        # statistic
        correct = sum([1 for qid in qids if check_point['results'][qid] == problems[qid]["answer"]])
        print(f"initial results: {correct}/{len(qids)}, correct: {correct}, acc: {round(correct / len(qids)*100, 2)}%")
    else:
        raise FileNotFoundError

    if os.path.exists(fr_result_file):
        print("# The result file exists! We will load the check point!!!")
        check_point = json.load(open(fr_result_file))
        acc = check_point['acc']
        correct = check_point['correct']
        causal_reasoning_outputs = check_point['causal_reasoning_outputs']
        fr_results = check_point['results']
        fr_refinement_outputs = check_point['outputs']
        refinement_causal_reasoning_outputs = check_point.get('refinement_causal_reasoning_outputs', {})

         # statistic
        if len(fr_results) >= len(qids):
            correct = sum([1 for qid in qids if fr_results[qid] == problems[qid]["answer"]])
            print(f"faithful reasoning results: {correct}/{len(qids)}, correct: {correct}, acc: {round(correct / len(qids)*100, 2)}%")

    else:
        correct = 0
        causal_reasoning_outputs = {}
        fr_refinement_outputs = {}
        fr_results = {}
        refinement_causal_reasoning_outputs = {}

    # for qid in tqdm(qids):
    llm = ChatOpenAI(temperature=0,
                 model=args.engine,
                 max_tokens=args.max_tokens,
                 top_p=args.top_p,
                 frequency_penalty=args.frequency_penalty,
                 presence_penalty=args.presence_penalty,
                 request_timeout=120)
    prompt_template = PromptTemplate.from_template("{input}")
    model = LLMChain(prompt=prompt_template, llm=llm, verbose=False)
    for i, qid in enumerate(qids):
        if qid in fr_results:
            continue
        # if qid != '22':
        #     continue

        choices = problems[qid]["choices"]
        answer = problems[qid]["answer"]  # 0, 1, ..., 4
        label = args.options[answer]  # 'A', ..., 'E'
        ori_pred_idx = ori_results[qid]

        # generate prompt
        prompt = build_causal_reasoning_prompt(problems, predicted_outputs[qid], qid, args)
        causal_reasoning_output = get_causal_reasoning_result(model, prompt, args)
        init_belief = get_confidence(causal_reasoning_output)
        # print(f"\n==={qid}===\n")
        # print(f"\n===predicted_outputs[{qid}]===\n", predicted_outputs[qid])
        # print(f"\n===prompt===\n", prompt)
        # print("\n===init causal reasoning output:===\n", causal_reasoning_output)
        causal_reasoning_outputs[qid] = causal_reasoning_output
        
        if init_belief >= 'C':
            prompt = build_refinement_prompt(problems, predicted_outputs[qid], shot_qids, qid, causal_reasoning_output, args)
            prediction, refinement_output = get_refinement_result(model, prompt, args)
            # print("\n===refinement output:===\n", refinement_output)

            # evaluation
            prompt = build_causal_reasoning_prompt(problems, refinement_output, qid, args)
            refinement_causal_reasoning_output = get_causal_reasoning_result(model, prompt, args)
            refinement_belief = get_confidence(refinement_causal_reasoning_output)
            # print("\n===refinement causal reasoning output:===\n", refinement_causal_reasoning_output)
            
            fr_refinement_outputs[qid] = refinement_output
            refinement_causal_reasoning_outputs[qid] = refinement_causal_reasoning_output
            # if refinement_belief < init_belief:
            if refinement_belief >= init_belief:
                prediction = extract_the_answer(predicted_outputs[qid])
                pred_idx = ori_pred_idx
            else:
                # print(f'\n===ori_prediction: {ori_pred_idx}===\n')
                pred_idx = get_pred_idx(prediction, choices, args.options) 
            # pred_idx = get_pred_idx(prediction, choices, args.options)  # 0, 1, ..., 4
            fr_results[qid] = pred_idx
        else:
            prediction = extract_the_answer(predicted_outputs[qid])
            fr_refinement_outputs[qid] = ''
            refinement_causal_reasoning_outputs[qid] = ''
            
            pred_idx = ori_pred_idx  # 0, 1, ..., 4
            fr_results[qid] = pred_idx
        
        print(f'\n===prediction: {pred_idx}, ground-truth: {answer}\n===')
        if pred_idx == answer:
            correct += 1

        acc = correct / len(fr_results) * 100

        if args.debug or i < 3:
            print("##################################")
            print(build_refinement_prompt(problems, predicted_outputs[qid], shot_qids, qid, causal_reasoning_output, args), "\n")
            print("# labeled answer:", label)
            print("# predicted answer:", prediction)
            print("# predicted index:", pred_idx)
            print("# causal reasoning output:", causal_reasoning_output)
            # if init_belief < 8 and refinement_belief >= init_belief:
            if init_belief >= 'C' and refinement_belief < init_belief:
                print("# predicted refinement output:", refinement_output)

        if (i + 1) % args.save_every == 0 or (i + 1) == len(qids):
            print(f"{len(fr_results)}/{len(qids)}, correct: {correct}, acc: {round(acc, 2)}%, saving to {fr_result_file}")
            save_results(fr_result_file, acc, correct, i + 1, shot_qids, args, causal_reasoning_outputs, fr_results, fr_refinement_outputs, refinement_causal_reasoning_outputs)
