import asyncio
import logging
import json
import yaml
import copy
from datetime import datetime
from pathlib import Path

import sys
sys.path.append('../')

from LightRAG.lightrag import LightRAG, QueryParam
from LightRAG.lightrag.utils import EmbeddingFunc
from LightRAG.lightrag.llm.ollama import ollama_embed, ollama_model_complete
from LightRAG.lightrag.kg.shared_storage import initialize_pipeline_status, initialize_share_data
from src.xgrag.perturbations import manipulate_entity, manipulate_relation
from src.xgrag.deduplications import deduplicate_entities
from src.xgrag.utils import write_to_txt, create_context_str, format_answer, format_explainer_input, load_config

import os
os.environ["NO_PROXY"] = '127.0.0.1, localhost'
os.environ['no_proxy'] = '127.0.0.1, localhost'

import nest_asyncio
nest_asyncio.apply()

CONFIG_FILE = "../config.yaml"

config = load_config(CONFIG_FILE)

# --- Experiment Settings ---
STRATEGY = config['strategy']
PERTURB_LEVEL = config["perturb_level"]
QTEXT = config['query_text']
DOC_NAME = config['doc_name']

# --- File Paths ---
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
WORKING_DIR = Path(config['working_dir'])
EXPERIMENT_DIR = Path(config['experiment_dir'].format(working_dir=WORKING_DIR, timestamp=timestamp))
TXT_FILE = Path(config['input_txt_file'])
KG_STORE_PATH = Path(config['kg_store_path_template'].format(working_dir=WORKING_DIR))
OUTPUT_FILE = Path(config['output_file_template'].format(working_dir=WORKING_DIR, strategy=STRATEGY, experiment_dir=EXPERIMENT_DIR))
CONTEXT_OUTPUT_PATH = Path(config['initial_context_path_template'].format(working_dir=WORKING_DIR, experiment_dir=EXPERIMENT_DIR))
CONTEXT_DEDUPLICATED_OUTPUT_PATH = Path(config['initial_context_deduplicated_path_template'].format(working_dir=WORKING_DIR, experiment_dir=EXPERIMENT_DIR))
PERTURBATION_OUTPUT_PATH = Path(config['perturbed_context_path_template'].format(working_dir=WORKING_DIR, experiment_dir=EXPERIMENT_DIR))
PERTURBED_ANSWER_PATH = Path(config['perturbed_answers_path_template'].format(working_dir=WORKING_DIR, experiment_dir=EXPERIMENT_DIR))
INITIAL_ANSWER_PATH = Path(config['initial_answer_path_template'].format(working_dir=WORKING_DIR, experiment_dir=EXPERIMENT_DIR))
EXPLAINER_INPUT_FILE = Path(config['explainer_input_file_template'].format(working_dir=WORKING_DIR, experiment_dir=EXPERIMENT_DIR))

WORKING_DIR.mkdir(parents=True, exist_ok=True)
EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)


