# main class chaining Planner, Worker and Solver.
import re
import time
# import prompts

from run_strategy.dataset import *
from run_strategy.evaluate import *
from prompts.psyqa.prompt import *
from prompts.tutoring.cima.prompt import *

from tools.retriever import *
from tools.utils import *

from app.sample_resp_cralwer import SampleRespCrawler
from app.chatgpt_resp_cralwer import ChatgptRespCrawler
from app.huggingface_resp_crawler import HuggingFaceRespCrawlerBase

class ReAct:
    def __init__(self, key_path, dataset, model_type="chatgpt", model_path="", language="", setting="zero-shot",
            temperature=0.7, persona="", top_p=0.95):
        
        self.key_path = key_path
        self.dataset = dataset
        self.model_type = model_type
        self.model_path = model_path
        self.language = language
        self.token_unit_price = get_token_unit_price(model_type)

        self.setting = setting
        if self.dataset == "hotpotqa":
            data_paths = ["./dataset/hotpotQA/hotpot_dev_distractor_v1.json"]
            demo_path = "./cot_retrieval/config/prompt_en.json"
            self.prompt_constructor = HotpotQA(data_paths, demo_path=demo_path, mode="dev")
        elif self.dataset == "psyqa":
            data_paths = ["./dataset/PsyQA/psyqa_test.json"]
            self.prompt_constructor = PsyQA(data_paths, mode="dev")
            self.max_steps = 7
            self.language = "cn"
        elif self.dataset == "cima":
            data_paths = ["./dataset/tutoring/CIMA_test.json"]
            self.prompt_constructor = StrategyTutoring(data_paths, mode="dev")
            self.language = "en"
            self.max_steps = 3

        self.f1_scores, self.rl_scores, self.avg_bleu_scores = [], [], []
        self.d1_scores, self.d2_scores, self.bleu_scores, self.bert_scores = [], [], [], []
        self.init_model_type(temperature, persona, top_p, model_type)

    def init_model_type(self, temperature, persona, top_p, model):
        if self.model_type in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003"]:
            self.sample_crawler = ChatgptRespCrawler(self.key_path, temperature, persona=persona, top_p=top_p, model=model)
        else:
            self.sample_crawler = HuggingFaceRespCrawlerBase(self.model_type, self.model_path, top_p=top_p, temperature=temperature)

    def step(self):
        # Think
        self.scratchpad += f'\nThought:'
        self.scratchpad += ' ' + self.prompt_agent()
        print(self.scratchpad.split('\n')[-1])

        # Act
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        print(self.scratchpad.split('\n')[-1])

        if "Response" in action:  
            generations = self.answer
            self.cache["generations"] = generations
                
            # evaluate 
            self.f1_scores.append(f1(generations, self.cache["response"], "cn"))
            self.d1_scores.append(distinct_n_sentence_level(generations, 1, lang="cn"))
            self.d2_scores.append(distinct_n_sentence_level(generations, 2, lang="cn"))
            self.rl_scores.append(rl(generations, self.cache["response"], self.language))
            # self.avg_bleu_scores.append(avg_bleu(generations, self.cache["response"], "cn"))
            bert_score, bleu_score = tutoring_scores(generations, self.cache["response"], self.language)
            self.bleu_scores.append(bleu_score)
            self.bert_scores.append(bert_score)
            self.finished = True
            return

        # Observe
        self.scratchpad += f'\nObservation: '
        obs = self.prompt_agent()
        self.scratchpad += ' ' + obs
        self.answer += obs + " "

        self.curr_step += 1
        print(self.scratchpad.split('\n')[-1])

        # directly give response when exceed max steps
        if self.curr_step > self.max_steps:
            generations = self.answer
            self.cache["generations"] = generations

            # evaluate 
            # evaluate 
            self.f1_scores.append(f1(generations, self.cache["response"], self.language))
            self.d1_scores.append(distinct_n_sentence_level(generations, 1, lang=self.language))
            self.d2_scores.append(distinct_n_sentence_level(generations, 2, lang=self.language))
            self.rl_scores.append(rl(generations, self.cache["response"], self.language))
            # self.avg_bleu_scores.append(avg_bleu(generations, self.cache["response"], self.language))
            bert_score, bleu_score = tutoring_scores(generations, self.cache["response"], self.language)
            self.bleu_scores.append(bleu_score)
            self.bert_scores.append(bert_score)
            return

    def prompt_agent(self):
        # set prompt for different datasets
        if self.dataset == "psyqa":
            full_prompt = psyqa_react_prompt.format(question = self.cache["question"],
                    desc = self.cache["description"],
                    agent_scratchpad = self.scratchpad)
        elif self.dataset == "cima":
            full_prompt = cima_react_prompt.format(context = self.cache["context"],
                agent_scratchpad = self.scratchpad)
        
        # call llm to finish the prompt
        results = self.sample_crawler.call_openai_each(full_prompt, stop="\n")
        total_tokens, results = get_response_according_to_model_type(results, self.model_type)

        # format step
        result = results.strip('\n').strip().replace('\n', '')
        self.cache["prices"] += total_tokens * self.token_unit_price
        return result

    def reset(self) -> None:
        self.scratchpad = ''
        self.curr_step = 1
        self.answer = ""
        self.finished = False

    def is_finished(self) -> bool:
        return self.finished

    def is_halted(self) -> bool:
        return (self.curr_step > self.max_steps) and not self.finished
    
    # input: the question line. e.g. "Question: What is the capital of France?"
    def predict_responses(self):
        count = 0 # the number of finished successfully, not equal to answer correctly

        evaluation_elo, prices = [], []
        for sample in self.prompt_constructor.examples[:200]:
    
            # init cache
            self.cache = {}
            self.cache["prices"] = 0
            if self.dataset == "psyqa":
                self.cache["question"] = sample["question"]
                self.cache["description"] = sample["desc"]
                self.cache["response"] = sample["answer"]
            elif self.dataset == "cima":
                self.cache["context"] = sample["context"]
                self.cache["response"] = sample["references"]
                context = sample["context"]

            # iterative act according to reasoning results
            self.scratchpad = ""
            self.reset()
            while not self.is_halted() and not self.is_finished():
                self.step()
            
            # count finished successfully
            if self.finished:
                count += 1
            
            if self.dataset == "psyqa":
                evaluation_elo.append({
                    "question": sample["question"],
                    "desc": sample["desc"],
                    "answer": sample["answer"],
                    "generation": self.cache["generations"]
                })
            elif self.dataset == "cima":
                evaluation_elo.append({
                    "context": context,
                    "generation": self.cache["generations"]
                })

            prices.append(self.cache["prices"])

        # save the final evaluation results
        f1_score, rl_score, avg_bleu_score = np.mean(self.f1_scores), np.mean(self.rl_scores), np.mean(self.avg_bleu_scores)
        d1_score, d2_score = np.mean(self.d1_scores), np.mean(self.d2_scores)


        result = {
            "method": "react",
            "model_type": self.model_type,
            "all_f1_score": f1_score,
            "all_rl_score": rl_score,
            "all_d1_score": d1_score,
            "all_d2_score": d2_score,
            "all_avg_bleu_score": avg_bleu_score,
            "all_bleu_score": np.mean(self.bleu_scores, dtype=np.float64),
            "all_bert_score": np.mean(self.bert_scores, dtype=np.float64),
            "count": count,
            "price": sum(prices)
        }

        with open("./exp_output/" + self.dataset + "/react_result.json", "w", encoding="utf-8") as f:
            json.dump(result, f, indent=4, ensure_ascii=False)

        with open("./response_output/" + self.dataset + "/react_result.json", "w", encoding="utf-8") as f:
            json.dump(evaluation_elo, f, indent=4, ensure_ascii=False)
    