import json
import yaml


def load_config(config_path):
    """Load the YAML configuration file."""
    with open(config_path, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)

def create_context_str(context_dict: dict):
    """Convert a context dict to the string format"""
    entities_context = context_dict["entities_context"]
    relations_context = context_dict["relations_context"]
    # text_units_context = context_dict['text_units_context']

    entities_str = json.dumps(entities_context, ensure_ascii=False)
    relations_str = json.dumps(relations_context, ensure_ascii=False)
    # text_units_str = json.dumps(text_units_context, ensure_ascii=False)
    
    return f"""-----Entities(KG)-----

```json
{entities_str}
```

-----Relationships(KG)-----

```json
{relations_str}
```
"""

# -----Document Chunks(DC)-----

# ```json
# {text_units_str}
# ```

# """

def write_to_txt(output_file: str, config: dict, answer_before: str, answers_after: dict, qtext: str, strategy: str, perturb_level: str = "entity", merge_logs: list = None):
    with open(output_file, "w", encoding="utf-8") as f:
        f.write("Experiment Configuration:\n")
        f.write("=========================\n")
        f.write(json.dumps(config, indent=4))
        f.write("\n\n")

        if merge_logs:
            f.write("Entity Deduplication Merging Process:\n")
            f.write("=====================================\n")
            for log in merge_logs:
                f.write(f"- {log}\n")
            f.write("\n\n")
        
        f.write("Experiment Results:\n")
        f.write("==================\n\n")
        f.write(f"Original Query: {qtext}\n\n")
        f.write("---\n\n")
        f.write("Answer with Full Context:\n")
        f.write(f"{answer_before}\n\n")
        f.write("---\n\n")

        target_name = perturb_level.capitalize()
        f.write(f"--- {len(answers_after)} {target_name} perturbations performed ---\n\n")
        f.write(f"Answers after {strategy}ing {target_name}s:\n")
        f.write("--------------------------------\n\n")
        for i, (item, answer) in enumerate(answers_after.items()):
            f.write(f"### {strategy}ed {target_name} {i}: '{item}'\n\n")
            f.write(f"{answer}\n\n")
            f.write("---\n\n")
    print(f"\nExperiment results saved to: {output_file}")


def format_answer(answer: str, delimiter: str):
    if delimiter in answer:
        answer = answer.split(delimiter,1)[1].lstrip()
    return answer.split("References")[0].strip()
    
def format_explainer_input(context: str, question: str):
    return f"Context: {context}\nQuestion:{question}"