import os
import json

class Dataset:
    def __init__(self, setting) -> None:
        self.setting = setting
    
    def output_prompt(self, output_path):
        assert len(self.prompts) > 0
        # 如果输出目录不存在，则创建目录
        if '/' in output_path:
            output_dir = '/'.join(output_path.split('/')[:-1])
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
                print("[CREATE DIRECTORY]", output_dir)
        with open(output_path, "w", encoding="utf-8") as f:
            for prompt in self.prompts:
                json.dump(prompt, f, ensure_ascii=False)
                f.write("\n")
        self.prompts = []

class MathDial(Dataset):
    def __init__(self, paths="", mode="") -> None:
        self.mode = mode
        self.paths = paths
        self.examples = self.load_data(paths, mode)
        self.prompts = []
    
    def load_data(self, files, mode):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_dialogues = json.load(f)

            overlap_knowledge_count, overlap_persona_count = 0, 0
            for dialogue in raw_dialogues["data"]:
                i = 1
                for j, sample in enumerate(dialogue["utterance"]):
                    session = sample["dialogue" + str(i)]

                    persona_cands = sample["persona_candidate"]
                    knowledge_cands = sample["knowledge_candidates"]
                    persona_grounding = sample["persona_grounding"]
                    knowledge_index = sample["knowledge_answer_index"]

                    # save previous used persona and knowledge
                    if j == 0:
                        previous_used_persona, previous_used_knowledge = [], []
                    else:
                        previous_used_persona = dialogue["utterance"][j-1]["persona_grounding"]
                        previous_used_knowledge = dialogue["utterance"][j-1]["knowledge_answer_index"]

                        if previous_used_persona == persona_grounding:
                            overlap_persona_count += 1
                        
                        if previous_used_knowledge == knowledge_index:
                            overlap_knowledge_count += 1

                    ins = {
                        "context": "\t".join(["USER: " + session[k] if k % 2 == 0 else "SYSTEM: " + session[k] for k in range(len(session[:-1]))]),
                        "response": session[-1],
                        "persona_cands": persona_cands,
                        "knowledge_cands": knowledge_cands,
                        "persona_grounding": persona_grounding,
                        "knowledge_index": knowledge_index,
                        "previous_used_persona": previous_used_persona,
                        "previous_used_knowledge": previous_used_knowledge
                    }

                    i += 1
                    examples.append(ins)
            
            print("There are {} overlap persona; {} overlap knowledge.".format(overlap_persona_count, overlap_knowledge_count))
        return examples

class PsyQA(Dataset):
    def __init__(self, paths="", mode="") -> None:
        self.mode = mode
        self.paths = paths
        self.examples = self.load_data(paths, mode)
        self.prompts = []
    
    def load_data(self, files, mode):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_qa = json.load(f)

            for sample in raw_qa:
                question = sample["question"]
                answer = sample["answer"]["answer_text"]
                desc = sample["desc"]

                ins = {
                    "question": question,
                    "answer": answer,
                    "desc": desc 
                }

                examples.append(ins)
            
        return examples


class StrategyTutoring(Dataset):
    def __init__(self, paths="", mode="") -> None:
        self.mode = mode
        self.paths = paths
        self.examples = self.load_data(paths, mode)
        self.prompts = []
    
    def load_data(self, files, mode):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_qa = json.load(f)

            for sample in raw_qa:
                context = sample["context"]
                references = sample["references"]
                ref_responses, ref_strategies = [], []
                for ins in references:
                    strategy = ins["strategy"]
                    response = ins["response"]
                    ref_responses.append(response)
                    ref_strategies.append(strategy)

                ins = {
                    "context": "\t".join(context),
                    "references": ref_responses,
                    "ref_strategies": ref_strategies 
                }

                examples.append(ins)
            
        return examples

class HotpotQA(Dataset):
    def __init__(self, paths="", demo_path="", mode="") -> None:
        self.paths = paths
        self.mode = mode
        self.examples = self.construct_knowledge_boundary(paths, mode)
        self.prompts = []
        self.load_prompts(path="./cot_retrieval/config/prompt_en.json")
    
    def load_prompts(self, path, mode="qa"):
        with open(path, 'r') as f:
            prompt_config = json.load(f)

        self.acc_evaluation_prompt = prompt_config.get('qa_acc_evaluation_prompt')

    def construct_knowledge_boundary(self, files, mode):
        examples = []

        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_qas = [json.loads(line) for line in f.readlines()]
            
            for sample in raw_qas[:500]:
                question = sample["input"]
                answer = sample["output"][0]["answer"]
                
                examples.append({
                    "question": question,
                    "answer": answer
                })
        return examples

    def load_data(self, files, mode):
        examples = []
        for file in files:
            with open(file, "r", encoding="utf-8") as f:
                raw_qas = json.load(f)
        
            for sample in raw_qas[:500]:
                question = sample["question"]
                answer = sample["answer"]
                support_facts, wikis = [], []
                
                for support_title, number in sample["supporting_facts"]:
                    for title, documents in sample["context"]:
                        if support_title == title:
                            support_facts.append(title + " " + documents[number].strip())
                            break

                for title, documents in sample["context"]:
                    wikis.extend([title + " " + docu.strip() for docu in documents])

                examples.append({
                    "question": question,
                    "answer": answer,
                    "type": sample["type"],
                    "level": sample["level"],
                    "wiki": wikis,
                    "support_facts": list(set(support_facts))
                })
        return examples
    
    def construct_prompt(self, template="", setting="zero-shot"):
        if setting == "in-context":
            self.demo_pools = self.load_demo(self.demo_path)
            self.demo = self.demo_pools["q->a:demos"]
        
        for example in self.examples:
            question, answer, documens = example["question"], example["answer"], example["wiki"]
            
            instance = {}
            instance["prompt"] = template.format(question=question)
            instance["question"] = question
            instance["type"] = example["type"]
            instance["level"] = example["level"]
            instance["answer"] = example["answer"]
            if setting == "in-context":
                instance["demo"] = self.demo
            self.prompts.append(instance)
