# coding: utf-8
import re
import time
# import prompts
from prompts.focus.planner import *
from prompts.focus.solver import *
from prompts.focus.prompt import *

from run_source.dataset import *
from tools.retriever import *
from tools.utils import *
from evaluate import *
from app.sample_resp_cralwer import SampleRespCrawler
from app.chatgpt_resp_cralwer import ChatgptRespCrawler
from app.huggingface_resp_crawler import HuggingFaceRespCrawlerBase


class ThinkPlanDo:
    def __init__(self, key_path, dataset, model_type="chatgpt", model_path="", language="", setting="zero-shot",
            temperature=0.7, persona="", top_p=0.95):
        
        self.key_path = key_path
        self.dataset = dataset
        self.model_type = model_type
        self.model_path = model_path
        self.langugae = language

        self.setting = setting
        self.token_unit_price = get_token_unit_price(model_type)
        if self.dataset == "hotpotqa":
            data_paths = ["./dataset/hotpotQA/hotpot_dev_distractor_v1.json"]
            demo_path = "./cot_retrieval/config/prompt_en.json"
            self.prompt_constructor = HotpotQA(data_paths, demo_path=demo_path, mode="dev")
        elif self.dataset == "focus":
            data_paths = ["./dataset/FoCus/valid_focus.json"]
            self.prompt_constructor = FoCus(data_paths, mode="dev")

        self.retriever = Retriever('bm25')
        self.call_retrieval_times, self.persona_right_count, self.knowledge_right_count = 0, 0, 0
        self.init_model_type(temperature, persona, top_p, model_type)

    def init_model_type(self, temperature, persona, top_p, model):
        if self.model_type in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003", "gpt-4-0613"]:
            self.sample_crawler = ChatgptRespCrawler(self.key_path, temperature, persona=persona, top_p=top_p, model=model)
        else:
            self.sample_crawler = HuggingFaceRespCrawlerBase(self.model_type, self.model_path, top_p=top_p, temperature=temperature)

    # @retry  # 装饰器，如果有异常就重复执行；无异常即返回
    def get_api_result(self, input_path, output_path): 
        while not is_finished_all_prompts(input_path, output_path):
            self.sample_crawler.get_all_result(input_path, output_path)

    def retrieve_external_knowledge(self, query, knowledge_bases, number_results):
        top_indexes, retrieved_res = self.retriever.retrieve_top_n(query, knowledge_bases, number=number_results)
        self.call_retrieval_times += 1
        return top_indexes, retrieved_res
    
    def reflect_global(self, prompt):
        reflection = self.sample_crawler.call_openai_each(prompt)
        total_tokens, reflection = get_response_according_to_model_type(reflection, self.model_type)
        
        self.cache["prices"] += total_tokens * self.token_unit_price
        return reflection
    
    def plan(self, prompt):
        plan = self.sample_crawler.call_openai_each(prompt)
        total_tokens, plan = get_response_according_to_model_type(plan, self.model_type)

        self.cache["prices"] += total_tokens * self.token_unit_price
        return plan

    def solve(self, prompt):
        generations = self.sample_crawler.call_openai_each(prompt)
        total_tokens, generations = get_response_according_to_model_type(generations, self.model_type)

        self.cache["prices"] += total_tokens * self.token_unit_price
        return generations
    
    def _parse_plans(self, response):
        plans = []
        for line in response.splitlines():
            if line.startswith("Plan:"):
                plans.append(line)
            elif "plan" in line:
                plans.append(line)
        return plans
    
    def _parse_planner_evidences(self, response):
        evidences = {}
        for line in response.splitlines():
            if line.startswith("#") and line[1] == "E" and line[2].isdigit():
                e, tool_call = line.split("=", 1)
                e, tool_call = e.strip(), tool_call.strip()
                if len(e) == 3:
                    evidences[e] = tool_call
                else:
                    evidences[e] = "No evidence found"
        return evidences

    def predict_responses(self):
        # init the evaluation metrics
        count, rl_scores, f1_scores, avg_bleu_scores = 0, [], [], []
        prices, evaluation_elo = [], []

        for sample in self.prompt_constructor.examples[:200]:
            count += 1 # the number of current results
            self._reinitialize()

            dialogue = sample["context"]

            # init cache
            self.cache = {}
            self.cache["dialogue"] = sample["context"]
            self.cache["response"] = sample["response"]
            self.cache["prices"] = 0

            # persona and knowledge cands
            self.cache["persona_cands"] = sample["persona_cands"]
            self.cache["knowledge_cands"] = sample["knowledge_cands"]

            # grounded persona index
            self.cache["persona_grounding_indexes"] = [i for i in range(len(sample["persona_grounding"])) if sample["persona_grounding"][i] == True]
            self.cache["knowledge_grounding_indexes"] = sample["knowledge_index"]

            # [1] Reflect current user interest and used knowledge
            reflection_prompt = f"Dialogue: {dialogue}\n\nReflection: "
            full_reflection_prompt = ours_refection_prompt + "\n\n" + reflection_prompt # full prompt
            reflection = self.reflect_global(full_reflection_prompt)
            self.cache["thoughts"] = reflection

            # [2] Plan the modules
            test_prompt = f"Dialogue: {dialogue}\nReflection: {reflection}"
            planner_prompt = ours_modules_prompt + "\n\n" + test_prompt # full prompt
            plan = self.plan(planner_prompt)
            self.plans = self._parse_plans(plan)
            self.planner_evidences = self._parse_planner_evidences(plan)
            # assert len(self.plans) == len(self.planner_evidences)

            # [3] Do
            self._get_worker_evidences()
            worker_log = ""
            for i in range(len(self.planner_evidences)):
                e = f"#E{i + 1}"
                worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n"
            
            solver_prompt = SOLVER_DEFAULT_PREFIX + f"Dialogue: {dialogue}" + "\n" + worker_log + SOLVER_DEFAULT_SUFFIX + f"Dialogue: {dialogue}" + '\n'
            generations = self.solve(solver_prompt).replace("SYSTEM:", "").strip()

            # [4] Evaluate the results
            # normalize the number in the text
            f1_scores.append(f1(generations, self.cache["response"]))
            rl_scores.append(rl(generations, self.cache["response"]))
            prices.append(self.cache["prices"])
            avg_bleu_scores.append(avg_bleu(generations, self.cache["response"]))

            # get grounding persona and knowledge
            grounding_persona = " ".join([self.cache["persona_cands"][i] for i in self.cache["persona_grounding_indexes"]])
            grounding_knowledge = self.cache["knowledge_cands"][self.cache["knowledge_grounding_indexes"]]

            evaluation_elo.append({
                "context": sample["context"],
                "grounding_persona": grounding_persona,
                "grounding_knowledge": grounding_knowledge,
                "generation": generations
            })

        
        # save the final evaluation results
        f1_score, rl_score, avg_bleu_score = np.mean(f1_scores), np.mean(rl_scores), np.mean(avg_bleu_scores)
        result = {
            "method": "ours",
            "model_type": self.model_type,
            "retriever_type": self.retriever.model_name,
            "persona_right_count": self.persona_right_count,
            "knowledge_right_count": self.knowledge_right_count,
            "call_retrieval_times": self.call_retrieval_times,
            "f1_score": f1_score,
            "rl_score": rl_score,
            "avg_bleu_score": avg_bleu_score,
            "price": sum(prices)
        }

        with open("./exp_output/" + self.dataset + "/" + self.model_type + "_ours_result.json", "w", encoding="utf-8") as f:
            json.dump(result, f, indent=4, ensure_ascii=False)
        
        with open("./response_output/" + self.dataset + "/" + self.model_type + "_ours_result.json", "w", encoding="utf-8") as f:
            json.dump(evaluation_elo, f, indent=4, ensure_ascii=False)

    
    # use planner evidences to assign tasks to respective workers.
    def _get_worker_evidences(self):
        for e, tool_call in self.planner_evidences.items():
            if "[" not in tool_call:
                self.worker_evidences[e] = tool_call
                continue
            tool, tool_input = tool_call.split("[", 1)
            tool_input = tool_input[:-1]
            # find variables in input and replace with previous evidences
            for var in re.findall(r"#E\d+", tool_input):
                if var in self.worker_evidences:
                    tool_input = tool_input.replace(var, self.worker_evidences[var])
            
            if tool == "KNOWLEDGE":
                retrieved_result = self.knowledge_retrieval(tool_input)
                self.worker_evidences[e] = retrieved_result
            elif tool == "PERSONA":
                retrieved_result = self.persona_retrieval(tool_input)
                self.worker_evidences[e] = retrieved_result
            else:
                self.worker_evidences[e] = "No evidence found"
    
    def retrieve_external_knowledge(self, query, knowledge_bases, number_results):
        top_indexes, retrieved_res = self.retriever.retrieve_top_n(query, knowledge_bases, number=number_results)
        self.call_retrieval_times += 1
        return top_indexes, retrieved_res
    
    def persona_retrieval(self, input):
        # init dialogue, persona_cands, persona_indexes
        dialogue = self.cache["dialogue"]
        persona_cands = self.cache["persona_cands"]
        persona_indexes = self.cache["persona_grounding_indexes"]

        if "context" in input:
            retrieved_p_index, retrieved_persona = self.retrieve_external_knowledge(dialogue + " " + self.cache["thoughts"], persona_cands, 1)
        else:
            retrieved_p_index, retrieved_persona = self.retrieve_external_knowledge(dialogue + " " + input + " " + self.cache["thoughts"], persona_cands, 1)
        
        if len(list(set(retrieved_p_index) & set(persona_indexes))) > 0:
            self.persona_right_count += 1
        
        # update the cache
        self.cache["persona_retriever:output"] = " ".join(retrieved_persona)
        return " ".join(retrieved_persona)
    
    def knowledge_retrieval(self, input):
        # init dialogue, knowledge_cands, knowledge_index
        dialogue = self.cache["dialogue"]
        knowledge_cands = self.cache["knowledge_cands"]
        knowledge_grounding_indexes = self.cache["knowledge_grounding_indexes"]

        user_last_turn = dialogue.split("USER:")[-1].strip()

        if "context" in input:
            retrieved_k_index, retrieved_knowledge = self.retrieve_external_knowledge(dialogue + " " + self.cache["thoughts"], knowledge_cands, 1)
        else:
            retrieved_k_index, retrieved_knowledge = self.retrieve_external_knowledge(dialogue + " " + input + " " + self.cache["thoughts"], knowledge_cands, 1)
        
        if retrieved_k_index.tolist()[0] == knowledge_grounding_indexes:
            self.knowledge_right_count += 1
        
        # update the cache
        self.cache["knowledge_retriever:output"] = " ".join(retrieved_knowledge)
        return " ".join(retrieved_knowledge)
    
    def _reinitialize(self):
        self.plans = []
        self.planner_evidences = {}
        self.worker_evidences = {}