import os
import json

class Dataset:
    def __init__(self, setting) -> None:
        self.setting = setting
    
    def output_prompt(self, output_path):
        assert len(self.prompts) > 0
        # 如果输出目录不存在，则创建目录
        if '/' in output_path:
            output_dir = '/'.join(output_path.split('/')[:-1])
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
                print("[CREATE DIRECTORY]", output_dir)
        with open(output_path, "w", encoding="utf-8") as f:
            for prompt in self.prompts:
                json.dump(prompt, f, ensure_ascii=False)
                f.write("\n")
        self.prompts = []
    
class FoCus(Dataset):
    def __init__(self, paths="", mode="") -> None:
        self.mode = mode
        self.paths = paths
        self.examples = self.load_data(paths, mode)
        self.prompts = []
    
    def load_data(self, files, mode):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_dialogues = json.load(f)

            overlap_knowledge_count, overlap_persona_count = 0, 0
            for dialogue in raw_dialogues["data"]:
                i = 1
                for j, sample in enumerate(dialogue["utterance"]):
                    session = sample["dialogue" + str(i)]

                    persona_cands = sample["persona_candidate"]
                    knowledge_cands = sample["knowledge_candidates"]
                    persona_grounding = sample["persona_grounding"]
                    knowledge_index = sample["knowledge_answer_index"]

                    # save previous used persona and knowledge
                    if j == 0:
                        previous_used_persona, previous_used_knowledge = [], []
                    else:
                        previous_used_persona = dialogue["utterance"][j-1]["persona_grounding"]
                        previous_used_knowledge = dialogue["utterance"][j-1]["knowledge_answer_index"]

                        if previous_used_persona == persona_grounding:
                            overlap_persona_count += 1
                        
                        if previous_used_knowledge == knowledge_index:
                            overlap_knowledge_count += 1

                    ins = {
                        "context": "\t".join(["USER: " + session[k] if k % 2 == 0 else "SYSTEM: " + session[k] for k in range(len(session[:-1]))]),
                        "response": session[-1],
                        "persona_cands": persona_cands,
                        "knowledge_cands": knowledge_cands,
                        "persona_grounding": persona_grounding,
                        "knowledge_index": knowledge_index,
                        "previous_used_persona": previous_used_persona,
                        "previous_used_knowledge": previous_used_knowledge
                    }

                    i += 1
                    examples.append(ins)
            
            print("There are {} overlap persona; {} overlap knowledge.".format(overlap_persona_count, overlap_knowledge_count))
        return examples



class HotpotQA(Dataset):
    def __init__(self, paths="", demo_path="", mode="") -> None:
        self.paths = paths
        self.mode = mode
        self.examples = self.construct_knowledge_boundary(paths, mode)
        self.prompts = []
        self.load_prompts(path="./cot_retrieval/config/prompt_en.json")
    
    def load_prompts(self, path, mode="qa"):
        with open(path, 'r') as f:
            prompt_config = json.load(f)

        self.acc_evaluation_prompt = prompt_config.get('qa_acc_evaluation_prompt')

    def construct_knowledge_boundary(self, files, mode):
        examples = []

        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_qas = [json.loads(line) for line in f.readlines()]
            
            for sample in raw_qas[:500]:
                question = sample["input"]
                answer = sample["output"][0]["answer"]
                
                examples.append({
                    "question": question,
                    "answer": answer
                })
        return examples

    def load_data(self, files, mode):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_qas = json.load(f)
        
            for sample in raw_qas[:500]:
                question = sample["question"]
                answer = sample["answer"]
                support_facts, wikis = [], []
                
                for support_title, number in sample["supporting_facts"]:
                    for title, documents in sample["context"]:
                        if support_title == title:
                            support_facts.append(title + " " + documents[number].strip())
                            break

                for title, documents in sample["context"]:
                    wikis.extend([title + " " + docu.strip() for docu in documents])

                examples.append({
                    "question": question,
                    "answer": answer,
                    "type": sample["type"],
                    "level": sample["level"],
                    "wiki": wikis,
                    "support_facts": list(set(support_facts))
                })
        return examples
    
    def construct_prompt(self, template="", setting="zero-shot"):
        if setting == "in-context":
            self.demo_pools = self.load_demo(self.demo_path)
            self.demo = self.demo_pools["q->a:demos"]
        
        for example in self.examples:
            question, answer, documens = example["question"], example["answer"], example["wiki"]
            
            instance = {}
            instance["prompt"] = template.format(question=question)
            instance["question"] = question
            instance["type"] = example["type"]
            instance["level"] = example["level"]
            instance["answer"] = example["answer"]
            if setting == "in-context":
                instance["demo"] = self.demo
            self.prompts.append(instance)

