# coding: utf-8
import json
import sys
sys.path.append('./')
from dataset import *
from evaluate import *
from prompts.psyqa.prompt import *
from prompts.mathdial.prompt import *
from prompts.tutoring.cima.prompt import *

from tools.utils import *
from tools.retriever import *
from app.sample_resp_cralwer import SampleRespCrawler
from app.chatgpt_resp_cralwer import ChatgptRespCrawler
from app.huggingface_resp_crawler import HuggingFaceRespCrawlerBase

class CoT:
    def __init__(self, key_path, dataset, model_type="chatgpt", model_path="", language="", setting="zero-shot",
            temperature=0.7, persona="", top_p=0.95):
        self.key_path = key_path
        self.dataset = dataset
        self.model_type = model_type
        self.model_path = model_path

        self.token_unit_price = get_token_unit_price(model_type)

        self.setting = setting
        if self.dataset == "hotpotqa":
            data_paths = ["./dataset/hotpotQA/hotpot_dev_distractor_v1.json"]
            demo_path = "./cot_retrieval/config/prompt_en.json"
            self.prompt_constructor = HotpotQA(data_paths, demo_path=demo_path, mode="dev")
        elif self.dataset == "psyqa":
            data_paths = ["./dataset/PsyQA/psyqa_test.json"]
            self.prompt_constructor = PsyQA(data_paths, mode="dev")
            self.langugae = "cn"
        elif self.dataset == "cima":
            data_paths = ["./dataset/tutoring/CIMA_test.json"]
            self.prompt_constructor = StrategyTutoring(data_paths, mode="dev")
            self.langugae = "en"

        self.init_model_type(temperature, persona, top_p, model_type)
    
    def init_model_type(self, temperature, persona, top_p, model):
        if self.model_type in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003", "gpt-4-0613"]:
            self.sample_crawler = ChatgptRespCrawler(self.key_path, temperature, persona=persona, top_p=top_p, model=model)
        else:
            self.sample_crawler = HuggingFaceRespCrawlerBase(self.model_type, self.model_path, top_p=top_p, temperature=temperature)

    def predict_responses(self):
        # init the evaluation metrics
        count, rl_scores, f1_scores, avg_bleu_scores = 0, [], [], []
        d1_scores, d2_scores, prices = [], [], []
        bleu_scores, bert_scores = [], []
        evaluation_elo = []

        for sample in self.prompt_constructor.examples:
            count += 1 # the number of current results

            # init cache
            self.cache = {}
            self.cache["prices"] = 0
            if self.dataset == "psyqa":
                question = sample["question"]
                desc = sample["desc"]
                self.cache["question"] = sample["question"]
                self.cache["description"] = sample["desc"]
                self.cache["response"] = sample["answer"]

                # generate the response
                if self.setting == "zero-shot":
                    test_prompt = f"Question: {question}\nDescription: {desc}\nResponse: "
                    full_prompt = psyqa_zero_shot_cot_prompt + "\n\n" + test_prompt # full prompt
                else:
                    test_prompt = f"Question: {question}\nDescription: {desc}\nResponse: "
                    full_prompt = psyqa_cot_prompt + "\n\n" + test_prompt # full prompt
            else:
                context = sample["context"]
                self.cache["context"] = sample["context"]
                self.cache["response"] = sample["references"]

                # generate the response
                if self.setting == "zero-shot":
                    test_prompt = f"Dialogue: {context}\nResponse: "
                    full_prompt = cima_zero_shot_cot_prompt + "\n\n" + test_prompt # full prompt
                else:
                    test_prompt = f"Dialogue: {context}\nResponse: "
                    full_prompt = cima_cot_prompt + "\n\n" + test_prompt # full prompt

            generations = self.sample_crawler.call_openai_each(full_prompt)
            total_tokens, generations = get_response_according_to_model_type(generations, self.model_type)
            self.cache["prices"] += total_tokens * self.token_unit_price
            prices.append(self.cache["prices"])

            # [3] Evaluate the results
            # normalize the number in the text
            if self.dataset == "psyqa":
                f1_scores.append(f1(generations, self.cache["response"], self.langugae))
                rl_scores.append(rl(generations, self.cache["response"], self.langugae))
                d1_scores.append(distinct_n_sentence_level(generations, 1, self.langugae))
                d2_scores.append(distinct_n_sentence_level(generations, 2, self.langugae))
                avg_bleu_scores.append(avg_bleu(generations, self.cache["response"], self.langugae))

                # bert_score, bleu_score = tutoring_scores(generations, self.cache["response"], self.langugae)
                # bleu_scores.append(bleu_score)
                # bert_scores.append(bert_score)

                evaluation_elo.append({
                    "question": question,
                    "desc": desc,
                    "answer": self.cache["response"],
                    "generation": generations
                })
            else:
                f1_scores.append(f1(generations, self.cache["response"], "en"))
                bert_score, bleu_score = tutoring_scores(generations, self.cache["response"], self.langugae)
                bleu_scores.append(bleu_score)
                bert_scores.append(bert_score)

                evaluation_elo.append({
                    "context": context,
                    "generation": generations
                })
        
        # save the final evaluation results
        if self.dataset == "psyqa":
            f1_score, rl_score, avg_bleu_score = np.mean(f1_scores), np.mean(rl_scores), np.mean(avg_bleu_scores)
            d1_score, d2_score = np.mean(d1_scores), np.mean(d2_scores)

            result = {
                "method": "cot",
                "model_type": self.model_type,
                "dataset": self.dataset,
                "rl_score": rl_score,
                "f1_score": f1_score,
                "d1_score": d1_score,
                "d2_score": d2_score,
                "avg_bleu_score": avg_bleu_score,
                "price": sum(prices),
                "bleu_score": np.mean(bleu_scores, dtype=np.float64),
                "bert_score": np.mean(bert_scores, dtype=np.float64)
            }
        else:
            f1_score = np.mean(f1_scores)
            result = {
                "method": "cot",
                "model_type": self.model_type,
                "dataset": self.dataset,
                "f1_score": f1_score,
                "price": sum(prices),
                "bleu_score": np.mean(bleu_scores, dtype=np.float64),
                "bert_score": np.mean(bert_scores, dtype=np.float64)
            }

        # dialogue as query: 2 30
        # change order of prompts: 2 33
        if self.setting == "zero-shot":
            with open("./exp_output/" + self.dataset + "/" + self.model_type + "_cot_zero_shot_result.json", "w", encoding="utf-8") as f:
                json.dump(result, f, indent=4, ensure_ascii=False)
        else:
            with open("./exp_output/" + self.dataset + "/" + self.model_type + "_cot_result.json", "w", encoding="utf-8") as f:
                json.dump(result, f, indent=4, ensure_ascii=False)

            with open("./response_output/" + self.dataset + "/cot_result.json", "w", encoding="utf-8") as f:
                json.dump(evaluation_elo, f, indent=4, ensure_ascii=False)
        

