from src.prompts.agent_prompt import (
    ascendc_template,
    ascend_optimize_template,
    generate_optimize_guidelines_template
)
from src.prompts.retrieve_prompt import get_task_api_description
from src.prompts.utils import read_example_files
from src.providers.base import BaseLLM, LLMOutput
from src.configs.agent import CodeAgentConfig
from src.agents.retriever import MemoryRetriever, RetrievalResult, EMPTY
from src.memorykit.note import NoteType
from src.utils import (
    extract_first_code,
)
from src.models.agent_models import ActionType
from typing import List, Dict, Any
import json
import logging

logger = logging.getLogger(__name__)


class CodeAgent:
    def __init__(self, retriever: MemoryRetriever, config: CodeAgentConfig, llm: BaseLLM):
        self.llm: BaseLLM = llm
        self.config: CodeAgentConfig = config
        self.retriever = retriever

    def _draft_retrieve(self, arc_src: str, op: str, correct_attempt: Dict):
        """Retrieve memories for draft action."""
        retrieved_note_ids: List[str] = []
        error_results, exp_results, api_results = EMPTY, EMPTY, EMPTY
        exps_from_previous_attempts = []
        
        task = self._create_task_description(arc_src, op)
        
        if NoteType.CODE_EXAMPLE in self.config.enabled_memory_types:
            # Retrieve failed code examples with experiences
            error_results = self.retriever.find_recent_error_attempts_with_experiences(
                op=op,
                top_k=1,
            )
            
            logger.info(
                f"[_draft_retrieve ({op})] find_recent_error_attempts_with_experiences: Retrieved {len(error_results.ids)} error attempts, "
                f"note_ids={error_results.ids}"
            )
            retrieved_note_ids.extend(error_results.ids)
            
            # Collect experience IDs and contents from error attempts
            for attempt in error_results.contents:
                retrieved_note_ids.extend(attempt.get("reviews", []))
                exps_from_previous_attempts.extend(attempt.get("exp_contents", []))
        

        # Retrieve experiences
        if NoteType.EXPERIENCE in self.config.enabled_memory_types:
            exp_results = self.retriever.find_experiences(
                op=op,
                task=task,
                specific_exp_k=2,
                general_exp_k=2,
                exclude_ids=retrieved_note_ids,
                similarity_threshold=0.8,
            )
            retrieved_note_ids.extend(exp_results.ids)
            logger.info(
                f"[_draft_retrieve ({op})] find_experiences: Retrieved {len(exp_results.ids)} experiences, "
                f"note_ids={exp_results.ids}"
            )
            
            # Merge all experience contents
            exps_from_previous_attempts.extend([exp.get("content", "") for exp in exp_results.contents])
            
        # Retrieve APIs
        if NoteType.API in self.config.enabled_memory_types:
            example_op_name = correct_attempt.get("op_name")
            example_arc_src = correct_attempt.get("arc_src")
            example_code= correct_attempt.get("code")
        
            func_descs = self._create_needed_api_descriptions(
                arc_src=arc_src,
                op=op,
                example_op_name=example_op_name,
                example_arc_src=example_arc_src,
                example_code=example_code,
                experience=exps_from_previous_attempts
            )
            
            api_results = self.retriever.find_related_apis(
                descriptions=func_descs, total_k=4
            )
            retrieved_note_ids.extend(api_results.ids)
            logger.info(
                f"[_draft_retrieve ({op})] find_related_apis: Retrieved {len(api_results.ids)} APIs, "
                f"note_ids={api_results.ids}"
            )
        
        logger.info(
            f"[_draft_retrieve ({op})] Completed: Retrieved {len(retrieved_note_ids)} memories in total, note_ids={retrieved_note_ids}"
        )
        return error_results.contents, exp_results.contents, api_results.contents, retrieved_note_ids

    def _optimize_retrieve(self, arc_src: str, op: str, cur_attempt: Dict):
        """Retrieve memories for optimize action."""
        retrieved_note_ids = []
        failed_attempts, optimized_results, succ_attempts, exp_results = EMPTY, EMPTY, EMPTY, EMPTY

        # Failed attempts
        failed_attempts = self.retriever.find_failed_optimize_attempts(parent_memory_id=cur_attempt.get("memory_id", ""), k=1)
        for attempt in failed_attempts.contents:
            retrieved_note_ids.extend(attempt.get("reviews", []))

        retrieved_note_ids.extend(failed_attempts.ids)

        succ_attempts = self.retriever.find_succ_optimize_attempts(parent_memory_id=cur_attempt.get("memory_id", ""))

        task = self._create_task_description(arc_src, op)
        if not failed_attempts.ids:
            # Reference implementation
            optimized_results = self.retriever.find_optimized_references(
                task=task,
                attempt=cur_attempt,
                top_k=1,
                similarity_threshold=0.8,
            )
            retrieved_note_ids.extend(optimized_results.ids)

        # Optimization history
        optimize_history = self.retriever.find_optimization_history(memory_id=cur_attempt.get("memory_id", ""))

        # Experiences
        if NoteType.EXPERIENCE in self.config.enabled_memory_types:
            exp_results = self.retriever.find_failed_optimize_experiences(
                parent_memory_id=cur_attempt.get("memory_id", ""),
                optimize_k=2,
                draft_k=2,
                exclude_ids=retrieved_note_ids,
            )
            retrieved_note_ids.extend(exp_results.ids)
            logger.info(
                f"[Retrieval result] find_experiences: Retrieved {len(exp_results.ids)} experiences, "
                f"note_ids={exp_results.ids}"
            )
            
            # Merge all experience contents

        # Retrieve best practices
        query = generate_optimize_guidelines_template(
            arc_src=arc_src, op=op, 
            code=cur_attempt.get("code", ""), 
            optimize_history=optimize_history.contents, 
            failed_attempts=failed_attempts.contents,
            succ_attempts=succ_attempts.contents,
        )
        logger.info(f"generate_optimize_guidelines_template(op-{op}):\n{query}")
        llm_output: LLMOutput = self.llm.generate_single(query, extract_first_block=True, language_type="json")
        response_json = llm_output.response_txt
        logger.info(f"optimize guidelines response_json(op-{op}):\n{response_json}")
        best_practice = self.retriever.find_best_practice(
            category=response_json.get("category", ""), 
            action=response_json.get("action", ""),
            k=1
        )
        retrieved_note_ids.extend(best_practice.ids)

        return failed_attempts.contents, optimized_results.contents, succ_attempts.contents, optimize_history.contents, best_practice.contents, exp_results.contents, retrieved_note_ids

    def _draft(self, arc_src: str, op: str, correct_attempt: Dict):
        draft_result = {}

        error_attempts, exp_contents, apis_content, retrieved_note_ids = self._draft_retrieve(arc_src, op, correct_attempt)

        example_op = correct_attempt.get("op_name")
        example_arc_src = correct_attempt.get("arc_src")
        example_new_arc_src = correct_attempt.get("code")
        example_performance = correct_attempt.get("performance")
        
        draft_template = ascendc_template(
            arc_src=arc_src,
            example_arc_src=example_arc_src,
            example_new_arc_src=example_new_arc_src,
            op=op,
            example_op=example_op,
            example_performance=example_performance,
            error_attempts=error_attempts,
            apis_contents=apis_content,
            experience_contents=exp_contents,
        ) 

        logger.info(f"draft_template for op {op}:\n{draft_template}")

        llm_output: LLMOutput = self.llm.generate_single(draft_template)
        response_txt = llm_output.response_txt
        
        plan, response_code = extract_first_code(response_txt, ['python', 'cpp'])
        
        if response_code is None:
            response_code = response_txt
        draft_result.update({"plan": plan, "code": response_code, "retrieved_note_ids": retrieved_note_ids})
        if error_attempts:
            # TODO: hard code the first error attempt
            draft_result.update({"error_attempt": error_attempts[0].get("code", ""), "error_exps": error_attempts[0].get("exp_contents", [])})
        if exp_contents:
            draft_result.update({"experience_contents": exp_contents})
            
        if llm_output.reasoning_content:
            draft_result['reasoning_content'] = llm_output.reasoning_content
        if llm_output.total_tokens:
            draft_result['total_tokens'] = llm_output.total_tokens
        if llm_output.completion_tokens:
            draft_result['completion_tokens'] = llm_output.completion_tokens
        if llm_output.prompt_tokens:
            draft_result['prompt_tokens'] = llm_output.prompt_tokens
        return draft_result

    def _optimize(self, arc_src: str, op: str, cur_attempt: Dict) -> str:
        optimize_result = {"parent": cur_attempt.get("memory_id", "")}

        failed_attempts, optimized_results,succ_attempts, optimize_history, best_practice, exp_contents, retrieved_note_ids = self._optimize_retrieve(arc_src, op, cur_attempt)

        optimize_template = ascend_optimize_template(
            op=op, 
            arc_src=arc_src,
            last_attempt=cur_attempt.get("code", ""),
            failed_attempt=failed_attempts[0] if failed_attempts else None,
            succ_attempt_plans=[
                {
                    "plan": attempt.get("plan", ""), 
                    "optimized_degree_to_parent": attempt.get("optimized_degree_to_parent", ""), 
                    "optimized_degree_to_root": attempt.get("optimized_degree_to_root", ""),
                    "optimized_degree_to_best": attempt.get("optimized_degree_to_best", "")
                } for attempt in succ_attempts
            ],
            optimized_attempt=optimized_results[0] if optimized_results else None,
            best_practice=best_practice[0] if best_practice else None,
            experience_contents=exp_contents,
        )

        logger.info(f"optimize_template for op-{op}:\n{optimize_template}")

        llm_output: LLMOutput = self.llm.generate_single(optimize_template)
        response_txt = llm_output.response_txt
        plan, response_code = extract_first_code(response_txt, ['python', 'cpp'])
        if response_code is None:
            response_code = response_txt
        optimize_result.update({"plan": plan, "code": response_code, "retrieved_note_ids": retrieved_note_ids})

        if llm_output.reasoning_content:
            optimize_result['reasoning_content'] = llm_output.reasoning_content
        if llm_output.total_tokens:
            optimize_result['total_tokens'] = llm_output.total_tokens
        if llm_output.completion_tokens:
            optimize_result['completion_tokens'] = llm_output.completion_tokens
        if llm_output.prompt_tokens:
            optimize_result['prompt_tokens'] = llm_output.prompt_tokens
        return optimize_result

    def generate(self, arc_src: str, op: str) -> Dict[str, Any]:
        """Generate code based on the current state and selected action."""

        # 1. find a correct reference as a starting point in stage optimize/ as a reference for draft
        correct_result: RetrievalResult = self.retriever.find_correct_references(
            task=op + "\n" + arc_src, 
            top_k=1
        )
        correct_attempt = correct_result.contents[0]

        if op == correct_attempt.get("op_name") and correct_attempt.get("performance") and ActionType.OPTIMIZE in self.config.enabled_stages:
            action = ActionType.OPTIMIZE
        else:
            action = ActionType.DRAFT

        result = {'action': action, "retrieved_note_ids": []}
        if action == ActionType.DRAFT:
            # For the draft stage, correct_result serves as a reference implementation
            draft_result = self._draft(arc_src, op, correct_attempt)
            draft_result['example'] = correct_attempt

            result.update(draft_result)
            result["retrieved_note_ids"].extend(correct_result.ids)
        elif action == ActionType.OPTIMIZE:
            # For the optimize stage, correct_result serves as the code example to be optimized

            similar_correct_result = self.retriever.find_correct_references_by_op_name(op_name=op, top_k=1)
            if similar_correct_result.ids:
                correct_result = similar_correct_result
            
            correct_attempt = correct_result.contents[0]
            optimize_result = self._optimize(arc_src, op, correct_attempt)
            optimize_result['parent'] = correct_attempt
            
            result.update(optimize_result)
            result["retrieved_note_ids"].extend(correct_result.ids)
        else:
            raise ValueError(f"Unsupported action: {action}")
        return result
    
    def _create_task_description(self, arc_src: str, op: str):
        return op + '\n' + arc_src
    
    def _create_needed_api_descriptions(self, arc_src: str, op: str, example_op_name: str, example_arc_src: str, example_code: str, experience: List[str]) -> List[str]:
        """
        Use LLM to generate a set of short 'functional descriptions' (functional requirements) to facilitate semantic retrieval of API-related memories. Returns a JSON list.
        """
        # add code_example and experience for api retrieve
        prompt = get_task_api_description(
            op = op,
            arc_src = arc_src,
            example_op = example_op_name,
            example_arc_src = example_arc_src,
            example_new_arc_src = example_code,
            exps = experience,
        )
        logger.info((f"get_task_api_description_template for op-{op}:\n{prompt}"))
        llm_output: LLMOutput = self.llm.generate_single(prompt, extract_first_block=True, language_type="json")
        response_json = llm_output.response_txt
        logger.info((f"get_task_api_description response for op-{op}:\n{response_json}"))
        return response_json