import argparse
import random

from utils.cot.get_prompt import get_prompt
from utils.decoder import Decoder, answer_cleansing
from utils.fp_substitution import fp_substitute, get_nums_from_passage
from utils.solis.solis_solver import try_search
from utils.solis.helper import *


def demo(decoder: Decoder, x: str, CNT_SUM):
    random.seed(123)
    args = get_default_argument()
    prompt_x = get_prompt()
    
    orig_nums, _ = get_nums_from_passage(x)
    if len(orig_nums) > 3:
        return "Too many operands!"
    orig_x = prompt_x + f"Q: {x}\nA:"
    # step 0, original predict
    try:
        orig_z = decoder.decode(args, orig_x, CNT_SUM)
        orig_z = answer_cleansing(args, orig_z)
    except Exception as e:
        print(e)
        orig_z = "Too Frequent!"
        return orig_z
    
    # step 1, #TODO skip operand proposal
    # step 2, substitute
    fp_data_list = fp_substitute(x, args.substitute_time)
    fp_results = []
    for fp_data in fp_data_list:
        fp_x = prompt_x + f"Q: {fp_data['Question']}\nA:"
        try:
            fp_z = decoder.decode(args, fp_x, CNT_SUM)
            fp_z = answer_cleansing(args, fp_z)
        except Exception as e:
            print(e)
            fp_z = "Too Frequent!"
            return fp_z
        fp_results.append({
            "fp_nums": fp_data["Alignments"],
            "fp_z": fp_z,
        })
    # step 3, arith relationship inversion
    solis_ret = try_search(args, orig_nums, fp_results)
    print(solis_ret)
    return solis_ret

def get_default_argument():
    parser = argparse.ArgumentParser(description="Solis")
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--api_time_interval", type=float, default=2)
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--substitute_time", type=int, default=5)
    parser.add_argument("--dataset", type=str, default="multiarith")
    parser.add_argument("--direct_answer_trigger_for_fewshot", type=str, default="The answer is")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    test_examples = [
        "Nancy uploaded 41 pictures to Facebook. She put 37 pics into one album and put the rest into 2 different albums. How many pictures were in each album?",
    ]
    decoder = Decoder()
    for test_example in test_examples:
        demo(decoder, test_example, 0)