import os
import json
from typing import (
    Tuple,
    List,
)

import uuid
import openai
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from loguru import logger
from dotenv import load_dotenv

from src.schema import (
    QA,
    Answer,
    Memory,
    MemoryType,
    Turn,
    Session,
    User,
    ColoredGraph,
)

from openai import AsyncOpenAI
import asyncio, os
from asyncio import Semaphore


# MEMORY_GENERATION_PROMPT_TEMPLATE = """
# You are a memory generation assistant.

# Your goal is to create a realistic and coherent dialogue between a User and an Assistant.

# The conversation should help the User arrive at the given short answer naturally, using the provided question and evidences.

# ### Requirements
# - Simulate a natural dialogue that could plausibly happen in a real conversation.
# - The User is curious or uncertain, and the Assistant guides them toward understanding.
# - Use the evidences to construct intermediate reasoning steps.
# - Include clarifications, follow-ups, or analogies if helpful.
# - Use casual but clear language.
# - Keep the number of dialogue turns between 3 and 6.
# - You should generate under 256 tokens.
# - Do not mention the term "evidence" or refer to this being a prompt.
# - Avoid repetition or stating the short answer too early.

# ### Input
# Question:
# {question}

# Rubric Question (to help interpret the original question):
# {rubric_question}

# Short Answer (the answer the conversation should lead to):
# {short_answer}

# Supporting Evidences (facts or prior knowledge to use naturally in the dialogue):
# {evidences_str}

# Previous Sessions (optional context from prior related conversations):
# {prev_sessions_str}

# ### Output Format (in JSON)
# Write the conversation as a list of turns in **JSON format** using the structure.
# Because it's a JSON list, you should not add ',' at the end of last item.
# [
#   {{"speaker": "User", "content": "..." }},
#   {{"speaker": "Assistant", "content": "..." }},
#   ...
#   {{"speaker": "Assistant", "content": "..." }}
# ]
# Only return the JSON content, nothing else.
# """.strip()

MEMORY_GENERATION_PROMPT_TEMPLATE = """
You are a memory generation assistant.

Your goal is to create a realistic and coherent dialogue between a User and an Assistant.

The conversation should help the User arrive at the given short answer naturally, using the provided question and evidences.

### Requirements
- Simulate a natural dialogue that could plausibly happen in a real conversation.
- The User is curious or uncertain, and the Assistant guides them toward understanding.
- Use the evidences to construct intermediate reasoning steps.
- Include clarifications, follow-ups, or analogies if helpful.
- Use casual but clear language.
- Keep the number of dialogue turns between 3 and 6.
- You must generate under 256 tokens. Never exceed this limit.
- Do not mention the term "evidence" or refer to this being a prompt.
- Avoid repetition or stating the short answer too early.

### Input
Question:
{question}

Rubric Question (to help interpret the original question):
{rubric_question}

Short Answer (the answer the conversation should lead to):
{short_answer}

Supporting Evidences (facts or prior knowledge to use naturally in the dialogue):
{evidences_str}

### Output Format (in JSON)
Write the conversation as a list of turns in **JSON format** using the structure.
Because it's a JSON list, you should not add ',' at the end of last item.
[
  {{"speaker": "User", "content": "..." }},
  {{"speaker": "Assistant", "content": "..." }},
  ...
  {{"speaker": "Assistant", "content": "..." }}
]
Only return the JSON content, nothing else.
""".strip()


