import os
import yaml
from utils import *
from configs import *
from nltk import sent_tokenize

FRAMEWORK_CONFIGS = "frameworks/settings.yaml"

def get_lm_extraction_trigger(dataset):
    if DATASET_TYPE[dataset] == "mc":
        trg = "Therefore, the answer (a letter of the given options) is"
    elif DATASET_TYPE[dataset] == "num":
        trg = "Therefore, the answer (arabic numerals) is"
    elif DATASET_TYPE[dataset] == "yn":
        trg = "Therefore, the answer (yes or no) is"
    elif DATASET_TYPE[dataset] == "word":
        trg = "Therefore, the answer (a word) is"
    else:
        raise NotImplementedError
    trg = trg.strip() + " " if trg is not None else None
    return trg

def get_prompt(index, dataset):
    with open("configs/data_configs.yaml", "r") as f:
        data_configs = yaml.load(f, Loader=yaml.FullLoader)
    prompt_path = data_configs["dataset"][dataset]["prompt"]
    if index == "cot":
        filename = "cot_full.txt"
    elif index == "precot_gi":
        filename = "precot_gi.txt"
    elif index == "precot_obj":
        filename = "precot_obj.txt"
    elif index == "precot_full":
        filename = "precot_full.txt"
    with open(os.path.join(prompt_path, filename), "r") as f:
        prompt = f.read()
    return prompt.strip() + "\n"

def get_framework(config_index, dataset, questions, model="palm2"):
    if config_index == "few-shot_cot":
        return few_shot_cot("few-shot_cot", dataset, questions, model=model)
    elif config_index == "zero-shot_cot":
        return prompt_framework("zero-shot_cot", dataset, questions, model=model)
    elif config_index == "few-shot_precot":
        return few_shot_precot("few-shot_precot", dataset, questions, model=model)
    elif config_index == "zero-shot_precot":
        return prompt_framework("zero-shot_precot", dataset, questions, model=model)
    return prompt_framework(config_index, dataset, questions, model=model)

class prompt_framework:
    def __init__(self, config_index, dataset, questions, model="palm2"):
        self.config_index = config_index
        
        configs = self.get_configs()
        self.num_stages = configs["num_stages"]
        self.instructions = configs.get("instructions", None)
        self.cleaning_flags = {
            "rule_based": configs["rule_based_stage_cleaning"],
            "lm_extraction": configs.get("lm_extraction", False),
        }
        self.triggers = configs.get("triggers", None)
        self.questions = questions
        self.model = model
        self.stage_output_dict = {}

        if self.triggers is not None:
            for key, value in self.triggers.items():
                self.triggers[key] = value.strip() + " "

        if self.instructions is not None:
            for key, value in self.instructions.items():
                self.instructions[key] = value.strip() + "\n"

        self.final_instruction = None # placeholder
        self.final_trigger = self.triggers["final"]

        if self.cleaning_flags["lm_extraction"]:
            self.lm_extraction_trigger = get_lm_extraction_trigger(dataset)
    
    def get_num_stages(self):
        return self.num_stages

    def get_cleaning_flags(self):
        rule_based = self.cleaning_flags["rule_based"]
        extraction = self.cleaning_flags["lm_extraction"]
        return {"rule_based": rule_based, "lm_extraction": extraction}
    
    def get_configs(self):
        with open(os.path.join(FRAMEWORK_CONFIGS), "r") as f:
            configs = yaml.load(f, Loader=yaml.FullLoader)
        configs = configs[self.config_index]
        return configs
        
    def get_stage_prompts(self, stage_idx=0):
        prompts = []
        instruction = self.get_instruction(stage_idx)["instruction"]
        stage_trigger = self.get_trigger(stage_idx)["trigger"]
        for q in self.questions:
            if instruction is None:
                prompts.append("\n".join([q, stage_trigger]))
            else:
                prompts.append("\n".join([instruction, q, stage_trigger]))
        return {"prompts": prompts}

    def get_lm_based_stage_cleaning_prompts(self, stage_output_list, stage_idx=0):
        prompts = []
        instruction = self.get_instruction(stage_idx)["cleaning_instruction"]
        triggers = self.get_trigger(stage_idx)
        stage_trigger, cleaning_trigger = triggers["trigger"], triggers["cleaning_trigger"]
        for o in stage_output_list:
            prompts.append("\n".join([instruction, stage_trigger + o, cleaning_trigger]))
        return {"prompts": prompts}

    def rule_based_stage_cleaning(self, stage_output_list, stage_idx=0):
        cleaned_outputs = self.multiple_linebreaks_to_single(stage_output_list)
        cleaned_outputs = self.take_first_sections(stage_output_list)
        trigger_content = self.get_trigger(stage_idx)["trigger"].split(":")[-1].strip()
        for i in range(len(cleaned_outputs)):
            cleaned_outputs[i] = cleaned_outputs[i].replace("\n", " ").replace(trigger_content, "")
            sent_tokens = [s.strip() for s in sent_tokenize(cleaned_outputs[i]) if len(s.strip()) > 0]
            sent_tokens = [s for s in sent_tokens if not "answer is" in s.lower()]
            cleaned_outputs[i] = " ".join(sent_tokens)
            if not cleaned_outputs[i].endswith("."):
                cleaned_outputs[i] += "."
        return {"outputs": cleaned_outputs}

    def update_stage_output(self, stage_output_list, stage_idx=0):
        self.stage_output_dict[f"stage_{stage_idx}"] = stage_output_list

    def get_extraction_prompts(self, rationale_output_list):
        prompts = []
        for i, o in enumerate(rationale_output_list):
            prompts.append(f"{self.questions[i]}\n{self.final_trigger}{o}\n\n{self.lm_extraction_trigger}")
        return {"prompts": prompts}

    def get_whole_process(self, rationale_output_list, cleansing_target_list=None):
        whole_process = []
        integrated = self.get_integrated_prompts()["prompts"]
        for integrated, rationale, pred_sentence in zip(integrated, rationale_output_list, cleansing_target_list):
            if self.get_cleaning_flags()["lm_extraction"]:
                whole_process.append(f"{integrated}{rationale}\n\n{self.lm_extraction_trigger}{pred_sentence}")
            else:
                whole_process.append(f"{integrated}{rationale}")
        return {"outputs": whole_process}

    def get_integrated_prompts(self):
        prompts = [q for q in self.questions]
        for stage, output_list in self.stage_output_dict.items():
            for i, o in enumerate(output_list):
                prompts[i] += "\n" + self.triggers[stage] + o
        prompts = [f"{p}\n{self.final_trigger}" for p in prompts] if self.final_instruction is None else [f"{self.final_instruction}\n{p}\n{self.final_trigger}" for p in prompts]
        return {"prompts": prompts}

    def take_first_sections(self, output_list):
        def tfs(s):
            s = s.strip()
            result = re.split(r'(?=\n+[a-zA-Z][a-zA-Z\s]*:)', s)
            return result[0].strip()
        return [tfs(o) for o in output_list]

    def multiple_linebreaks_to_single(self, output_list):
        def mls(s):
            return re.sub(r"\n+", "\n", s)
        return [mls(o) for o in output_list]

    def take_first_lines(self, output_list):
        def tfl(s):
            return s.strip().split("\n")[0]
        return [tfl(o) for o in output_list]

    def get_instruction(self, stage_idx=0):
        return {"instruction": self.instructions.get(f"stage_{stage_idx}", None)}

    def get_trigger(self, stage_idx=0):
        return {"trigger": self.triggers[f"stage_{stage_idx}"]}

