"""
Base Task Interface

This module defines the base interface for tasks.
"""

from abc import ABC, abstractmethod
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd

from src.utils.decorator_utils import with_logger


class TaskInterface(ABC):
    """
    Interface for tasks.

    This abstract class defines the interface that all task implementations
    must adhere to.
    """

    @with_logger
    def __init__(self):
        """initialise the task interface."""
        pass  # Logger is now available through decorator

    @abstractmethod
    def load_data(self, **kwargs: Any) -> Any:
        """
        Load task-specific data.

        Args:
            **kwargs: Additional keyword arguments for data loading

        Returns:
            The loaded data
        """
        logger.info("Loading task data")
        pass

    @abstractmethod
    def create_prompt(self, data_item: Any, **kwargs: Any) -> List[Dict[str, Any]]:
        """
        Create a prompt for a specific data item.

        Args:
            data_item: The data item to create a prompt for
            **kwargs: Additional keyword arguments for prompt creation

        Returns:
            A list of message dictionaries, each with 'role' and 'content' keys
        """
        logger.info(f"Creating prompt for data item: {type(data_item)}")
        pass

    @abstractmethod
    def run(self, llm: Any, **kwargs: Any) -> Tuple[pd.DataFrame, float]:
        """
        Run the task with the given LLM and prompt optimiser.

        Args:
            llm: The language model to use
            **kwargs: Additional keyword arguments for task execution

        Returns:
            A tuple of (results_dataframe, evaluation_score)
        """
        logger.info(f"Running task with LLM: {type(llm).__name__}")
        pass

    @abstractmethod
    def evaluate(self, results: pd.DataFrame, **kwargs: Any) -> float:
        """
        Evaluate the results of the task.

        Args:
            results: The results to evaluate
            **kwargs: Additional keyword arguments for evaluation

        Returns:
            The evaluation score
        """
        logger.info(f"Evaluating results: {results.shape[0]} items")
        pass

    @abstractmethod
    def get_prompt_msg_template(self):
        pass

    @abstractmethod
    def update_prompt_msg_template(
        self,
        new_prompt_msg_template: List[Dict[str, Any]],
    ):
        pass

    @with_logger
    def log_data_stats(self, data: Any) -> None:
        """
        Log statistics about the loaded data.

        Args:
            data: The data to log statistics for
        """
        if isinstance(data, pd.DataFrame):
            logger.info(
                f"Data statistics: {data.shape[0]} rows, {data.shape[1]} columns"
            )
            logger.info(f"Column names: {list(data.columns)}")
            if logger.isEnabledFor(logging.INFO):
                for col in data.columns:
                    logger.info(
                        f"Column '{col}' - {data[col].dtype}, {data[col].nunique()} unique values"
                    )
        elif isinstance(data, list):
            logger.info(f"Data statistics: {len(data)} items")
            if data and logger.isEnabledFor(logging.INFO):
                logger.info(f"First item type: {type(data[0])}")
        else:
            logger.info(f"Data type: {type(data)}")

    @with_logger
    def log_evaluation_details(self, results: pd.DataFrame, score: float) -> None:
        """
        Log detailed evaluation information.

        Args:
            results: The evaluation results
            score: The overall evaluation score
        """
        logger.info(f"Evaluation score: {score:.4f}")

        if isinstance(results, pd.DataFrame) and logger.isEnabledFor(logging.INFO):
            if "correct" in results.columns:
                correct_count = results["correct"].sum()
                total_count = len(results)
                accuracy = correct_count / total_count if total_count > 0 else 0
                logger.info(f"Accuracy: {correct_count}/{total_count} = {accuracy:.4f}")

            if "score" in results.columns:
                logger.info(
                    f"Score statistics: min={results['score'].min():.4f}, "
                    f"max={results['score'].max():.4f}, "
                    f"mean={results['score'].mean():.4f}"
                )
