from colmind.agents.agent import LLMAgent
import dspy
from pathlib import Path
import colmind.utils as U
import random
import re
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma

class NextTask(dspy.Signature):
    """Generate the next task in the curriculum."""
    context: str = dspy.InputField(desc="Current context about the environment")
    biome: str = dspy.InputField(desc="Current biome")
    time: str = dspy.InputField(desc="Current time of day")
    nearby_blocks: list = dspy.InputField(desc="Blocks in the immediate vicinity")
    other_blocks: list = dspy.InputField(desc="Other blocks seen recently")
    nearby_entities: dict = dspy.InputField(desc="Nearby entities")
    health: float = dspy.InputField(desc="Current health")
    hunger: float = dspy.InputField(desc="Current hunger level")
    position: dict = dspy.InputField(desc="Current position")
    equipment: str = dspy.InputField(desc="Current equipment")
    inventory: dict = dspy.InputField(desc="Current inventory")
    chests: str = dspy.InputField(desc="Chest observations")
    completed_tasks: list = dspy.InputField(desc="Previously completed tasks")
    failed_tasks: list = dspy.InputField(desc="Previously failed tasks")

    next_task: str = dspy.OutputField(desc="The next task to attempt")

class TaskDecomposition(dspy.Signature):
    """Decompose a complex task into subtasks."""
    final_task: str = dspy.InputField(desc="The target task to decompose")
    context: str = dspy.InputField(desc="Current environment context")
    biome: str = dspy.InputField()
    time: str = dspy.InputField()
    nearby_blocks: list = dspy.InputField()
    nearby_entities: dict = dspy.InputField()
    inventory: dict = dspy.InputField()

    subtasks: list = dspy.OutputField(desc="List of subtasks")
    reasoning: str = dspy.OutputField(desc="Explanation of the decomposition")

class Question(dspy.Signature):
    """Generate questions about the environment."""
    context: str = dspy.InputField()
    biome: str = dspy.InputField()
    time: str = dspy.InputField()
    nearby_blocks: list = dspy.InputField()
    nearby_entities: dict = dspy.InputField()
    inventory: dict = dspy.InputField()

    questions: list = dspy.OutputField(desc="List of relevant questions")
    concepts: list = dspy.OutputField(desc="Core concepts for each question")

class Answer(dspy.Signature):
    """Answer questions about Minecraft mechanics."""
    question: str = dspy.InputField(desc="Question about Minecraft")

    answer: str = dspy.OutputField(desc="Detailed answer")

