from typing import List, Dict, Any, Union
from dataset import BaseDataset
from prompts import PromptTemplate
import os
import statistics


class SummEvalDataset(BaseDataset):
    """
    Subclass for processing SummEval data
    - Format reference:
      {
        "decoded": "summary content...",
        "text": "original text...",
        "task_type": "relevance",
        "scores": [2, 1, 2, 4, 4, 4, 4, 4],
        "id": 6400
      }
    """

    def __init__(self, file_path: str = os.path.join(os.path.dirname(__file__), "train", "summeval_train.jsonl"),
                 keys: List[str] = None,
                 name: str = "summeval"):
        if keys is None:
            keys = ["decoded", "text", "task_type", "scores", "id"]
        
        # SummEval dataset typically has scores from 1-5
        label_mapping = {
            "1": 0,  # Very poor
            "2": 1,  # Poor
            "3": 2,  # Average
            "4": 3,  # Good
            "5": 4   # Excellent
        }
        self.coherence_template = PromptTemplate("SummEval_Coherence")
        self.consistency_template = PromptTemplate("SummEval_Consistency")
        self.fluency_template = PromptTemplate("SummEval_Fluency")
        self.relevance_template = PromptTemplate("SummEval_Relevance")

        super().__init__(name=name,
                         template_name="SummEval_Relevance",
                         file_path=file_path,
                         keys=keys,
                         id_key="id",
                         label_mapping=label_mapping)

    def make_prompt(self, item: Dict[str, Any]) -> str:

        if item.get("task_type") == "coherence":
            prompt = self.coherence_template(Summary=item["decoded"], Document=item["text"])
        elif item.get("task_type") == "consistency":
            prompt = self.consistency_template(Summary=item["decoded"], Document=item["text"])
        elif item.get("task_type") == "fluency":
            prompt = self.fluency_template(Summary=item["decoded"], Document=item["text"])
        elif item.get("task_type") == "relevance":
            prompt = self.relevance_template(Summary=item["decoded"], Document=item["text"])
        else:
            prompt = self.template(Summary=item["decoded"], Document=item["text"])
        return prompt

    def get_label(self, item: Dict[str, Any]) -> List[int]:

        scores = item.get("scores", [])
        scores = [str(score) for score in scores]
        return scores

    def get_gold_label(self, item: Dict[str, Any]) -> Union[float, int]:
        return self.get_mode_label(item)

    def get_mode_label(self, item: Dict[str, Any]) -> int:

        scores = item.get("scores", [])
        if not scores:
            return ""
        
        score_count = {}
        for score in scores:
            if score not in score_count:
                score_count[score] = 0
            score_count[score] += 1
        
        max_count = max(score_count.values())
        mode_scores = [score for score, count in score_count.items() if count == max_count]
        
        return str(mode_scores[0]) if mode_scores else ""

    def phrase_output(self, llm_output: str) -> str:

        clean_output = llm_output.strip()
        
        try:
            for word in clean_output.split():
                if word.isdigit():
                    score = int(word)
                    if 1 <= score <= 5:
                        return str(score)
            
            return clean_output
        except:
            return clean_output
    
    def get_summary(self, item: Dict[str, Any]) -> str:
        return item.get("decoded", "")
    
    def get_original_text(self, item: Dict[str, Any]) -> str:
        return item.get("text", "")
    
    def get_task_type(self, item: Dict[str, Any]) -> str:
        return item.get("task_type", self.task_type)