class CueCoT:
    def __init__(self, key_path, dataset, model_type="chatgpt", model_path="", language="", setting="zero-shot",
            temperature=0.7, persona="", top_p=0.95):
        self.key_path = key_path
        self.dataset = dataset
        self.model_type = model_type
        self.model_path = model_path
        self.langugae = language

        self.token_unit_price = get_token_unit_price(model_type)
        self.setting = setting
        if self.dataset == "hotpotqa":
            data_paths = ["./dataset/hotpotQA/hotpot_dev_distractor_v1.json"]
            demo_path = "./cot_retrieval/config/prompt_en.json"
            self.prompt_constructor = HotpotQA(data_paths, demo_path=demo_path, mode="dev")
        elif self.dataset == "psyqa":
            data_paths = ["./dataset/PsyQA/psyqa_test.json"]
            self.prompt_constructor = PsyQA(data_paths, mode="dev")
        elif self.dataset == "cima":
            data_paths = ["./dataset/tutoring/CIMA_test.json"]
            self.prompt_constructor = StrategyTutoring(data_paths, mode="dev")
            self.langugae = "en"

        self.retriever = Retriever("bm25")
        self.call_retrieval_times, self.persona_right_count, self.knowledge_right_count = 0, 0, 0
        self.init_model_type(temperature, persona, top_p, model_type)
    
    def init_model_type(self, temperature, persona, top_p, model):
        if self.model_type in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003"]:
            self.sample_crawler = ChatgptRespCrawler(self.key_path, temperature, persona=persona, top_p=top_p, model=model)
        else:
            self.sample_crawler = HuggingFaceRespCrawlerBase(self.model_type, self.model_path, top_p=top_p, temperature=temperature)

    def predict_responses(self):
        # init the evaluation metrics
        count, rl_scores, f1_scores, avg_bleu_scores = 0, [], [], []
        d1_scores, d2_scores, prices = [], [], []
        bleu_scores, bert_scores = [], []

        evaluation_elo = []
        for sample in self.prompt_constructor.examples[:200]:
            count += 1 # the number of current results

            # init cache
            self.cache = {}
            self.cache["prices"] = 0
            if self.dataset == "psyqa":
                question = sample["question"]
                desc = sample["desc"]
                self.cache["question"] = sample["question"]
                self.cache["description"] = sample["desc"]
                self.cache["response"] = sample["answer"]

                # infer the status
                status_prompt = f"Question: {question}\nDescription: {desc}\nStatus: "
                full_status_prompt = psyqa_cuecot_status_prompt + "\n\n" + status_prompt
            elif self.dataset == "cima":
                self.cache["context"] = sample["context"]
                self.cache["response"] = sample["references"]
                context = sample["context"]

                # infer the status
                status_prompt = f"Dialogue: {context}\nStatus: "
                full_status_prompt = cima_cuecot_status_prompt + "\n\n" + status_prompt

            
            user_status = self.sample_crawler.call_openai_each(full_status_prompt)
            total_tokens, user_status = get_response_according_to_model_type(user_status, self.model_type)
            self.cache["prices"] += total_tokens * self.token_unit_price

            # generate the response
            if self.dataset == "psyqa":
                test_prompt = f"Question: {question}\nDescription: {desc}\nStatus: {user_status}\nResponse: "
                full_prompt = psyqa_cuecot_response_prompt + "\n\n" + test_prompt # full prompt
            elif self.dataset == "cima":
                test_prompt = f"Dialogue: {context}\nStatus: {user_status}\nResponse: "
                full_prompt = cima_cuecot_response_prompt + "\n\n" + test_prompt # full prompt

            generations = self.sample_crawler.call_openai_each(full_prompt)
            total_tokens, generations = get_response_according_to_model_type(generations, self.model_type)
            self.cache["prices"] += total_tokens * self.token_unit_price

            # [3] Evaluate the results
            # normalize the number in the text
            f1_scores.append(f1(generations, self.cache["response"], "cn"))
            # rl_scores.append(rl(generations, self.cache["response"]))
            d1_scores.append(distinct_n_sentence_level(generations, 1, lang="cn"))
            prices.append(self.cache["prices"])
            d2_scores.append(distinct_n_sentence_level(generations, 2, lang="cn"))
            avg_bleu_scores.append(avg_bleu(generations, self.cache["response"], lang="cn"))

            if self.dataset == "cima":
                f1_scores.append(f1(generations, self.cache["response"], self.langugae))
                rl_scores.append(rl(generations, self.cache["response"], self.langugae))
                bert_score, bleu_score = tutoring_scores(generations, self.cache["response"], self.langugae)
                bleu_scores.append(bleu_score)
                bert_scores.append(bert_score)

                evaluation_elo.append({
                    "context": sample["context"],
                    "generations": generations
                })
            elif self.dataset == "psyqa":
                evaluation_elo.append({
                    "question": sample["question"],
                    "desc": sample["desc"],
                    "answer": sample["answer"],
                    "generation": generations
                })
        
        # save the final evaluation results
        f1_score, rl_score, avg_bleu_score = np.mean(f1_scores), np.mean(rl_scores), np.mean(avg_bleu_scores)
        d1_score, d2_score = np.mean(d1_scores), np.mean(d2_scores)

        result = {
            "method": "cuecot",
            "model_type": self.model_type,
            "f1_score": f1_score,
            "d1_score": d1_score,
            "d2_score": d2_score,
            "avg_bleu_score": avg_bleu_score,
            "bleu_score": np.mean(bleu_scores, dtype=np.float64),
            "bert_score": np.mean(bert_scores, dtype=np.float64),
            "price": sum(prices)
        }
        # dialogue as query: 2 30
        # change order of prompts: 2 33
        with open("./exp_output/" + self.dataset + "/" + self.model_type + "_cuecot_result.json", "w", encoding="utf-8") as f:
            json.dump(result, f, indent=4, ensure_ascii=False)

        with open("./response_output/" + self.dataset + "/" + self.model_type + "_cuecot_result.json", "w", encoding="utf-8") as f:
            json.dump(evaluation_elo, f, indent=4, ensure_ascii=False)