class few_shot_cot(prompt_framework):
    def __init__(self, config_index, dataset, questions, model="palm2"):
        super().__init__(config_index, dataset, questions, model=model)
        self.final_instruction = get_prompt("cot", dataset)

    def get_whole_process(self, rationale_output_list, cleansing_target_list=None):
        whole_process = []
        full_prompts = self.get_integrated_prompts()["prompts"]
        queries = ["Question: " + s.split("Question:")[-1].strip() for s in full_prompts]
        for query, rationale in zip(queries, rationale_output_list):
            whole_process.append(f"{query} {rationale}")
        return {"outputs": whole_process}

class few_shot_precot(prompt_framework):
    def __init__(self, config_index, dataset, questions, model="palm2"):
        super().__init__(config_index, dataset, questions, model=model)
        self.final_instruction = get_prompt("precot_full", dataset)
        self.gi = get_prompt("precot_gi", dataset)
        self.obj = get_prompt("precot_obj", dataset)

    def get_stage_prompts(self, stage_idx=0):
        prompts = []
        instruction = self.gi if stage_idx == 0 else self.obj
        stage_trigger = self.get_trigger(stage_idx)["trigger"]
        for q in self.questions:
            prompts.append("\n".join([instruction, q, stage_trigger]))
        return {"prompts": prompts}
    
    def rule_based_stage_cleaning(self, stage_output_list, stage_idx=0):
        cleaned_outputs = self.multiple_linebreaks_to_single(stage_output_list)
        cleaned_outputs = self.take_first_sections(stage_output_list)
        cleaned_outputs = [s.replace("\n", " ") for s in cleaned_outputs]
        return {"outputs": cleaned_outputs}
    
    def get_whole_process(self, rationale_output_list, cleansing_target_list=None):
        whole_process = []
        full_prompts = self.get_integrated_prompts()["prompts"]
        queries = ["Question: " + s.split("Question:")[-1].strip() for s in full_prompts]
        for query, rationale in zip(queries, rationale_output_list):
            whole_process.append(f"{query} {rationale}")
        return {"outputs": whole_process}