import argparse
import json
import os
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, TimeoutError

from dotenv import load_dotenv
from jinja2 import Template
from openai import OpenAI
from tqdm import tqdm
from zep_cloud import Message,EntityEdge, EntityNode
from zep_cloud.client import Zep

ANSWER_PROMPT_ZEP = """
You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories.
    These memories contain timestamped information that may be relevant to answering the question.
    # INSTRUCTIONS:
    1. Carefully analyze all provided memories
    2. Pay special attention to the timestamps to determine the answer
    3. If the question asks about a specific event or fact, look for direct evidence in the memories
    4. If the memories contain contradictory information, prioritize the most recent memory
    5. If there is a question about time references (like "last year", "two months ago", etc.), 
       calculate the actual date based on the memory timestamp. For example, if a memory from 
       4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
    6. Always convert relative time references to specific dates, months, or years. For example, 
       convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory 
       timestamp. Ignore the reference while answering the question.
    7. Focus only on the content of the memories. Do not confuse character 
       names mentioned in memories with the actual users who created those memories.
    8. The answer should be less than 5-6 words.

    # APPROACH (Think step by step):
    1. First, examine all memories that contain information related to the question
    2. Examine the timestamps and content of these memories carefully
    3. Look for explicit mentions of dates, times, locations, or events that answer the question
    4. If the answer requires calculation (e.g., converting relative time references), show your work
    5. Formulate a precise, concise answer based solely on the evidence in the memories
    6. Double-check that your answer directly addresses the question asked
    7. Ensure your final answer is specific and avoids vague time references

    Memories:

    {{memories}}

    Question: {{question}}
    Answer:
    """

load_dotenv()

TEMPLATE = """
FACTS and ENTITIES represent relevant context to the current conversation.

# These are the most relevant facts and their valid date ranges
# format: FACT (Date range: from - to)

{facts}


# These are the most relevant entities
# ENTITY_NAME: entity summary

{entities}

"""
ZEP_API_KEY = "Your ZEP_API_KEY"

class ZepClient:
    def __init__(self, data_path=None):
        self.zep_client = Zep(api_key=ZEP_API_KEY)
        self.data_path = data_path
        self.data = None
        self.results = defaultdict(list)
        self.openai_client = OpenAI(
            api_key = "key",
            base_url = "http://localhost:8866/v1"
        )
        if data_path:
            self.load_data()

    def load_data(self):
        with open(self.data_path, "r") as f:
            self.data = json.load(f)
        return self.data

    def process_conversation(self, run_id, item, idx):
        conversation = item["conversation"]

        user_id = f"run_id_{run_id}_experiment_user_{idx}"
        session_id = f"run_id_{run_id}_experiment_session_{idx}"

        try:
            self.zep_client.user.delete(user_id=user_id)
        except Exception:
            pass

        try:
            self.zep_client.thread.delete(thread_id=session_id)
        except Exception:
            pass

        self.zep_client.user.add(user_id=user_id)
        self.zep_client.thread.create(user_id=user_id, thread_id=session_id)


        print("Starting to add memories... for user", user_id)
        for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"):
            if key in ["speaker_a", "speaker_b"] or "date" in key:
                continue

            date_time_key = key + "_date_time"
            timestamp = conversation[date_time_key]
            chats = conversation[key]

            for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False):
                self.zep_client.thread.add_messages(
                    thread_id=session_id,
                    messages=[
                        Message(
                            role="user" if chat["speaker"] == "speaker_a" else "assistant",
                            role_type="user",
                            content=f"{timestamp} | {chat['speaker']} : {chat['text']}",
                        )
                    ],
                )

    def process_all_conversations_add(self, run_id):
        if not self.data:
            raise ValueError("No data loaded. Please set data_path and call load_data() first.")
        for idx, item in tqdm(enumerate(self.data)):
            if idx > 0:
                self.process_conversation(run_id, item, idx)
    
    def format_edge_date_range(self, edge: EntityEdge) -> str:
        return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"

    def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
        facts = [f"  - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges]
        entities = [f"  - {node.name}: {node.summary}" for node in nodes]
        return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities))

    def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
        start_time = time.time()
        retries = 0
        while retries < max_retries:
            try:
                user_id = f"run_id_{run_id}_experiment_user_{idx}"
                edges_results = (
                    self.zep_client.graph.search(
                        user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=25
                    )
                ).edges
                node_results = (
                    self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=25)
                ).nodes
                context = self.compose_search_context(edges_results, node_results)
                break
            except Exception as e:
                print("Retrying...")
                retries += 1
                if retries >= max_retries:
                    raise e
                time.sleep(retry_delay)

        end_time = time.time()

        return context, end_time - start_time

    def process_question(self, run_id, val, idx):
        question = val.get("question", "")
        answer = val.get("answer", "")
        category = val.get("category", -1)
        evidence = val.get("evidence", [])
        adversarial_answer = val.get("adversarial_answer", "")

        answer_prompt, response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question)

        result = {
            "question": question,
            "standard answer": answer,
            "category": category,
            "evidence": evidence,
            "prompt": answer_prompt,
            "answer": response,
            "adversarial_answer": adversarial_answer,
            "search_memory_time": search_memory_time,
            "response_time": response_time,
            "context": context,
        }
        return result

    def answer_question(self, run_id, idx, question, timeout=300):
        context, search_memory_time = self.search_memory(run_id, idx, question)

        template = Template(ANSWER_PROMPT_ZEP)
        answer_prompt = template.render(memories=context, question=question)
        
        t1 = time.time()

        with ThreadPoolExecutor(max_workers=1) as executor:
            future = executor.submit(
                self.openai_client.chat.completions.create,
                model=os.getenv("MODEL"),
                messages=[{"role": "system", "content": answer_prompt}],
                temperature=0.0
            )
            try:
                response = future.result(timeout=timeout)  
            except TimeoutError:
                response_time = time.time() - t1
                return answer_prompt, "", search_memory_time, response_time, context
        t2 = time.time()
        response_time = t2 - t1
        return answer_prompt, response.choices[0].message.content, search_memory_time, response_time, context

    def process_data_file_qa(self, file_path, run_id, output_file_path):
        with open(file_path, "r") as f:
            data = json.load(f)

        for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
            qa = item["qa"]

            for question_item in tqdm(
                qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
            ):
                adversarial_answer = question_item.get("adversarial_answer", "")
                if adversarial_answer != "":
                    continue
                result = self.process_question(run_id, question_item, idx)
                self.results[idx].append(result)

                # Save results after each question is processed
                with open(output_file_path, "w") as f:
                    json.dump(self.results, f, indent=4)

        # Final save at the end
        with open(output_file_path, "w") as f:
            json.dump(self.results, f, indent=4)

    def process_conversation_and_qa(self, run_id, item, idx, output_file_path):
        # step 1: add conversation
        self.process_conversation(run_id, item, idx)

        # step 2: QA for this conversation
        qa = item.get("qa", [])
        for question_item in tqdm(qa, total=len(qa), desc=f"QA for conversation {idx}", leave=False):
            adversarial_answer = question_item.get("adversarial_answer", "")
            if adversarial_answer != "":
                continue

            result = self.process_question(run_id, question_item, idx)
            self.results[idx].append(result)

            # save after each QA
            with open(output_file_path, "w") as f:
                json.dump(self.results, f, indent=4)

    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    zep_client = ZepClient(data_path="./data/locomo10.json")
    zep_client.process_all_conversations_add("1")
    zep_client.process_data_file_qa("./data/locomo10.json", "1", "your_output_path")
