import time
import json
import os
import datasets
import numpy as np
from rich.progress import track
from sklearn.metrics.pairwise import cosine_similarity
from src.graph_utils import path_to_str_with_graph, summary
from src.models import select_embedding_model
from src.utils import logger
from langchain.prompts import PromptTemplate
import argparse
import atexit


QA_FORMAT_TEMPLATE = """
Q: {question}
A: ```{answer}```
"""

AGENT_FORMAT_TEMPLATE = """
Question: {question}
Answer: {answer}
Paths: {path}
"""




class Memory:
    def __init__(self,
                 embed_model,
                 data: list[str] | None = None,
                 dataset_name: str | None = None,
                 memory_file: str | None = None):
        """
        Args:
            embed_model: 嵌入模型
            data: 数据列表，如果为 None，则从 memory_file 中读取
            dataset_name: 数据集名称
            memory_file: 记忆文件，如果为 None，则从 dataset_name 中读取
        """

        self.embed_model = embed_model
        self.inference_memory = []  # 存储推理阶段的记忆
        self.inference_embeddings = []  # 新增: 存储推理记忆的embeddings
        self.memory_template_map = {
            "qa": PromptTemplate(
                input_variables=["question", "answer"],
                template=QA_FORMAT_TEMPLATE
            ),
            "agent": PromptTemplate(
                input_variables=["question", "answer", "path"],
                template=AGENT_FORMAT_TEMPLATE
            ),
        }

        if dataset_name is not None:
            dataset_name = dataset_name.split("/")[-1]
            memory_file_map = {
                "webqsp": "data/memory/rmanluo_RoG-webqsp.valid.txt",
                "cwq": "data/memory/rmanluo_RoG-cwq.shortest.txt",
                # "cwq": "data/memory/rmanluo_RoG-cwq.shortest.txt",
            }
            self.memory_file = memory_file or memory_file_map[dataset_name]
            logger.info(f"Loading memory from {self.memory_file}")
            with open(self.memory_file, "r") as f:
                self.lines = [json.loads(line) for line in f.readlines()]
        else:
            self.lines = data or []

        # 初始化基础记忆的embeddings
        self.base_embeddings = self._calculate_embeddings(self.lines)

        # 注册退出时的清理函数
        atexit.register(self.close)

    def format_memory(self, memory):
        return self.memory_template_map["qa"].invoke({
            "question": memory["question"],
            "answer": json.dumps({
                "candidate_answers": memory["answer"],
                "reasoning_paths": memory["path"]
            }, indent=4, ensure_ascii=False)
        }).text

    def _calculate_embeddings(self, memories):
        """计算给定记忆的embeddings"""
        if not memories:
            return None

        format_lines = []
        for line in memories:
            format_lines.append(self.format_memory(line))

        return self.embed_model.encode([d[:500] for d in format_lines])

    def retrieve_similar(self, query, neg_query=None, top_n=10, return_format="json"):
        if top_n == 0:
            return [], []

        query_embed = self.embed_model.encode(query)

        # 合并基础记忆和推理记忆的embeddings
        all_embeddings = np.vstack([self.base_embeddings, self.inference_embeddings]) if self.inference_embeddings else self.base_embeddings
        all_memories = self.lines + self.inference_memory

        similarities = cosine_similarity([query_embed], all_embeddings)[0]
        if neg_query is not None:
            neg_query_embed = self.embed_model.encode(neg_query)
            neg_similarities = cosine_similarity([neg_query_embed], all_embeddings)[0]
            similarities = similarities - neg_similarities

        top_indices = np.argsort(similarities)[-top_n:][::-1]
        lines = [all_memories[idx] for idx in top_indices]
        sim_scores = [similarities[idx] for idx in top_indices]

        if return_format == "qa":
            lines = [self.memory_template_map["qa"].invoke({
                "question": line["question"],
                "answer": json.dumps({
                    "candidate_answers": line["answer"],
                    "reasoning_paths": line["path"]
                }, indent=4, ensure_ascii=False)
            }).text for line in lines]
        elif return_format == "agent":
            lines = [self.memory_template_map["agent"].invoke(line).text for line in lines]

        return [lines, sim_scores]

    def write_instances(self, instance, answers, paths):
        """写入新的推理记忆实例"""
        new_memory = {
            "question": instance["question"],
            "entity": instance.get("q_entity", [""])[0],
            "answer": answers,
            "path": paths
        }

        memory = self.format_memory(new_memory)
        new_embedding = self.embed_model.encode([memory[:500]])

        self.inference_memory.append(new_memory)
        self.inference_embeddings.extend(new_embedding)
        assert len(self.inference_memory) == len(self.inference_embeddings), "Inference memory and embeddings count mismatch"
        return new_memory

    def close(self):
        """显式关闭和清理资源，特别是嵌入模型的多进程资源"""
        try:
            # 清理嵌入模型资源
            if hasattr(self.embed_model, 'stop_self_pool'):
                self.embed_model.stop_self_pool()
            # 其他可能需要清理的资源...
        except Exception as e:
            logger.warning(f"清理资源时发生错误: {e}")

    def __del__(self):
        """确保对象被垃圾回收时资源被正确清理"""
        self.close()