class KBP(Dataset):
    def __init__(self, paths="", demo_path="", mode=""):
        self.paths = paths
        self.demo_path = demo_path
        self.mode = mode
        self.examples = self.load_data(paths, mode)
        self.prompts = []

    def assemble_prompt(self, prompt="", prefix="", suffix="", dialogue="", description="", setting=""):
        instance = {}
        instance["prompt"] = prefix + prompt + suffix
        instance["context"] = dialogue
        if description:
            instance["description"] = description
        if setting == "in-context":
            instance["demo"] = self.demo_pools["c,p,k->r:demos"]
        self.prompts.append(instance)

    def load_demo(self, path):
        with open(path, 'r', encoding="utf-8") as f:
            return json.load(f)

    def construct_prompt(self, template="", setting="zero-shot"):
        if setting == "in-context":
            self.demo_pools = self.load_demo(self.demo_path)
            self.demo = self.demo_pools["c->p,k:demos"]
        
        for example in self.examples:
            context, sources, middle_results, ground_resp = example["context"], example["resources"], example["middle_result"], example["resp"]
            persona, knowledge = example["persona"], example["knowledge"]
            instance = {}
            instance["prompt"] = template.format(dialogue_history=context)
            instance["context"] = context
            instance["persona"] = persona
            instance["sources"] = sources
            instance["middle"] = middle_results
            if setting == "in-context":
                instance["demo"] = self.demo
            instance["knowledge"] = knowledge
            self.prompts.append(instance)
    
    def load_data(self, files, mode: str):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_dialogues = json.load(f)

            for sample in raw_dialogues:
                # regarding persona and knowledge as two different databases
                persona_database = sample["persona"]
                knowledge_database = sample["persona_kg"]
                # construct training examples
                context, history = "", ""
                for turn in sample["conversation"]:
                    user, sys, p_k = turn["U"], turn["S"], turn["P-K"]
                    each_utter = "用户：" + user + "\n" + "系统：" + sys + "\n"
                    context = history + "用户：" + user + "\n"
                    history += each_utter
                    commands, middle_results = "", ""
                    if len(p_k) == 0: # 没有使用
                        commands = "NULL"
                        ground_persona, ground_knowledge = [], []
                    elif len(p_k) == 1: # 人设
                        commands = "PERSONA"
                        middle_results = [persona_database[persona].replace(" ", "") for persona in p_k[0]]
                        ground_persona = [persona_database[persona].replace(" ", "") for persona in p_k[0]]
                        ground_knowledge = []
                    elif len(p_k) == 2: # 既有人设又有知识
                        commands = "PERSONA KNOWLEDGE"
                        middle_results = [persona_database[persona].replace(" ", "") for persona in p_k[0]]
                        ground_persona = [persona_database[persona].replace(" ", "") for persona in p_k[0]]
                        ground_knowledge = []
                        for per_knowledge in p_k[1]:
                            per, know = per_knowledge.split("-")
                            ground_knowledge.append(knowledge_database[per][know].replace(" ", ""))
                            middle_results.append(knowledge_database[per][know].replace(" ", ""))
                    
                    ins = {
                        "context": context,
                        "resources": commands,
                        "middle_result": "，".join(middle_results),
                        "resp": sys,
                        "used_persona": ground_persona,
                        "used_knowledge": ground_knowledge,
                        "persona": persona_database,
                        "knowledge": knowledge_database
                    }
                    examples.append(ins)
        return examples