from mem0 import Memory
import json
import os
import threading
import time
from mem0.configs.prompts import DEFAULT_UPDATE_MEMORY_PROMPT, ANSWER_PROMPT, ANSWER_PROMPT_GRAPH  
from tqdm import tqdm
from jinja2 import Template
from openai import OpenAI
from collections import defaultdict

openai_api_key = "EMPTY"
openai_api_base = "http://127.0.0.1:8866/v1"


class MemoryManager:
    def __init__(self, is_graph=False, 
        output_path="/root/storage/mem0_result/mem_result.txt",
        data_path="/root/storage/locomo10.json"):
        config = {
            "vector_store": {
                "provider": "faiss",
                "config": {
                    "collection_name": "test",
                    "path": "/root/storage/faiss_memories",
                    "distance_strategy": "euclidean",
                    "embedding_model_dims":1024
                },
            },
            "llm": {
                "provider": "vllm",
                "config": {
                    "model": "Llama-3.1-8B-Instruct",
                    "vllm_base_url": "http://127.0.0.1:8866/v1",
                    "temperature": 0,
                    "max_tokens": 512,
                },
            },
            "embedder": {
                "provider": "huggingface",
                "config": {
                    "model": "/root/storage/models/bge-m3"
                }
            },
        }
        self.mem = Memory.from_config(config)
        self.is_graph = is_graph
        if self.is_graph:
            self.ANSWER_PROMPT = ANSWER_PROMPT_GRAPH
        else:
            self.ANSWER_PROMPT = ANSWER_PROMPT
        self.output_path = output_path
        self.data_path = data_path
        self.data = None
        self.openai_client = OpenAI(
            api_key=openai_api_key,
            base_url=openai_api_base,
        )
        self.results = defaultdict(list)
        if data_path:
            with open(self.data_path, "r") as f:
                self.data = json.load(f)

    def add_memory(self, user_id, message, metadata, retries=3):
        for attempt in range(retries):
            try:
                _ = self.mem.add(
                    message, user_id=user_id, metadata=metadata
                )
                return
            except Exception as e:
                if attempt < retries - 1:
                    time.sleep(1)  # Wait before retrying
                    continue
                else:
                    raise e

    def add_memories_for_speaker(self, speaker, messages, timestamp, desc):
        for message in messages:
            self.add_memory(speaker, message, metadata={"timestamp": timestamp})

    def memory_warmup_for_conversation(self, item, idx):
        conversation = item["conversation"]
        speaker_a = conversation["speaker_a"]
        speaker_b = conversation["speaker_b"]

        speaker_a_user_id = f"{speaker_a}_{idx}"
        speaker_b_user_id = f"{speaker_b}_{idx}"

        self.mem.delete_all(user_id=speaker_a_user_id)
        self.mem.delete_all(user_id=speaker_b_user_id)

        for cnt, key in tqdm(enumerate(conversation.keys())):
            if key in {"speaker_a", "speaker_b"} or "date" in key or "timestamp" in key or "sample_id" in key:
                continue

            date_time_key = key + "_date_time"
            timestamp = conversation[date_time_key]
            chats = conversation[key]

            messages = []
            messages_reverse = []
            for chat in chats:
                if chat["speaker"] == speaker_a:
                    messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"})
                    messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"})
                elif chat["speaker"] == speaker_b:
                    messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"})
                    messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"})
                else:
                    raise ValueError(f"Unknown speaker: {chat['speaker']}")

            self.add_memories_for_speaker(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A")
            self.add_memories_for_speaker(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B")

        print("Messages added successfully")
    
    def search_memory(self, user_id, query, max_retries=3, retry_delay=1):
        start_time = time.time()
        retries = 0
        while retries < max_retries:
            try:
                if self.is_graph:
                    memories = self.mem.search(
                        query,
                        user_id=user_id,
                    )
                else:
                    memories = self.mem.search(
                        query, user_id=user_id
                    )
                break
            except Exception as e:
                print("Retrying...")
                retries += 1
                if retries >= max_retries:
                    raise e
                time.sleep(retry_delay)

        end_time = time.time()
        if not self.is_graph:
            semantic_memories = [
                {
                    "memory": memory["memory"],
                    "timestamp": memory["metadata"]["timestamp"],
                    "score": round(memory["score"], 2),
                }
                for memory in memories['results']
            ]
            graph_memories = None
        else:
            semantic_memories = [
                {
                    "memory": memory["memory"],
                    "timestamp": memory["metadata"]["timestamp"],
                    "score": round(memory["score"], 2),
                }
                for memory in memories["results"]
            ]
            graph_memories = [
                {"source": relation["source"], "relationship": relation["relationship"], "target": relation["target"]}
                for relation in memories["relations"]
            ]
        return semantic_memories, graph_memories, end_time - start_time

    def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category):
        speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(
            speaker_1_user_id, question
        )
        speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(
            speaker_2_user_id, question
        )

        search_1_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_1_memories]
        search_2_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_2_memories]

        template = Template(self.ANSWER_PROMPT)
        answer_prompt = template.render(
            speaker_1_user_id=speaker_1_user_id.split("_")[0],
            speaker_2_user_id=speaker_2_user_id.split("_")[0],
            speaker_1_memories=json.dumps(search_1_memory, indent=4),
            speaker_2_memories=json.dumps(search_2_memory, indent=4),
            speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4),
            speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4),
            question=question,
        )

        t1 = time.time()
        response = self.openai_client.chat.completions.create(
            model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0, max_tokens=20
        )
        t2 = time.time()
        response_time = t2 - t1
        return (
            response.choices[0].message.content,
            speaker_1_memories,
            speaker_2_memories,
            speaker_1_memory_time,
            speaker_2_memory_time,
            speaker_1_graph_memories,
            speaker_2_graph_memories,
            response_time,
        )

    def process_question(self, val, speaker_a_user_id, speaker_b_user_id):
        question = val.get("question", "")
        answer = val.get("answer", "")
        category = val.get("category", -1)
        evidence = val.get("evidence", [])
        adversarial_answer = val.get("adversarial_answer", "")
        if adversarial_answer != "":
            return None
        (
            response,
            speaker_1_memories,
            speaker_2_memories,
            speaker_1_memory_time,
            speaker_2_memory_time,
            speaker_1_graph_memories,
            speaker_2_graph_memories,
            response_time,
        ) = self.answer_question(speaker_a_user_id, speaker_b_user_id, question, answer, category)

        result = {
            "question": question,
            "standard answer": answer,
            "category": category,
            "evidence": evidence,
            "answer": response,
            "adversarial_answer": adversarial_answer,
            "speaker_1_memories": speaker_1_memories,
            "speaker_2_memories": speaker_2_memories,
            "num_speaker_1_memories": len(speaker_1_memories),
            "num_speaker_2_memories": len(speaker_2_memories),
            "speaker_1_memory_time": speaker_1_memory_time,
            "speaker_2_memory_time": speaker_2_memory_time,
            "speaker_1_graph_memories": speaker_1_graph_memories,
            "speaker_2_graph_memories": speaker_2_graph_memories,
            "response_time": response_time,
        }
        return result


def main(conversation_idx):
    MM = MemoryManager()
    item = MM.data[conversation_idx]
    print(f"Processing Conversation {conversation_idx}")

    MM.memory_warmup_for_conversation(item, conversation_idx)

    qa = item["qa"]
    conversation = item["conversation"]
    speaker_a = conversation["speaker_a"]
    speaker_b = conversation["speaker_b"]

    speaker_a_user_id = f"{speaker_a}_{conversation_idx}"
    speaker_b_user_id = f"{speaker_b}_{conversation_idx}"

    for question_item in qa:
        result = MM.process_question(question_item, speaker_a_user_id, speaker_b_user_id)
        if result is not None:
            print(f"Processed Question {question_item['question']}")
        if result != None: 
            MM.results[conversation_idx].append(result) 
    with open(MM.output_path, "a") as f: 
        json.dump(MM.results, f, indent=4)
    
    print(f"Finished processing Conversation {conversation_idx}")

if __name__ == "__main__":
    import sys
    if len(sys.argv) != 2:
        print("Usage: python script.py <conversation_index>")
    else:
        conversation_index = int(sys.argv[1])
        main(conversation_index)
