"""Spider postprocessor for processing SQL generation task predictions."""
import re
import logging
from postprocessors.base import Postprocessor

logger = logging.getLogger(__name__)
logger.propagate = True

class SpiderPostprocessor(Postprocessor):
    """
    Postprocessor for Spider model predictions.
    """

    def _clean_sql(self, sql: str) -> str:
        """Cleans a single SQL query string.

        Args:
            sql: An SQL query string.

        Returns:
            A cleaned SQL query string. Cleaning includes:
                - find the code blocks that start and end with ```sql and ```.
                - if multiple code blocks are present, select the first one.
                - if no code blocks are present, then return the sql generated by the model.

        Example:
            INPUT: "```sqlSELECT COUNT(*) from singer;``` ### Generate an ..."
            OUTPUT: "SELECT COUNT(*) from singer;"
        """
        if sql.strip() == "":
            return "SELECT"

        pattern = r"```(?:sql)?([\s\S]*?)```"
        try:
            match = re.findall(pattern, sql, re.MULTILINE)[0]
        except (IndexError, re.error) as e:
            logger.warning("Error cleaning generation: %s. Error: %s", sql, str(e))
            match = sql
        return match.strip()

    def process(
        self,
        dataset: list[dict],
        predictions: dict[str, list[str]],
        metric
    ) -> tuple[list[tuple[str, str]], dict[str, list[str]], list, list]:
        """
        Process and clean model predictions and prepare target-label pairs.
        """
        logger.info("Processing predictions with SpiderPostprocessor...")

        predictions = self.process_predictions(predictions)

        processed_predictions: dict[str, list[str]] = {}
        for model_name, preds in predictions.items():
            processed = [self.remove_thinking_content(pred) for pred in preds]
            processed = [self._clean_sql(pred) for pred in processed]
            processed_predictions[model_name] = processed

        output = {
            "instructions": [record.get("instruction", "") for record in dataset],
            "model_targets": [record["model_target"] for record in dataset if "model_target" in record],
            "processed_predictions": processed_predictions,
        }
        self.validate_output(output)
        return output
