from typing import List, Dict, Any
from dataset import BaseDataset
import os

class SNLIDataset(BaseDataset):
    """
    Subclass for processing SNLI data
    - Format reference:
      {
        "annotator_labels": ["neutral", "entailment", ...],
        "gold_label": "entailment",
        "sentence1": "...",
        "sentence2": "..."
        ...
      }
    """

    def __init__(self, file_path: str = os.path.join(os.path.dirname(__file__), "train", "snli_train.jsonl"),
                 keys: List[str] = None,
                 name: str = "snli"):
        if keys is None:
            keys = ["annotator_labels", "gold_label", "sentence1", "sentence2", "pairID"]
        label_mapping = {
            "neutral": 0,
            "entailment": 1,
            "contradiction": 2
        }
        super().__init__(name=name,
                         template_name="NLI",
                         file_path=file_path,
                         keys=keys,
                         id_key="pairID",
                         label_mapping=label_mapping)

    def make_prompt(self, item: Dict[str, Any]) -> str:

        return self.template(
            premise=item.get("sentence1", ""),
            hypothesis=item.get("sentence2", ""),
        )

    def get_label(self, item: Dict[str, Any]) -> Any:
        return item.get("annotator_labels", [])

    def get_gold_label(self, item: Dict[str, Any]) -> Any:
        return item.get("gold_label", "")

    def phrase_output(self, llm_output: str) -> str:

        llm_output.replace("\n", "")
        llm_output.replace(" ", "")
        return llm_output