class CurriculumAgent(LLMAgent):
    def __init__(
        self,
        name: str,
        llm: str,
        qa_llm: str,
        temperature: float,
        request_timeout: int,
        ckpt_dir: str,
        resume: bool,
        mode: str,
        warm_up: dict = None,
        core_inventory_items: str = None,
        logger = None
    ):
        super().__init__(name, llm, temperature, logger)
        self.qa_llm = get_llm(qa_llm) if qa_llm != llm else self.llm
        assert mode in ["auto", "manual"]
        self.mode = mode
        self.ckpt_dir = Path(ckpt_dir)
        self.request_timeout = request_timeout

        # Initialize QA caching
        qa_cache_dir = self.ckpt_dir / "curriculum/vectordb"
        qa_cache_dir.mkdir(parents=True, exist_ok=True)

        if resume:
            self.completed_tasks = U.load_json(self.ckpt_dir / "curriculum/completed_tasks.json")
            self.failed_tasks = U.load_json(self.ckpt_dir / "curriculum/failed_tasks.json")
            self.qa_cache = U.load_json(self.ckpt_dir / "curriculum/qa_cache.json")
        else:
            self.completed_tasks = []
            self.failed_tasks = []
            self.qa_cache = {}

        # Initialize vector DB for QA
        self.qa_vectordb = Chroma(
            collection_name="qa_cache_questions",
            embedding_function=OpenAIEmbeddings(),
            persist_directory=str(qa_cache_dir)
        )

        # Warm up settings
        self.warm_up = warm_up or self.default_warmup
        if core_inventory_items:
            self._core_inv_items_regex = re.compile(core_inventory_items)

    @property
    def default_warmup(self):
        return {
            "context": 15,
            "biome": 10,
            "time": 15,
            "nearby_blocks": 0,
            "other_blocks": 10,
            "nearby_entities": 5,
            "health": 15,
            "hunger": 15,
            "position": 0,
            "equipment": 0,
            "inventory": 0,
            "optional_inventory_items": 7,
            "chests": 0,
            "completed_tasks": 0,
            "failed_tasks": 0,
        }

    def __call__(self, context: dict):
        """Generate next task based on current context."""
        if self.progress == 0 and self.mode == "auto":
            return "Mine 1 wood log", "You can mine oak, birch, spruce, jungle, acacia, dark oak, or mangrove logs."

        # Check inventory fullness
        if context["status"]["inventoryUsed"] >= 33:
            return self._handle_full_inventory(context)

        if self.mode == "auto":
            generate_task = dspy.Predict(NextTask)
            task = generate_task(**context)
            task_context = self.get_task_context(task.next_task)
            return task.next_task, task_context
        else:
            return self._get_manual_task()

    def decompose_task(self, task: str, context: dict):
        """Break down complex task into subtasks."""
        generate_decomposition = dspy.Predict(TaskDecomposition)
        decomposition = generate_decomposition(
            final_task=task,
            **context
        )
        return {"subtasks": decomposition.subtasks, "reasoning": decomposition.reasoning}

    def run_qa(self, context: dict):
        """Run the QA pipeline to gather relevant information."""
        # Generate questions
        generate_questions = dspy.Predict(Question)
        qa = generate_questions(**context)

        questions = []
        answers = []

        # Default biome questions
        biome = context["biome"].replace("_", " ")
        default_questions = [
            f"What can I find in the {biome} in Minecraft?",
            f"What blocks exist in the {biome} in Minecraft?",
            f"What mobs spawn in the {biome} in Minecraft?"
        ]

        all_questions = default_questions + qa.questions

        # Get answers for each question
        for question in all_questions:
            answer = self._get_cached_or_new_answer(question)
            if answer:
                questions.append(question)
                answers.append(answer)

        return questions, answers

    def _get_cached_or_new_answer(self, question: str):
        """Get cached answer or generate new one."""
        if self.qa_vectordb._collection.count() > 0:
            similar = self.qa_vectordb.similarity_search_with_score(question, k=1)
            if similar and similar[0][1] < 0.05:
                cached_q = similar[0][0].page_content
                return self.qa_cache[cached_q]

        generate_answer = dspy.Predict(Answer)
        answer = generate_answer(question=question)

        # Cache the new QA pair
        self.qa_cache[question] = answer.answer
        self.qa_vectordb.add_texts([question])
        U.dump_json(self.qa_cache, self.ckpt_dir / "curriculum/qa_cache.json")
        self.qa_vectordb.persist()

        return answer.answer

    def _handle_full_inventory(self, context: dict):
        """Handle case when inventory is nearly full."""
        if "chest" in context["inventory"]:
            return "Place a chest", "Place a chest nearby to store items."
        elif context["chest_observation"] != "Chests: None\n\n":
            chest_pos = context["chest_observation"].split(":")[0]
            return (f"Deposit useless items into chest at {chest_pos}",
                   "Deposit until only 20 inventory slots are used.")
        else:
            return "Craft 1 chest", "Craft a chest using 8 wood planks."

    def _get_manual_task(self):
        """Get task input from human."""
        task = input("Enter task: ").strip()
        context = input("Enter context: ").strip()
        return task, context

    def update_exploration_progress(self, info: dict):
        """Update task completion status."""
        task = info["task"]
        if not task.startswith("Deposit"):
            if info["success"]:
                self.completed_tasks.append(task)
            else:
                self.failed_tasks.append(task)
            self._clean_up_tasks()

    def _clean_up_tasks(self):
        """Clean up and persist task lists."""
        # Remove duplicates while preserving order
        completed = []
        for task in self.completed_tasks:
            if task not in completed:
                completed.append(task)

        # Remove completed tasks from failed list
        failed = [t for t in self.failed_tasks if t not in completed]

        self.completed_tasks = completed
        self.failed_tasks = failed

        # Save to disk
        U.dump_json(self.completed_tasks, self.ckpt_dir / "curriculum/completed_tasks.json")
        U.dump_json(self.failed_tasks, self.ckpt_dir / "curriculum/failed_tasks.json")

    @property
    def progress(self):
        return len(self.completed_tasks)

    def restore(self):
        # TODO: implement checkpoint restoration
        pass