def generate_memory(dataset, dataset_name: str, memory_file: str):
    memorys = []
    for ins in track(dataset, description="Memory Creation"):
        ins = summary(ins, with_graph=True)
        if len(ins["graph"]) == 0 or len(ins["truth_paths"]) == 0:
            print(f"Skipping {ins['question']} because it has no graph or truth paths")
            continue
        assert "webqsp" in dataset_name or "cwq" in dataset_name, f"Unsupported dataset: {dataset_name}"
        memorys.append({
            "question": ins["question"],
            "entity": ins["q_entity"][0],
            "answer": ins["answer"][:10],
            "path": path_to_str_with_graph(ins["q_entity"], ins["truth_paths"][:10], ins["graph"])
        })

    with open(memory_file, "w") as f:
        for memory in memorys:
            f.writelines(json.dumps(memory, ensure_ascii=False) + "\n")

    return memorys

def parse_args():
    parser = argparse.ArgumentParser(description='Generate and manage memory for the model')
    parser.add_argument('--dataset-name', type=str, default="webqsp", choices=['cwq', 'webqsp'], help='Name of the dataset to use')
    parser.add_argument('--data-path', type=str, default="data/webqsp", help='Path to the dataset to use')
    parser.add_argument('--path-type', type=str, default="shortest", choices=['shortest', 'valid'], help='Type of paths to use (shortest or valid)')
    parser.add_argument('--valid-path-file', type=str, default="data/RoG-webqsp.train.valid.jsonl", help='Path to the valid path file to use')
    parser.add_argument('--embed-model', type=str, default="bge-m3", help='Name of the embedding model to use')
    return parser.parse_args()

def main():
    args = parse_args()

    if args.path_type == "shortest": # 仅限 Huggingface Datasets
        assert args.data_path is not None, "data_path is required for Huggingface Datasets"
        dataset = datasets.load_dataset(args.data_path, split="train")
    else:
        assert args.valid_path_file is not None, "valid_path_file is required for local dataset"
        with open(args.valid_path_file, "r") as f:
            lines = f.readlines()
        dataset = [json.loads(line) for line in tqdm(lines, desc="Loading dataset")]

    # 获取基础的文件名，然后在前面加上 memory.dataset_name.path_type.valid-path-file.txt, 务必去掉 jsonl
    base_name = args.valid_path_file.split("/")[-1] if args.path_type == "valid" else "shortest"
    base_name = base_name.replace(".jsonl", "")
    memory_file = f"data/memory/{args.dataset_name.replace('/', '_')}.{base_name}.txt"
    print(f"Memory file: {memory_file}")

    try:
        memorys = generate_memory(dataset, args.dataset_name, memory_file)
        embed_model = select_embedding_model(args.embed_model)
        memory = Memory(embed_model=embed_model, data=memorys, dataset_name=args.dataset_name)

        # 测试查询示例
        print(memory.retrieve_similar("What is the capital of France?", "France"))
    finally:
        # 确保在程序结束时清理资源
        if 'memory' in locals():
            memory.close()

def merge_memory(memory_file_list: list[str]):
    lines = []
    for memory_file in memory_file_list:
        with open(memory_file, "r") as f:
            lines.extend(f.readlines())
    return lines

if __name__ == "__main__":
    main()

    # memory_paths = [
    #     "data/memory/cwq.train.valid.qwen3-32b.txt",
    #     "data/memory/webqsp.RoG-webqsp.train.valid.txt",
    # ]
    # lines = merge_memory(memory_paths)
    # with open("data/memory/cwq.qwen3-32b+webqsp.gpt-4o.train.valid.txt", "w") as f:
    #     f.writelines(lines)