"""
Tamper Detection Evaluator

This module provides an evaluator for the tamper detection task.
"""

import pandas as pd

from src.evaluation.base import BaseEvaluator
from src.utils.decorator_utils import with_logger


class TamperDetectionEvaluator(BaseEvaluator):
    """
    Evaluator for the tamper detection task.

    This class evaluates the results of the tamper detection task.
    """

    @with_logger
    def __init__(self):
        """
        Initialise the tamper detection evaluator.
        """
        logger.info("Initialising TamperDetectionEvaluator")
        super().__init__()
        logger.info("TamperDetectionEvaluator initialisation complete")

    @with_logger
    def _calculate_accuracy(self, results: pd.DataFrame) -> float:
        """
        Calculate the accuracy of the results.

        Args:
            results: The results dataframe

        Returns:
            The accuracy score
        """
        logger.info("Calculating accuracy")
        # Add eval_score column if it doesn't exist
        if "eval_score" not in results.columns:
            logger.info("eval_score column not found, adding it")
            results = self._add_eval_score(results)

        # Calculate accuracy
        correct = results[results["eval_score"] == 1].shape[0]
        total = results.shape[0]

        accuracy = correct / total if total > 0 else 0.0
        logger.info(f"Accuracy: {accuracy:.4f} ({correct}/{total} correct)")
        return accuracy

    @with_logger
    def _calculate_precision(self, results: pd.DataFrame) -> float:
        """
        Calculate the precision of the results.

        Args:
            results: The results dataframe

        Returns:
            The precision score
        """
        logger.info("Calculating precision")
        # Add eval_score column if it doesn't exist
        if "eval_score" not in results.columns:
            logger.info("eval_score column not found, adding it")
            results = self._add_eval_score(results)

        # Calculate precision
        true_positives = results[
            (results["response_tampered"] == "Yes") & (results["eval_score"] == 1)
        ].shape[0]
        all_positives = results[results["response_tampered"] == "Yes"].shape[0]

        precision = true_positives / all_positives if all_positives > 0 else 0.0
        logger.info(
            f"Precision: {precision:.4f} ({true_positives}/{all_positives} true positives)"
        )
        return precision

    @with_logger
    def _calculate_recall(self, results: pd.DataFrame) -> float:
        """
        Calculate the recall of the results.

        Args:
            results: The results dataframe

        Returns:
            The recall score
        """
        logger.info("Calculating recall")
        # Add eval_score column if it doesn't exist
        if "eval_score" not in results.columns:
            logger.info("eval_score column not found, adding it")
            results = self._add_eval_score(results)

        # Calculate recall
        true_positives = results[
            (results["response_tampered"] == "Yes") & (results["eval_score"] == 1)
        ].shape[0]
        actual_positives = results.shape[0]  # All tampered images should be detected

        recall = true_positives / actual_positives if actual_positives > 0 else 0.0
        logger.info(
            f"Recall: {recall:.4f} ({true_positives}/{actual_positives} detected)"
        )
        return recall

    @with_logger
    def _calculate_f1_score(self, results: pd.DataFrame) -> float:
        """
        Calculate the F1 score of the results.

        Args:
            results: The results dataframe

        Returns:
            The F1 score
        """
        logger.info("Calculating F1 score")
        precision = self._calculate_precision(results)
        recall = self._calculate_recall(results)

        f1_score = (
            2 * (precision * recall) / (precision + recall)
            if (precision + recall) > 0
            else 0.0
        )
        logger.info(
            f"F1 score: {f1_score:.4f} (precision={precision:.4f}, recall={recall:.4f})"
        )
        return f1_score

    @with_logger
    def _calculate_mean_score(self, results: pd.DataFrame) -> float:
        """
        Calculate the mean score of the results.

        Args:
            results: The results dataframe

        Returns:
            The mean score
        """
        logger.info("Calculating mean score")
        # Add eval_score column if it doesn't exist
        if "eval_score" not in results.columns:
            logger.info("eval_score column not found, adding it")
            results = self._add_eval_score(results)

        # Calculate mean score
        mean_score = results["eval_score"].mean()
        logger.info(f"Mean score: {mean_score:.4f}")
        return mean_score

    @with_logger
    def _add_eval_score(self, results: pd.DataFrame) -> pd.DataFrame:
        """
        Add an eval_score column to the results dataframe.

        Args:
            results: The results dataframe

        Returns:
            The results dataframe with an eval_score column
        """
        logger.info("Adding evaluation score column to results dataframe")
        # Make a copy to avoid modifying the original
        results = results.copy()

        # Define conditions for correct and incorrect predictions
        cond_correct = (results["response_original"] == "No") & (
            results["response_tampered"] == "Yes"
        )
        cond_incorrect = (results["response_original"] == "Yes") & (
            results["response_tampered"] == "No"
        )

        # Set eval_score based on conditions
        results["eval_score"] = 0
        results.loc[cond_correct, "eval_score"] = 1
        results.loc[cond_incorrect, "eval_score"] = -1

        # Log the distribution of scores
        correct_count = results[results["eval_score"] == 1].shape[0]
        incorrect_count = results[results["eval_score"] == -1].shape[0]
        neutral_count = results[results["eval_score"] == 0].shape[0]
        logger.info(
            f"Evaluation score distribution: correct={correct_count}, incorrect={incorrect_count}, neutral={neutral_count}"
        )

        return results

    @with_logger
    def get_eval_score(self, results: pd.DataFrame) -> float:
        """
        Get the evaluation score for the results.

        This is a compatibility method for the old API.

        Args:
            results: The results dataframe

        Returns:
            The mean evaluation score
        """
        logger.info("Getting evaluation score")
        score = self._calculate_mean_score(results)
        logger.info(f"Final evaluation score: {score:.4f}")
        return score