class MemoryGenerationStage:
    def __init__(self) -> None:
        logger.info("Initializing MemoryGenerationStage.")

        load_dotenv()
        # self.client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
        self.client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
        self.sem = Semaphore(8)

    def _generate_session(
        self,
        question: str,
        answer: Answer,
        prev_sessions: List[Session],
    ) -> Session:
        evidences_str = "\n".join([f"{evidence.question}: {evidence.answer}" for evidence in answer.evidences])
        prev_sessions_str = "\n\n".join(self._session_to_str(session) for session in reversed(prev_sessions))
        prompt = MEMORY_GENERATION_PROMPT_TEMPLATE.format(
            question=question,
            short_answer=answer.short_answer,
            rubric_question=answer.rubric_question,
            evidences_str=evidences_str,
            # prev_sessions_str=prev_sessions_str,
        )
        # response = self.client.chat.completions.create(
        #     model="gpt-5-nano-2025-08-07",  # "gpt-4o-mini",
        #     messages=[{"role": "user", "content": prompt}],
        #     max_completion_tokens=256,  # max_tokens=256,
        #     temperature=1.0,
        # )
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1024,
            temperature=1.0,
        )
        result = response.choices[0].message.content

        try:
            result = result.replace("```json", "").replace("```", "").strip()
            turns_json = json.loads(result)
            turns: List[Turn] = [
                Turn(
                    turn_id=str(uuid.uuid4()),
                    speaker=turn["speaker"],
                    content=turn["content"],
                )
                for turn in turns_json
            ]
            new_session = Session(session_id=f"S_{answer.answer_id}", turns=turns)
        except Exception as e:
            raise ValueError(f"Failed to parse session JSON: {e}\nQuestion: {question}\nAnswer: {answer}\nRaw response:\n{result}")
        return new_session
    
    def _session_to_str(self, session: Session) -> str:
        return "\n".join(f"{turn.speaker}: {turn.content}" for turn in session.turns).strip()

    def _transform_to_dialogue(self, sessions: List[Session]) -> List[Memory]:
        # TODO: implement here
        memories = []
        for session in sessions:
            memory = Memory(
                memory_id=session.session_id,
                memory_type=MemoryType.DIALOGUE,
                content=self._session_to_str(session),
                metadata=session.metadata,
            )
            memories.append(memory)
        return memories
    
    def _transform_to_observation(self, sessions: List[Session]) -> List[Memory]:
        # TODO: implement here
        memories = []
        for session in sessions:
            memory = Memory(
                memory_id=session.session_id,
                memory_type=MemoryType.OBSERVATION,
                content=self._session_to_str(session),
                metadata=session.metadata,
            )
            memories.append(memory)
        return memories

    def _transform_to_summary(self, sessions: List[Session]) -> List[Memory]:
        # TODO: implement here
        memories = []
        for session in sessions:
            memory = Memory(
                memory_id=session.session_id,
                memory_type=MemoryType.SUMMARY,
                content=self._session_to_str(session),
                metadata=session.metadata,
            )
            memories.append(memory)
        return memories
    
    def _transform_to_episodic_memory(self, sessions: List[Session]) -> List[Memory]:
        # TODO: implement here
        memories = []
        for session in sessions:
            memory = Memory(
                memory_id=session.session_id,
                memory_type=MemoryType.EPISODIC_MEMORY,
                content=self._session_to_str(session),
                metadata=session.metadata,
            )
            memories.append(memory)
        return memories
    
    def _transform_to_semantic_memory(self, sessions: List[Session]) -> List[Memory]:
        return []  # TODO: implement here

    # def run(
    #     self,
    #     qa_dataset: List[QA],
    #     colored_graph: ColoredGraph,
    # ) -> Tuple[List[QA], List[User]]:
    #     n_colors = colored_graph.n_colors
    #     user_memory_dict = {f"U_{color_id}": list() for color_id in range(n_colors)}

    #     for qa_sample in tqdm(qa_dataset):
    #         for answer in qa_sample.answers:
    #             answer_id = answer.answer_id
    #             node_id = colored_graph.get_node_id(answer_id=answer_id)
    #             color_id = colored_graph.get_color_id(node_id=node_id)
    #             user_id = f"U_{color_id}"
    #             session = self._generate_session(
    #                 question=qa_sample.question,
    #                 answer=answer,
    #                 prev_sessions=user_memory_dict[user_id][-3:],
    #             )
    #             user_memory_dict[user_id].append(session)
    #             answer.user_id = user_id
    #             answer.session_id = session.session_id

    #             # TODO: implement here
    #             answer.ref_memory_ids = [session.session_id]
    #             answer.ref_memory_contents = self._session_to_str(session)
        
    #     memory_dataset = [
    #         User(
    #             user_id=user_id,
    #             raw_dialogue=sessions,
    #             dialogue=self._transform_to_dialogue(sessions),
    #             observation=self._transform_to_observation(sessions),
    #             summary=self._transform_to_summary(sessions),
    #             episodic_memory=self._transform_to_episodic_memory(sessions),
    #             semantic_memory=self._transform_to_semantic_memory(sessions),
    #         )
    #         for user_id, sessions in user_memory_dict.items()
    #     ]

    #     return qa_dataset, memory_dataset


    async def _generate_session_async(self, question: str, answer: Answer, prev_sessions: List[Session]) -> Session:
        evidences_str = "\n".join(f"{e.question}: {e.answer}" for e in answer.evidences)
        prompt = MEMORY_GENERATION_PROMPT_TEMPLATE.format(
            question=question,
            short_answer=answer.short_answer,
            rubric_question=answer.rubric_question,
            evidences_str=evidences_str,
        )
        async with self.sem:
            resp = await self.client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=256,
                temperature=1.0,
                response_format={
                    "type": "json_schema",
                    "json_schema": {
                        "name": "dialogue_schema",
                        "strict": True,
                        "schema": {
                            "type": "object",
                            "properties": {
                                "dialogue": {
                                    "type": "array",
                                    "items": {
                                        "type": "object",
                                        "properties": {
                                            "speaker": {"type": "string"},
                                            "content": {"type": "string"}
                                        },
                                        "required": ["speaker", "content"],
                                        "additionalProperties": False
                                    }
                                }
                            },
                            "required": ["dialogue"],
                            "additionalProperties": False
                        }
                    }
                },
            )
        content = resp.choices[0].message.content
        content = content.replace("```json", "").replace("```", "").strip()
        try:
            turns_json = json.loads(content)
            turns = [Turn(
                turn_id=str(uuid.uuid4()),
                speaker=t["speaker"],
                content=t["content"])
                for t in turns_json["dialogue"]
            ]
            return Session(session_id=f"S_{answer.answer_id}", turns=turns)
        except json.JSONDecodeError as e:
            print(f"Failed to parse session JSON: {e}\nQuestion: {question}\nAnswer: {answer}\nRaw response:\n{content}")
            return Session(session_id=f"S_{answer.answer_id}", turns=[Turn(turn_id=str(uuid.uuid4()), speaker="", content="")])

    async def run(self, qa_dataset: List[QA], colored_graph: ColoredGraph) -> Tuple[List[QA], List[User]]:
        n_colors = colored_graph.n_colors
        user_memory_dict = {f"U_{i}": [] for i in range(n_colors)}

        tasks = []
        for qa_sample in qa_dataset:
            for answer in qa_sample.answers:
                node_id = colored_graph.get_node_id(answer_id=answer.answer_id)
                color_id = colored_graph.get_color_id(node_id=node_id)
                user_id = f"U_{color_id}"
                prev = user_memory_dict[user_id][-3:]
                tasks.append((qa_sample, answer, user_id, prev))

        async def worker(qa_sample, answer, user_id, prev):
            session = await self._generate_session_async(qa_sample.question, answer, prev)
            user_memory_dict[user_id].append(session)
            answer.user_id = user_id
            answer.session_id = session.session_id
            answer.ref_memory_ids = [session.session_id]
            answer.ref_memory_contents = [self._session_to_str(session)]

        # await asyncio.gather(*(worker(*t) for t in tasks))
        await tqdm_asyncio.gather(*(worker(*t) for t in tasks), desc="Generating sessions")

        memory_dataset = [
            User(
                user_id=uid,
                raw_dialogue=sessions,
                dialogue=self._transform_to_dialogue(sessions),
                observation=self._transform_to_observation(sessions),
                summary=self._transform_to_summary(sessions),
                episodic_memory=self._transform_to_episodic_memory(sessions),
                semantic_memory=self._transform_to_semantic_memory(sessions),
            )
            for uid, sessions in user_memory_dict.items()
        ]
        return qa_dataset, memory_dataset
