import argparse
import logging
import random
import yaml
from tqdm import tqdm
from utils import *
from configs import *
from frameworks import *
from lm import *

def get_configs(path):
    with open(path, "r") as f:
        cfgs = yaml.load(f, Loader=yaml.FullLoader)
    return cfgs

def parse_arguments(cfgs):
    parser = argparse.ArgumentParser()
    parser.add_argument("--random_seed", type=int, default=1, help="random seed")
    parser.add_argument("--dataset", type=str, default="gsm8k", choices=ALL_DATASETS, help="dataset used for experiment")
    parser.add_argument("--model", type=str, default="palm2", choices=["palm2", "gpt3"], help="language model used for experiment")
    parser.add_argument("--lm_output_max_length", type=int, default=850, help="maximum length of output tokens of language model")
    parser.add_argument("--log_dir", type=str, default="logs", help="log directory")
    parser.add_argument("--framework", type=str, default="few-shot_cot")
    args = parser.parse_args()
    setattr(args, "dataset_path", cfgs["dataset"][args.dataset]["path_test"])
    return args

if __name__ == "__main__":
    cfgs = get_configs("configs/data_configs.yaml")
    args = parse_arguments(cfgs)
    fix_seed(args.random_seed)
    RM = ReportManager(args)

    lm = LM(args)
    print(f"Language Model: {args.model} initialized")
    answer_cleaner = AnswerCleaner(args)
    dataset = test_data_reader(args)
    raw_questions, raw_answers = dataset.get("questions"), dataset.get("answers")

    questions, labels = [], []
    cumul_api_usage = 0
    for i, (x, y) in enumerate(zip(raw_questions, raw_answers)):
        # Prepare Question Template
        x = "Question: " + x
        y = y.strip()
        y = answer_cleaner.gt_cleaning(y)
        questions.append(x.strip())
        labels.append(y)

    framework = get_framework(args.framework, args.dataset, questions, args.model.split("-")[0])
    cleaning_flags = framework.get_cleaning_flags()
    num_stages = framework.get_num_stages()
    prompt_examples = []

    # problem representation construction
    for i in range(num_stages):
        stage_prompts = framework.get_stage_prompts(stage_idx=i)["prompts"]
        stage_prompt_example = f"\n{'=' * 50}\nStage {i} Prompt Example:\n{'-' * 50}\n{stage_prompts[0]}\n{'=' * 50}\n"
        print(stage_prompt_example)
        prompt_examples.append(stage_prompt_example)
        outputs = lm.process(stage_prompts, desc=f"Stage {i} Prompt Completion")
        output_list = outputs["output_list"]
        cumul_api_usage += outputs["api_usage"]

        if cleaning_flags["rule_based"][i]:
            output_list = framework.rule_based_stage_cleaning(output_list, stage_idx=i)["outputs"]
        
        framework.update_stage_output(output_list, stage_idx=i)

    # solution searching
    integrated_prompts = framework.get_integrated_prompts()["prompts"]
    stage_prompt_example = f"\n{'=' * 50}\nIntegrated Prompt Example:\n{'-' * 50}\n{integrated_prompts[0]}\n{'=' * 50}\n"
    print(stage_prompt_example)
    prompt_examples.append(stage_prompt_example)
    outputs = lm.process(integrated_prompts, desc="Reasoning Prompt Completion")
    reasoning_output_list = outputs["output_list"]
    cumul_api_usage += outputs["api_usage"]

    cleansing_target_list = None
    if cleaning_flags["lm_extraction"]:
        extraction_prompts = framework.get_extraction_prompts(reasoning_output_list)["prompts"]
        outputs = lm.process(extraction_prompts, desc="LM based Answer Extraction")
        cleansing_target_list = outputs["output_list"]
        cumul_api_usage += outputs["api_usage"]
    else:
        cleansing_target_list = reasoning_output_list

    if not DATASET_TYPE[args.dataset] == "word":
        cleaned_preds = [answer_cleaner.get_pred(s) for s in cleansing_target_list]
    else:
        cleaned_preds = [answer_cleaner.get_pred(s, to_detect) for s, to_detect in zip(cleansing_target_list, labels)]

    whole_process = framework.get_whole_process(reasoning_output_list, cleansing_target_list)["outputs"]

    # Logging
    RM.update_log(whole_process, cleaned_preds, labels, cumul_api_usage, prompt_examples, framework.get_configs())
    RM.save_report()