"""
Question Answer Task Handler

This module provides a task handler for the Question Answer task.
"""

import asyncio
import copy
import json
from typing import Any, Dict, List, Optional

import pandas as pd
from pydantic import create_model

from src.core.registry import task_registry
from src.llm.base import LLMInterface
from src.tasks.base import TaskInterface
from src.tasks.question_answer.data_handler import QATaskDataHandler
from src.tasks.question_answer.eval_handler import QATaskEvaluator
from src.utils.decorator_utils import with_logger
from src.utils.error_tracking import ErrorTracker


@task_registry.register("question_answer")
class QATask(TaskInterface):
    @with_logger
    def __init__(
        self,
        data_path: str,
        dataset_name: str,
        train_test_flag: str,
        prompt_msg_template: List[Dict[str, Any]],
        id_lst: List[int],
    ):
        self.train_test_flag = train_test_flag
        self.prompt_msg_template = prompt_msg_template

        self.data_handler = QATaskDataHandler(data_path, dataset_name)
        self.eval_handler = QATaskEvaluator()

        self.answer_dtype = self.data_handler.get_answer_dtype()
        logger.debug(self.answer_dtype)
        self.answer_schema = self.create_qa_task_answer_model()

        if len(id_lst) == 0:
            data_size = self.data_handler.get_size(self.train_test_flag)
            self.id_lst = list(range(data_size))
        else:
            self.id_lst = id_lst

    @with_logger
    def create_qa_task_answer_model(self):
        return create_model(
            "QATaskAnswerDynamic",
            question=str,
            reasoning_steps=list[str],
            final_answer=self.answer_dtype,
        )

    @with_logger
    def load_data(self, **kwargs):
        return None

    @with_logger
    def evaluate(self, results, **kwargs):
        return None

    @with_logger
    def create_prompt(self, data_item: Dict) -> List[Dict[str, Any]]:
        """
        Create a prompt for a specific data item.
        Args:
            data_item: The data item to create a prompt for
        Returns:
            A list of message dictionaries

        """
        question, _, _ = data_item
        messages = copy.deepcopy(self.prompt_msg_template)
        for msg in messages:
            if msg["role"] == "user":
                msg["content"] = msg["content"].format(question=question)
        return messages

    @with_logger
    def run(
        self,
        llm: LLMInterface,
        error_tracker: Optional[ErrorTracker] = None,
    ):
        """
        Run the GSM8K task.

        Args:
            llm: The language model interface to use
            error_tracker: Optional error tracker to record failed LLM invocations
        """
        # Create error tracker if not provided
        if error_tracker is None:
            error_tracker = ErrorTracker()

        async def _get_response(
            llm: LLMInterface,
            prompt: List[Dict[str, Any]],
            semaphore: asyncio.Semaphore,
            task_id: str,
        ) -> str:
            async with semaphore:
                try:
                    return await asyncio.to_thread(
                        llm.generate,
                        messages=prompt,
                        response_format=self.answer_schema,
                    )
                except Exception as e:
                    # Record error with full context
                    error_tracker.record_error(
                        error=e,
                        input_messages=prompt,
                        model_config=llm.model_info,
                        retry_attempts=(
                            llm.retry_config.max_attempts if llm.retry_config else 0
                        ),
                        task_id=task_id,
                    )

                    return json.dumps(
                        {
                            "question": prompt,
                            "reasoning_steps": [
                                f"Error occurred during LLM invocation: {str(e)}"
                            ],
                            "final_answer": None,
                        }
                    )

        async def _async_main():
            semaphore = asyncio.Semaphore(50)  # Limit to 50 concurrent coroutines
            tasks = [
                _get_response(
                    llm,
                    self.create_prompt(
                        self.data_handler.get_data_by_id(
                            self.train_test_flag,
                            item_id,
                        )
                    ),
                    semaphore,
                    task_id=str(item_id),
                )
                for item_id in self.id_lst
            ]
            # Use return_when=ALL_COMPLETED to continue with partial results
            _responses = await asyncio.gather(*tasks, return_exceptions=True)

            # Convert any remaining exceptions to error placeholders
            processed_responses = []
            for i, response in enumerate(_responses):
                if isinstance(response, Exception):
                    # This shouldn't happen due to our error handling, but just in case
                    error_tracker.record_error(
                        error=response,
                        input_messages=[
                            {"role": "system", "content": "Unknown prompt"}
                        ],
                        model_config=llm.model_info,
                        retry_attempts=0,
                        task_id=str(self.id_lst[i]),
                    )
                    processed_responses.append(
                        json.dumps(
                            {
                                "question": "Unknown",
                                "reasoning_steps": [
                                    f"Unexpected error: {str(response)}"
                                ],
                                "final_answer": None,
                            }
                        )
                    )
                else:
                    processed_responses.append(response)

            return processed_responses

        responses = asyncio.run(_async_main())

        result_df = self.data_handler.get_data_by_id_lst(
            self.train_test_flag,
            self.id_lst,
        )

        result_df.loc[:, ["response"]] = responses
        result_df.loc[:, ["response"]] = result_df.loc[:, "response"].apply(json.loads)
        result_df.loc[:, ["llm_answer"]] = result_df.loc[:, "response"].apply(
            lambda x: x.get("final_answer")
        )

        # Calculate scores, treating None answers from errors as incorrect
        # Handle both numeric answers (GSM8K) and letter-based answers (AQuA)
        valid_llm_answers = result_df.loc[:, "llm_answer"].notna()

        # Try numeric comparison first (for datasets like GSM8K)
        answer_numeric = pd.to_numeric(result_df.loc[:, "answer"], errors="coerce")
        llm_answer_numeric = pd.to_numeric(
            result_df.loc[:, "llm_answer"], errors="coerce"
        )

        # Check if both ground truth and LLM answers can be converted to numeric
        numeric_answers_valid = answer_numeric.notna() & llm_answer_numeric.notna()

        if numeric_answers_valid.any():
            # Use numeric comparison for numeric answers
            result_df.loc[:, ["score"]] = (
                (answer_numeric == llm_answer_numeric) & valid_llm_answers
            ).astype(int)
        else:
            # Convert the LLM answers to the same type as the ground truth answers
            # Not using str because sometime the float to str may lose precision
            result_df.loc[:, ["score"]] = (
                (
                    result_df.loc[:, "answer"]
                    == result_df.loc[:, "llm_answer"].astype(self.answer_dtype)
                )
                & valid_llm_answers
            ).astype(int)

        score = self.eval_handler.get_eval_score(result_df)

        # Log error statistics if any errors occurred
        if error_tracker.get_error_count() > 0:
            total_attempts = len(self.id_lst)
            success_rate = error_tracker.get_success_rate(total_attempts)
            logger.info(
                f"LLM invocation errors: {error_tracker.get_error_count()}/{total_attempts} failed"
            )
            logger.info(f"Success rate: {success_rate:.2%}")

        return result_df, score, error_tracker

    @with_logger
    def get_prompt_msg_template(self) -> List[Dict[str, Any]]:
        """
        Get the prompt message template.

        Returns:
            The prompt message template
        """
        return copy.deepcopy(self.prompt_msg_template)

    @with_logger
    def update_prompt_msg_template(
        self, new_prompt_msg_template: List[Dict[str, Any]]
    ) -> None:
        """
        Update the prompt message template.

        Args:
            new_prompt_msg_template: The new prompt message template
        """
        self.prompt_msg_template = copy.deepcopy(new_prompt_msg_template)
