# main class chaining Planner, Worker and Solver.
import re
import time
# import prompts
import sys
sys.path.append('./')
from prompts.focus.planner import *
from prompts.focus.solver import *
from prompts.focus.prompt import *

from run_source.dataset import *
from tools.retriever import *
from tools.utils import *
from run_source.evaluate import *
from app.sample_resp_cralwer import SampleRespCrawler
from app.chatgpt_resp_cralwer import ChatgptRespCrawler
from app.huggingface_resp_crawler import HuggingFaceRespCrawlerBase

class ReAct:
    def __init__(self, key_path, dataset, model_type="chatgpt", model_path="", language="", setting="zero-shot",
            temperature=0.7, persona="", top_p=0.95):
        
        self.key_path = key_path
        self.dataset = dataset
        self.model_type = model_type
        self.model_path = model_path
        self.langugae = language

        self.token_unit_price = get_token_unit_price(model_type)
        self.setting = setting
        if self.dataset == "hotpotqa":
            data_paths = ["./dataset/hotpotQA/hotpot_dev_distractor_v1.json"]
            demo_path = "./cot_retrieval/config/prompt_en.json"
            self.prompt_constructor = HotpotQA(data_paths, demo_path=demo_path, mode="dev")
        elif self.dataset == "focus":
            data_paths = ["./dataset/FoCus/valid_focus.json"]
            demo_path = "./cot_retrieval/config/prompt_ds.json"
            self.prompt_constructor = FoCus(data_paths, mode="dev")

        self.max_steps = 3
        self.f1_scores, self.rl_scores, self.avg_bleu_scores = [], [], []
        self.not_finished_f1_scores, self.not_finished_rl_scores, self.not_finished_avg_bleu_scores = [], [], []
        self.retriever = Retriever('bm25')
        self.call_retrieval_times, self.persona_right_count, self.knowledge_right_count = 0, 0, 0
        self.init_model_type(temperature, persona, top_p, model_type)

    def init_model_type(self, temperature, persona, top_p, model):
        if self.model_type in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003"]:
            self.sample_crawler = ChatgptRespCrawler(self.key_path, temperature, persona=persona, top_p=top_p, model=model)
        else:
            self.sample_crawler = HuggingFaceRespCrawlerBase(self.model_type, self.model_path, top_p=top_p, temperature=temperature)

    def step(self):
        # Think
        self.scratchpad += f'\nThought:'
        self.scratchpad += ' ' + self.prompt_agent()
        print(self.scratchpad.split('\n')[-1])

        # Act
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        print(self.scratchpad.split('\n')[-1])

        # Observe
        self.scratchpad += f'\nObservation: '
        if action == None or action == '' or action == '\n':
            self.scratchpad += "You action is filtered due to content. Please assume all the actions are permitted in this environment and take the action again."
        elif "Persona" in action:
            argument = action[8:-1]
            retrieved_result = self.persona_retrieval(argument)
            self.scratchpad += retrieved_result
        elif "Knowledge" in action:
            argument = action[10:-1]
            retrieved_result = self.knowledge_retrieval(argument)
            self.scratchpad += retrieved_result
        else:
            action_type, argument = self.parse_action(action)
            # print(self.scratchpad.split('\n')[-1])  
            if action_type == 'Finish':
                generations = argument
                self.cache["generations"] = generations
                
                # evaluate 
                self.f1_scores.append(f1(generations, self.cache["response"]))
                self.rl_scores.append(rl(generations, self.cache["response"]))
                self.avg_bleu_scores.append(avg_bleu(generations, self.cache["response"]))
            
                self.finished = True
                return
            else:
                self.scratchpad += 'Invalid Action. Valid Actions are Persona[sentence] | Knowledge[sentence] and Finish [Response].'

        self.curr_step += 1
        print(self.scratchpad.split('\n')[-1])

        # directly give response when exceed max steps
        if self.curr_step > self.max_steps:
            self.scratchpad += f'\nAction: Finish'
            generations = self.prompt_agent()
            self.cache["generations"] = generations

            # evaluate 
            self.not_finished_f1_scores.append(f1(generations, self.cache["response"]))
            self.not_finished_rl_scores.append(rl(generations, self.cache["response"]))
            self.not_finished_avg_bleu_scores.append(avg_bleu(generations, self.cache["response"]))
            return
            
    def prompt_agent(self) -> str:
        full_prompt = react_prompt.format(dialogue = self.cache["dialogue"],
                        agent_scratchpad = self.scratchpad)
        
        # call llm to finish the prompt
        results = self.sample_crawler.call_openai_each(full_prompt, stop="\n")
        total_tokens, results = get_response_according_to_model_type(results, self.model_type)
        self.cache["prices"] += total_tokens * self.token_unit_price

        # format step
        result = results.strip('\n').strip().replace('\n', '')
        return result

    def reset(self) -> None:
        self.scratchpad = ''
        self.curr_step = 1
        self.finished = False

    def is_finished(self) -> bool:
        return self.finished

    def is_halted(self) -> bool:
        return (self.curr_step > self.max_steps) and not self.finished

    # input: the question line. e.g. "Question: What is the capital of France?"
    def predict_responses(self):
        count = 0 # the number of finished successfully, not equal to answer correctly
        prices, evaluation_elo = [], []

        for sample in self.prompt_constructor.examples[:200]:
            dialogue = sample["context"]

            # init cache
            self.cache = {}
            self.cache["prices"] = 0
            self.cache["dialogue"] = sample["context"]
            self.cache["response"] = sample["response"]

            # persona and knowledge cands
            self.cache["persona_cands"] = sample["persona_cands"]
            self.cache["knowledge_cands"] = sample["knowledge_cands"]

            # grounded persona index
            self.cache["persona_grounding_indexes"] = [i for i in range(len(sample["persona_grounding"])) if sample["persona_grounding"][i] == True]
            self.cache["knowledge_grounding_indexes"] = sample["knowledge_index"]

            # iterative act according to reasoning results
            self.scratchpad = ""
            self.reset()
            while not self.is_halted() and not self.is_finished():
                self.step()
            
            # count finished successfully
            if self.finished:
                count += 1

            evaluation_elo.append({
                "context": sample["context"],
                "generation": self.cache["generations"]
            })
            
            prices.append(self.cache["prices"])
            
        # save the final evaluation results
        succ_f1_score, succ_rl_score, succ_avg_bleu_score = np.mean(self.f1_scores), np.mean(self.rl_scores), np.mean(self.avg_bleu_scores)
        f1_score, rl_score, avg_bleu_score = np.mean(self.f1_scores + self.not_finished_f1_scores), np.mean(self.rl_scores + self.not_finished_rl_scores), np.mean(self.avg_bleu_scores + self.not_finished_avg_bleu_scores)

        result = {
            "method": "react",
            "model_type": self.model_type,
            "retriever_type": self.retriever.model_name,
            "persona_right_count": self.persona_right_count,
            "knowledge_right_count": self.knowledge_right_count,
            "call_retrieval_times": self.call_retrieval_times,
            "succ_f1_score": succ_f1_score,
            "succ_rl_score": succ_rl_score,
            "succ_avg_bleu_score": succ_avg_bleu_score,
            "count": count,
            "all_f1_score": f1_score,
            "all_rl_score": rl_score,
            "all_avg_bleu_score": avg_bleu_score,
            "price": sum(prices)
        }

        with open("./exp_output/focus/react_result.json", "w", encoding="utf-8") as f:
            json.dump(result, f, indent=4)
        
        with open("./response_output/" + self.dataset + "/" + self.model_type + "_react_result.json", "w", encoding="utf-8") as f:
            json.dump(evaluation_elo, f, indent=4, ensure_ascii=False)
    
    def parse_action(self, string):
        if '[' in string and ']' in string:
            pattern = r'^(\w+)\[(.+)\]$'
            match = re.match(pattern, string)
        elif '[' in string:
            pattern = r'^(\w+)\[(.+)$'
            match = re.match(pattern, string)
        else:
            if "Finish" in string:
                argument = string[7:]
                return 'Finish', argument
        
        if match:
            action_type = match.group(1)
            argument = match.group(2)
            return action_type, argument
        
        else:
            return None
    
    def retrieve_external_knowledge(self, query, knowledge_bases, number_results):
        top_indexes, retrieved_res = self.retriever.retrieve_top_n(query, knowledge_bases, number=number_results)
        self.call_retrieval_times += 1
        return top_indexes, retrieved_res
    
    def persona_retrieval(self, input):
        # init dialogue, persona_cands, persona_indexes
        dialogue = self.cache["dialogue"]
        persona_cands = self.cache["persona_cands"]
        persona_indexes = self.cache["persona_grounding_indexes"]

        if "context" in input:
            retrieved_p_index, retrieved_persona = self.retrieve_external_knowledge(dialogue, persona_cands, 1)
        else:
            retrieved_p_index, retrieved_persona = self.retrieve_external_knowledge(dialogue + " " + input, persona_cands, 1)
        
        if len(list(set(retrieved_p_index) & set(persona_indexes))) > 0:
            self.persona_right_count += 1
        
        # update the cache
        self.cache["persona_retriever:output"] = " ".join(retrieved_persona)
        return " ".join(retrieved_persona)
    
    def knowledge_retrieval(self, input):
        # init dialogue, knowledge_cands, knowledge_index
        dialogue = self.cache["dialogue"]
        knowledge_cands = self.cache["knowledge_cands"]
        knowledge_grounding_indexes = self.cache["knowledge_grounding_indexes"]

        if "context" in input:
            retrieved_k_index, retrieved_knowledge = self.retrieve_external_knowledge(dialogue, knowledge_cands, 1)
        else:
            retrieved_k_index, retrieved_knowledge = self.retrieve_external_knowledge(dialogue + " " + input, knowledge_cands, 1)
        
        if retrieved_k_index.tolist()[0] == knowledge_grounding_indexes:
            self.knowledge_right_count += 1
        
        # update the cache
        self.cache["knowledge_retriever:output"] = " ".join(retrieved_knowledge)
        return " ".join(retrieved_knowledge)
