"""
Question Answer Task Handler

This module provides a task handler for the Question Answer task.
"""

import asyncio
import copy
import json
import sys
from typing import Any, Dict, List, Optional

import pandas as pd
from pydantic import create_model

from src.core.registry import task_registry
from src.llm.base import LLMInterface
from src.tasks.base import TaskInterface
from src.tasks.question_answer_psao_optuna.data_handler import QATaskDataHandler
from src.tasks.question_answer_psao_optuna.eval_handler import QATaskEvaluator
from src.utils.decorator_utils import with_logger
from src.utils.error_tracking import ErrorTracker

from .psao_optuna import psao_optuna_optimisation


@task_registry.register("question_answer_psao_optuna")
class QAPSAOOptunaTask(TaskInterface):
    @with_logger
    def __init__(
        self,
        data_path: str,
        dataset_name: str,
        train_test_flag: str,
        prompt_msg_template: List[Dict[str, Any]],
        id_lst: List[int],
        optuna_study_name: str,
        optuna_db_name: str,
        ann_option_list: List[str],
    ):
        self.train_test_flag = train_test_flag
        self.prompt_msg_template = prompt_msg_template

        self.data_handler = QATaskDataHandler(data_path, dataset_name)
        self.eval_handler = QATaskEvaluator()

        self.answer_dtype = self.data_handler.get_answer_dtype()
        print(self.answer_dtype)
        self.answer_schema = self.create_qa_task_answer_model()

        if len(id_lst) == 0:
            data_size = self.data_handler.get_size(self.train_test_flag)
            self.id_lst = list(range(data_size))
        else:
            self.id_lst = id_lst

        self.optuna_study_name = optuna_study_name
        self.optuna_db_name = optuna_db_name
        self.ann_option_list = ann_option_list

    @with_logger
    def create_qa_task_answer_model(self):
        return create_model(
            "QATaskAnswerDynamic",
            question=str,
            reasoning_steps=list[str],
            final_answer=self.answer_dtype,
        )

    @with_logger
    def load_data(self, **kwargs):
        return None

    @with_logger
    def evaluate(self, results, **kwargs):
        return None

    @with_logger
    def create_prompt(self, data_item: Dict) -> List[Dict[str, Any]]:
        """
        Create a prompt for a specific data item.
        Args:
            data_item: The data item to create a prompt for
        Returns:
            A list of message dictionaries

        """
        print(data_item)
        question, _, _, _ = data_item
        messages = copy.deepcopy(self.prompt_msg_template)
        for msg in messages:
            if msg["role"] == "user":
                msg["content"] = msg["content"].format(question=question)
        return messages

    @with_logger
    def get_answer(self, data_item: Dict) -> List[Dict[str, Any]]:
        """
        Create a prompt for a specific data item.
        Args:
            data_item: The data item to create a prompt for
        Returns:
            A list of message dictionaries

        """
        _, answer, _, _ = data_item
        return answer

    @with_logger
    def get_seg_lst(self, data_item: Dict) -> List[Dict[str, Any]]:
        """
        Create a prompt for a specific data item.
        Args:
            data_item: The data item to create a prompt for
        Returns:
            A list of message dictionaries

        """
        _, _, _, seg_lst = data_item

        seg_lst = eval(seg_lst)

        return seg_lst

    @with_logger
    def run(
        self,
        llm: LLMInterface,
        error_tracker: Optional[ErrorTracker] = None,
    ):
        """
        Run the GSM8K task.

        Args:
            llm: The language model interface to use
            error_tracker: Optional error tracker to record failed LLM invocations
        """
        # Create error tracker if not provided
        if error_tracker is None:
            error_tracker = ErrorTracker()

        async def _get_response(
            llm: LLMInterface,
            prompt: List[Dict[str, Any]],
            answer: Any,
            seg_lst: List[str],
            semaphore: asyncio.Semaphore,
            task_id: str,
        ) -> str:

            prompt_opt, score_best = psao_optuna_optimisation(
                llm,
                prompt,
                answer,
                self.answer_schema,
                self.answer_dtype,
                task_id,
                self.optuna_study_name,
                self.optuna_db_name,
                seg_lst,
                self.ann_option_list,
            )

            async with semaphore:
                try:
                    resp = await asyncio.to_thread(
                        llm.generate,
                        messages=prompt_opt,
                        response_format=self.answer_schema,
                    )
                    return (resp, score_best)
                except Exception as e:
                    # Record error with full context
                    error_tracker.record_error(
                        error=e,
                        input_messages=prompt_opt,
                        model_config=llm.model_info,
                        retry_attempts=(
                            llm.retry_config.max_attempts if llm.retry_config else 0
                        ),
                        task_id=task_id,
                    )

                    return (
                        json.dumps(
                            {
                                "question": prompt_opt,
                                "reasoning_steps": [
                                    f"Error occurred during LLM invocation: {str(e)}"
                                ],
                                "final_answer": None,
                            }
                        ),
                        0,
                    )

        async def _async_main():
            semaphore = asyncio.Semaphore(50)  # Limit to 50 concurrent coroutines
            tasks = [
                _get_response(
                    llm,
                    self.create_prompt(
                        self.data_handler.get_data_by_id(
                            self.train_test_flag,
                            item_id,
                        )
                    ),
                    self.get_answer(
                        self.data_handler.get_data_by_id(
                            self.train_test_flag,
                            item_id,
                        )
                    ),
                    self.get_seg_lst(
                        self.data_handler.get_data_by_id(
                            self.train_test_flag,
                            item_id,
                        )
                    ),
                    semaphore,
                    task_id=str(item_id),
                )
                for item_id in self.id_lst
            ]

            # Use return_when=ALL_COMPLETED to continue with partial results
            _responses_scores = await asyncio.gather(
                *tasks,
                return_exceptions=True,
            )

            # Convert any remaining exceptions to error placeholders
            processed_responses = []
            processed_best_score = []
            for i, r_s in enumerate(_responses_scores):

                if isinstance(r_s, Exception):
                    # This shouldn't happen due to our error handling, but just in case
                    error_tracker.record_error(
                        error=r_s,
                        input_messages=[
                            {"role": "system", "content": "Unknown prompt"}
                        ],
                        model_config=llm.model_info,
                        retry_attempts=0,
                        task_id=str(self.id_lst[i]),
                    )
                    processed_responses.append(
                        json.dumps(
                            {
                                "question": "Unknown",
                                "reasoning_steps": [f"Unexpected error: {str(r_s)}"],
                                "final_answer": None,
                            }
                        )
                    )
                    processed_best_score.append(0.0)
                else:
                    response, score = r_s
                    processed_responses.append(response)
                    processed_best_score.append(score)

            return processed_responses, processed_best_score

        responses, best_scores = asyncio.run(_async_main())

        result_df = self.data_handler.get_data_by_id_lst(
            self.train_test_flag,
            self.id_lst,
        )

        result_df.loc[:, ["response"]] = responses
        result_df.loc[:, ["best_score"]] = best_scores
        result_df.loc[:, ["response"]] = result_df.loc[:, "response"].apply(json.loads)
        result_df.loc[:, ["llm_answer"]] = result_df.loc[:, "response"].apply(
            lambda x: x.get("final_answer")
        )

        # Calculate scores, treating None answers from errors as incorrect
        # Handle both numeric answers (GSM8K) and letter-based answers (AQuA)
        valid_llm_answers = result_df.loc[:, "llm_answer"].notna()

        # Try numeric comparison first (for datasets like GSM8K)
        answer_numeric = pd.to_numeric(result_df.loc[:, "answer"], errors="coerce")
        llm_answer_numeric = pd.to_numeric(
            result_df.loc[:, "llm_answer"], errors="coerce"
        )

        # Check if both ground truth and LLM answers can be converted to numeric
        numeric_answers_valid = answer_numeric.notna() & llm_answer_numeric.notna()

        if numeric_answers_valid.any():
            # Use numeric comparison for numeric answers
            result_df.loc[:, ["score"]] = (
                (answer_numeric == llm_answer_numeric) & valid_llm_answers
            ).astype(int)
        else:
            # Convert the LLM answers to the same type as the ground truth answers
            # Not using str because sometime the float to str may lose precision
            result_df.loc[:, ["score"]] = (
                (
                    result_df.loc[:, "answer"]
                    == result_df.loc[:, "llm_answer"].astype(self.answer_dtype)
                )
                & valid_llm_answers
            ).astype(int)

        score = self.eval_handler.get_eval_score(result_df)

        # Log error statistics if any errors occurred
        if error_tracker.get_error_count() > 0:
            total_attempts = len(self.id_lst)
            success_rate = error_tracker.get_success_rate(total_attempts)
            logger.info(
                f"LLM invocation errors: {error_tracker.get_error_count()}/{total_attempts} failed"
            )
            logger.info(f"Success rate: {success_rate:.2%}")

        return result_df, score, error_tracker

    @with_logger
    def get_prompt_msg_template(self) -> List[Dict[str, Any]]:
        """
        Get the prompt message template.

        Returns:
            The prompt message template
        """
        return copy.deepcopy(self.prompt_msg_template)

    @with_logger
    def update_prompt_msg_template(
        self, new_prompt_msg_template: List[Dict[str, Any]]
    ) -> None:
        """
        Update the prompt message template.

        Args:
            new_prompt_msg_template: The new prompt message template
        """
        self.prompt_msg_template = copy.deepcopy(new_prompt_msg_template)
