"""SQL scoring metrics for text-to-SQL evaluation tasks.

This module provides the SqlScore class for evaluating SQL generation quality
using execution accuracy and exact set match metrics from the Spider benchmark.
"""

import os
from typing import List, Tuple, Dict, Optional, Union

import nltk
import pandas as pd

from metrics.metrics import Metrics
from metrics.text2sql.evaluation import evaluate
from models.model_response import ModelResponse
from utils import util
from utils.custom_logging import write_record_log, append_final_score

# Constants for file paths and data selection
SPIDER_DATA_DIR = "data/spider/"
SPIDER_DB_DIR = "data/spider/database"


class SqlScore(Metrics):
    """SQL scoring metrics for text-to-SQL evaluation using Spider benchmark.
    
    This class evaluates generated SQL queries against reference queries using
    execution accuracy and exact set match metrics from the Spider evaluation suite.
    """
    def __init__(self):
        super().__init__()
        self.name = "sql_score"
        self.metric_ex = "Execution Accuracy"
        self.metric_em = "Exact Set Match"
        self.processed_text = "Post Processed Text"

        try:
            nltk.data.find("tokenizers/punkt")
        except LookupError:
            nltk.download("punkt")

        self.record_level_score = {}

    def __call__(
        self,
        candidates: List[str],
        references: List[Tuple[List[str], List[Dict[str, Optional[Union[str, int]]]]]],
        *,
        instructions: Optional[List[str]] = None,
        task_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_responses: Optional[List[ModelResponse]] = None,
    ) -> dict[str, dict[str, float] | float]:
        """
        Evaluate SQL execution accuracy and exact set match using Spider evaluation.

        Args:
            candidates (List[str]): Generated SQL strings.
            references (List[str]): Reference SQL strings.

        Returns:
            dict: Flattened dictionary with accuracy scores.
        """
        scores = evaluate(
            glist=references,
            plist=candidates,
            db_dir=SPIDER_DB_DIR,
            etype="all",
            table=os.path.join(SPIDER_DATA_DIR, "tables.jsonl"),
        )

        self.record_level_score = {
            self.processed_text: candidates,
            self.metric_ex: scores.get("per_record_ex", []),
            self.metric_em: scores.get("per_record_em", []),
        }

        #  Cleaning up the scores into the flattened format
        cleaned_scores = self._clean_scores(scores)

        # Write detailed record-level logs (if task_name and model_name provided)
        if task_name and model_name:
            write_record_log(
                self, 
                refs=references, 
                cands=candidates, 
                scores=scores, 
                task_name=task_name, 
                model_name=model_name, 
                explanations=None, 
                instructions=instructions,
                model_responses=model_responses
            )
            append_final_score(self, cleaned_scores, task_name, model_name, model_responses)
        return cleaned_scores

    def _clean_scores(self, scores: dict) -> dict:
        """
        Flatten the output scores.

        Args:
            scores: Dictionary containing evaluation scores.

        Returns:
            Flattened dictionary containing formatted scores.
        """
        flattened_scores = {}
        for level, ex, em in zip(
            scores["levels"], scores["exec_accuracy_score"], scores["exact_match_score"]
        ):
            if level == "all":
                level = "overall"
            # Scaling range from [0,1.0] to [0, 100.0]
            flattened_scores[f"{level}_exec_accuracy"] = util.smart_round(ex * 100.0, 2)
            flattened_scores[f"{level}_exact_set_match"] = util.smart_round(em * 100.0, 2)
        return flattened_scores

    def get_all_score_df(
        self, ids: List[int], candidates: List[str], references: List[str]
    ) -> pd.DataFrame:
        """Generate a DataFrame with all scores for the given data.
        
        Args:
            ids: List of sample IDs
            candidates: List of generated SQL strings
            references: List of reference SQL strings
            
        Returns:
            pd.DataFrame: DataFrame containing all scores with IDs
        """
        if not self.record_level_score:
            _ = self.get_score(candidates, references)
        all_scores = self.record_level_score
        all_scores["id"] = ids
        return pd.DataFrame(all_scores)

    def compute_record_level_scores(
        self,
        candidates: List[str],
        references: List[str],
    ) -> List[float]:
        """
        Compute record-level execution accuracy scores.

        Args:
            candidates (List[str]): Generated SQL strings.
            references (List[str]): Reference SQL strings.

        Returns:
            List[float]: A list where each item is 1.0 if execution was correct, else 0.0.
        """
        scores = evaluate(
            glist=references,
            plist=candidates,
            db_dir=SPIDER_DB_DIR,
            etype="all",
            table=os.path.join(SPIDER_DATA_DIR, "tables.jsonl"),
        )
        return [float(x) for x in scores.get("per_record_ex", [])]