class Chameleon:
    def __init__(self, key_path, dataset, model_type="chatgpt", model_path="", language="", setting="zero-shot",
            temperature=0.7, persona="", top_p=0.95):
        self.key_path = key_path
        self.dataset = dataset
        self.model_type = model_type
        self.model_path = model_path
        self.token_unit_price = get_token_unit_price(model_type)

        self.setting = setting
        if self.dataset == "hotpotqa":
            data_paths = ["./dataset/hotpotQA/hotpot_dev_distractor_v1.json"]
            demo_path = "./cot_retrieval/config/prompt_en.json"
            self.prompt_constructor = HotpotQA(data_paths, demo_path=demo_path, mode="dev")
        elif self.dataset == "psyqa":
            data_paths = ["./dataset/PsyQA/psyqa_test.json"]
            self.prompt_constructor = PsyQA(data_paths, mode="dev")
            self.language = "cn"
        elif self.dataset == "cima":
            data_paths = ["./dataset/tutoring/CIMA_test.json"]
            self.prompt_constructor = StrategyTutoring(data_paths, mode="dev")
            self.language = "en"
        
        self.init_model_type(temperature, persona, top_p, model_type)
    
    def init_model_type(self, temperature, persona, top_p, model):
        if self.model_type in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003"]:
            self.sample_crawler = ChatgptRespCrawler(self.key_path, temperature, persona=persona, top_p=top_p, model=model)
        else:
            self.sample_crawler = HuggingFaceRespCrawlerBase(self.model_type, self.model_path, top_p=top_p, temperature=temperature)
    
    def approval_reassurance(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nApproval and Reassurance: "
        full_prompt = psyqa_chameleon_ar_prompt + "\n\n" + test_prompt

        ar = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, ar = get_response_according_to_model_type(ar, self.model_type)

        # update the cache
        self.cache["ar:output"] = ar
        self.cache["prices"] += total_tokens * self.token_unit_price
        return ar
    
    def interpretation(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nInterpretation: "
        full_prompt = psyqa_chameleon_interpretation_prompt + "\n\n" + test_prompt

        interpretation = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, interpretation = get_response_according_to_model_type(interpretation, self.model_type)

        # update the cache
        self.cache["interpretation:output"] = interpretation
        self.cache["prices"] += total_tokens * self.token_unit_price
        return interpretation
    
    def direct_guidance(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nDirect Guidance: "
        full_prompt = psyqa_chameleon_interpretation_prompt + "\n\n" + test_prompt

        direct_guidance = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, direct_guidance = get_response_according_to_model_type(direct_guidance, self.model_type)

        # update the cache
        self.cache["direct_guidance:output"] = direct_guidance
        self.cache["prices"] += total_tokens * self.token_unit_price
        return direct_guidance

    def inform(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nInformation: "
        full_prompt = psyqa_chameleon_information_prompt + "\n\n" + test_prompt

        information = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, information = get_response_according_to_model_type(information, self.model_type)

        # update the cache
        self.cache["information:output"] = information
        self.cache["prices"] += total_tokens * self.token_unit_price
        return information
    
    def restatement(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nInformation: "
        full_prompt = psyqa_chameleon_restatement_prompt + "\n\n" + test_prompt

        restatement = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, restatement = get_response_according_to_model_type(restatement, self.model_type)

        # update the cache
        self.cache["restatement:output"] = restatement
        self.cache["prices"] += total_tokens * self.token_unit_price
        return restatement
    
    def disclosure(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nSelf-disclosure: "
        full_prompt = psyqa_chameleon_self_disclosure_prompt + "\n\n" + test_prompt

        disclosure = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, disclosure = get_response_according_to_model_type(disclosure, self.model_type)

        # update the cache
        self.cache["disclosure:output"] = disclosure
        self.cache["prices"] += total_tokens * self.token_unit_price
        return disclosure
    
    def others_psyqa(self):
        # init question, desc
        question = self.cache["question"]
        desc = self.cache["description"]
        test_prompt = f"Question: {question}\nDescription: {desc}\nOthers: "
        full_prompt = psyqa_chameleon_others_prompt + "\n\n" + test_prompt

        others = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, others = get_response_according_to_model_type(others, self.model_type)

        # update the cache
        self.cache["others:output"] = others
        self.cache["prices"] += total_tokens * self.token_unit_price
        return others

    def others_cima(self):
        # init question, desc
        context = self.cache["context"]
        test_prompt = f"Dialogue: {context}\nOthers: "
        full_prompt = cima_chameleon_others_prompt + "\n\n" + test_prompt

        others = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, others = get_response_according_to_model_type(others, self.model_type)

        # update the cache
        self.cache["others:output"] = others
        self.cache["prices"] += total_tokens * self.token_unit_price
        return others
    
    def hint(self):
        # init question, desc
        context = self.cache["context"]
        test_prompt = f"Dialogue: {context}\nHint: "
        full_prompt = cima_chameleon_hint_prompt + "\n\n" + test_prompt

        hints = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, hints = get_response_according_to_model_type(hints, self.model_type)

        # update the cache
        self.cache["hint:output"] = hints
        self.cache["prices"] += total_tokens * self.token_unit_price
        return hints

    def question(self):
        # init question, desc
        context = self.cache["context"]
        test_prompt = f"Dialogue: {context}\nQuestion: "
        full_prompt = cima_chameleon_question_prompt + "\n\n" + test_prompt

        question = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, question = get_response_according_to_model_type(question, self.model_type)

        # update the cache
        self.cache["question:output"] = question
        self.cache["prices"] += total_tokens * self.token_unit_price
        return question
    
    def confirmation(self):
        # init question, desc
        context = self.cache["context"]
        test_prompt = f"Dialogue: {context}\nConfirmation: "
        full_prompt = cima_chameleon_confirmation_prompt + "\n\n" + test_prompt

        confirmations = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, confirmations = get_response_according_to_model_type(confirmations, self.model_type)

        # update the cache
        self.cache["confirmations:output"] = confirmations
        self.cache["prices"] += total_tokens * self.token_unit_price
        return confirmations
    
    def correction(self):
        # init question, desc
        context = self.cache["context"]
        test_prompt = f"Dialogue: {context}\nCorrection: "
        full_prompt = cima_chameleon_correction_prompt + "\n\n" + test_prompt

        corrections = self.sample_crawler.call_openai_each(full_prompt, max_tokens=32)
        total_tokens, corrections = get_response_according_to_model_type(corrections, self.model_type)

        # update the cache
        self.cache["corrections:output"] = corrections
        self.cache["prices"] += total_tokens * self.token_unit_price
        return corrections
    
    def predict_modules(self, prompt):
        if self.dataset == "psyqa":
            default_modules = ["Approval and Reassurance", "Interpretation", "Direct Guidance", "Interpretation", "Direct Guidance"]
        elif self.dataset == "cima":
            default_modules = ["Question"]
        
        modules = self.sample_crawler.call_openai_each(prompt)
        total_tokens, modules = get_response_according_to_model_type(modules, self.model_type)
        self.cache["prices"] += total_tokens * self.token_unit_price
        try:
            modules = eval(modules.lower().strip())
        except:
            modules = default_modules

        # mapping the strategy to the functions
        functions = []
        for module in modules:
            if "approval and reassurance" in module:
                functions.append("approval_reassurance")
            elif "interpretation" in modules:
                functions.append("interpretation")
            elif "direct_guidance" in modules:
                functions.append("direct_guidance")
            elif "information" in modules:
                functions.append("inform")
            elif "restatement" in modules:
                functions.append("restatement")
            elif "self-disclosure" in modules:
                functions.append("disclosure")
            elif "hint" in modules:
                functions.append("hint")
            elif "question" in modules:
                functions.append("question")
            elif "correction" in modules:
                functions.append("correction")
            elif "confirmation" in modules:
                functions.append("confirmation")
            else:
                if self.dataset == "psyqa":
                    functions.append("others_psyqa")
                elif self.dataset == "cima":
                    functions.append("others_cima")

        return functions

    def predict_responses(self):
        # init the evaluation metrics
        count, rl_scores, f1_scores, avg_bleu_scores = 0, [], [], []
        d1_scores, d2_scores, prices = [], [], []
        bleu_scores, bert_scores = [], []

        evaluation_elo = []
        for sample in self.prompt_constructor.examples[:200]:
            count += 1 # the number of current results

            # init cache
            self.cache = {}
            self.cache["prices"] = 0
            if self.dataset == "psyqa":
                question = sample["question"]
                desc = sample["desc"]
                self.cache["question"] = sample["question"]
                self.cache["description"] = sample["desc"]
                self.cache["response"] = sample["answer"]

                # [1] Predict the modules
                test_prompt = f"Question: {question}\nDescription: {desc}\nStrategies: "
                full_prompt = psyqa_chameleon_modules_prompt + "\n\n" + test_prompt # full prompt
            elif self.dataset == "cima":
                context = sample["context"]
                self.cache["context"] = sample["context"]
                self.cache["response"] = sample["references"]

                # [1] Predict the modules
                test_prompt = f"Dialogue: {context}\nStrategies: "
                full_prompt = cima_chameleon_modules_prompt + "\n\n" + test_prompt # full prompt


            modules = self.predict_modules(full_prompt)
            modules = [f"self.{module}" for module in modules]
            
            # [2] Execute the modules 
            generations = ""
            for module in modules:
                output = eval(module)() # eval the module and update the cache
                if count < 4:
                    print(f"======== [Strategy]: {module} ========\n")
                    # print(f"# [Input]\n{input}\n")
                    print(f"# [Output]\n{output}\n")
                
                generations += output

            # [3] Evaluate the results
            # normalize the number in the text
            f1_scores.append(f1(generations, self.cache["response"], self.language))
            rl_scores.append(rl(generations, self.cache["response"], self.language))
            d1_scores.append(distinct_n_sentence_level(generations, 1, lang=self.language))
            d2_scores.append(distinct_n_sentence_level(generations, 2, lang=self.language))
            prices.append(self.cache["prices"])
            # avg_bleu_scores.append(avg_bleu(generations, self.cache["response"], lang=self.language))
            bert_score, bleu_score = tutoring_scores(generations, self.cache["response"], self.language)
            bleu_scores.append(bleu_score)
            bert_scores.append(bert_score)


            if self.dataset == "psyqa":
                evaluation_elo.append({
                    "question": sample["question"],
                    "desc": sample["desc"],
                    "answer": sample["answer"],
                    "generation": generations
                })
            elif self.dataset == "cima":
                evaluation_elo.append({
                    "context": sample["context"],
                    "generation": generations
                })

        # save the final evaluation results
        f1_score, rl_score, avg_bleu_score = np.mean(f1_scores), np.mean(rl_scores), np.mean(avg_bleu_scores)
        d1_score, d2_score = np.mean(d1_scores), np.mean(d2_scores)

        result = {
            "method": "chameleon",
            "model_type": self.model_type,
            "f1_score": f1_score,
            "d1_score": d1_score,
            "d2_score": d2_score,
            "avg_bleu_score": avg_bleu_score,
            "price": sum(prices),
            "bleu_score": np.mean(bleu_scores, dtype=np.float64),
            "bert_score": np.mean(bert_scores, dtype=np.float64)
        }
        # dialogue as query: 2 30
        # change order of prompts: 2 33
        with open("./exp_output/" + self.dataset + "/chameleon_result.json", "w", encoding="utf-8") as f:
            json.dump(result, f, indent=4, ensure_ascii=False)

        with open("./response_output/" + self.dataset + "/chameleon_result.json", "w", encoding="utf-8") as f:
            json.dump(evaluation_elo, f, indent=4, ensure_ascii=False)

       