async def run_lightrag_with_perturbation():

    rag_cfg = config['lightrag']
    llm_cfg = rag_cfg['llm']
    emb_cfg = rag_cfg['embedding']
    chunk_cfg = rag_cfg['chunk']

    # initialize lightrag
    rag = LightRAG(
        working_dir=str(KG_STORE_PATH),
        llm_model_func=ollama_model_complete,
        llm_model_name=llm_cfg['model_name'],
        llm_model_max_async=llm_cfg['max_async'],
        llm_model_max_token_size=llm_cfg['max_token_size'],
        llm_model_kwargs=llm_cfg['kwargs'],
        embedding_func=EmbeddingFunc(
            embedding_dim=emb_cfg['dim'],
            max_token_size=emb_cfg['max_token_size'],
            func=lambda texts: ollama_embed(
                texts, embed_model=emb_cfg['model'], host=emb_cfg['host']
            ),
        ),
        chunk_token_size=chunk_cfg['token_size'],
        chunk_overlap_token_size=chunk_cfg['overlap_token_size'],
        enable_llm_cache=rag_cfg['enable_llm_cache'],
        log_level=getattr(logging, rag_cfg['log_level'].upper())
    )
    await rag.initialize_storages()
    initialize_share_data(workers=1)
    await initialize_pipeline_status()

    # insert the txt file and create the knowledge graph
    text = TXT_FILE.read_text(encoding="utf-8")
    await rag.ainsert(
        text,
        file_paths=[str(TXT_FILE)],
        ids=[TXT_FILE.stem]
    )
    await rag.finalize_storages()

    # query lightrag
    qparam_cfg = config['query_param']
    qparam = QueryParam(
        mode=qparam_cfg['mode'],
        conversation_history=qparam_cfg['conversation_history'],
        enable_rerank=qparam_cfg['enable_rerank'],
        response_type=qparam_cfg['response_type']
    )

    # retrieve the context and generate the answer
    _, ctx_dict_before, _ = await rag.aquery(QTEXT, param=qparam)
    ctx_before = create_context_str(ctx_dict_before)    
    answer_before = await rag.aquery_from_context(QTEXT, ctx_before, param=qparam)
    answer_before = format_answer(answer_before,'</think>')
    with open(str(CONTEXT_OUTPUT_PATH), "w", encoding="utf-8") as f:
        json.dump(ctx_dict_before, f, indent=4, ensure_ascii=False)
    print(f"Initial context saved to: {CONTEXT_OUTPUT_PATH}\n")
    print("Answer BEFORE manipulating a node: \n", format_answer(answer_before,'</think>'), "\n")
    
    # --- Deduplicate entities by calling the function directly ---
    print("--- Starting Entity Deduplication ---")
    # The function is called here, operating directly on the dictionary.
    dedup_cfg = config.get('deduplication', {})
    similarity_threshold = dedup_cfg.get('similarity_threshold', 70) # Default to 70 if not in config
    ctx_dict_deduplicated, merge_logs = deduplicate_entities(copy.deepcopy(ctx_dict_before), similarity_threshold)
    
    print("\n--- Deduplication Complete ---")
    print(f"Original entity count: {len(ctx_dict_before['entities_context'])}")
    print(f"Deduplicated entity count: {len(ctx_dict_deduplicated['entities_context'])}")
    print(f"Original relation count: {len(ctx_dict_before.get('relations_context', []))}")
    print(f"Deduplicated relation count: {len(ctx_dict_deduplicated.get('relations_context', []))}")

    with open(str(CONTEXT_DEDUPLICATED_OUTPUT_PATH), "w", encoding="utf-8") as f:
        json.dump(ctx_dict_deduplicated, f, indent=4, ensure_ascii=False)
    print(f"Deduplicated context saved to: {CONTEXT_DEDUPLICATED_OUTPUT_PATH}\n")

    answers_after = {}
    ctx_dict_after_list = []

    if PERTURB_LEVEL == 'entity':
        all_items_to_perturb = sorted(list({e['entity'] for e in ctx_dict_deduplicated['entities_context']}))
        print(f"--- Found {len(all_items_to_perturb)} unique entities to perturb ---")

        for i, entity_to_manipulate in enumerate(all_items_to_perturb):
            print(f"--- Running experiment by {STRATEGY}ing entity {i}: '{entity_to_manipulate}' ---")

            ctx_dict_copy = copy.deepcopy(ctx_dict_deduplicated)
            ctx_dict_after = manipulate_entity(ctx_dict_copy, entity_to_manipulate, STRATEGY, DOC_NAME)
            ctx_dict_after_list.append(ctx_dict_after)
            ctx_after = create_context_str(ctx_dict_after)

            answer_after = await rag.aquery_from_context(QTEXT, ctx_after, param=qparam)
            answer_after = format_answer(answer_after, '</think>')
            answers_after[entity_to_manipulate] = answer_after
            print(f"Answer AFTER {STRATEGY}ing an entity: \n", answer_after, "\n")

    elif PERTURB_LEVEL == 'relation':
        all_items_to_perturb = ctx_dict_deduplicated.get('relations_context', [])
        print(f"--- Found {len(all_items_to_perturb)} relations to perturb ---")

        for i, relation_to_manipulate in enumerate(all_items_to_perturb):
            relation_str = f"{relation_to_manipulate.get('entity1')} -> {relation_to_manipulate.get('entity2')}"
            print(f"--- Running experiment by {STRATEGY}ing relation {i}: '{relation_str}' ---")

            ctx_dict_copy = copy.deepcopy(ctx_dict_deduplicated)
            ctx_dict_after = manipulate_relation(ctx_dict_copy, relation_to_manipulate, STRATEGY, DOC_NAME)
            ctx_dict_after_list.append(ctx_dict_after)
            ctx_after = create_context_str(ctx_dict_after)

            answer_after = await rag.aquery_from_context(QTEXT, ctx_after, param=qparam)
            answer_after = format_answer(answer_after, '</think>')
            answers_after[relation_str] = answer_after
            print(f"Answer AFTER {STRATEGY}ing a relation: \n", answer_after, "\n")
    else:
        raise ValueError(f"Unknown perturb_level in config: '{PERTURB_LEVEL}'. Must be 'entity' or 'relation'.")

    # write the result to the output txt file
    write_to_txt(
            output_file=OUTPUT_FILE,
            config=config,
            answer_before=answer_before,
            answers_after=answers_after,
            qtext=QTEXT,
            strategy=STRATEGY,
            perturb_level=PERTURB_LEVEL,
            merge_logs=merge_logs
        )
    
    explainer_input = format_explainer_input(ctx_before, QTEXT)
    with open(str(INITIAL_ANSWER_PATH), "w", encoding="utf-8") as f:
        f.write(answer_before)

    with open(str(EXPLAINER_INPUT_FILE), "w", encoding="utf-8") as f:
        f.write(explainer_input)
        f.close()
    
    with open(str(PERTURBATION_OUTPUT_PATH), "w", encoding="utf-8") as f:
        json.dump(ctx_dict_after_list, f)
        f.close()

    with open(str(PERTURBED_ANSWER_PATH), "w", encoding="utf-8") as f:
        json.dump(answers_after, f, ensure_ascii=False)
        f.close()

if __name__ == "__main__":
    asyncio.run(run_lightrag_with_perturbation())