from typing import List, Dict, Any
from dataset import BaseDataset
import os


class MTBenchDataset(BaseDataset):
    """
    Subclass for processing MT-Bench data
    - Format reference:
      {
        "question_id": 81,
        "model_a": "alpaca-13b",
        "model_b": "gpt-3.5-turbo",
        "conversation_a": [...],
        "conversation_b": [...],
        "winner": ["b", "b", "b", "b"],
        "turn": "1",
        "pairID": 304
      }
    """

    def __init__(self, file_path: str = os.path.join(os.path.dirname(__file__), "train", "mt_bench_train.jsonl"),
                 keys: List[str] = None,
                 name: str = "mtbench"):
        if keys is None:
            keys = ["conversation_a", "conversation_b", "winner", "turn", "pairID"]
        # MT-Bench compares outputs from two models, winner can be "a", "b" or "tie"
        label_mapping = {
            "a": 0,
            "b": 1,
            "tie": 2
        }
        super().__init__(name=name,
                         template_name="MT_Bench",
                         file_path=file_path,
                         keys=keys,
                         id_key="pairID",
                         label_mapping=label_mapping)

    def make_prompt(self, item: Dict[str, Any]) -> str:

        conversation_a = item.get("conversation_a", [])
        conversation_b = item.get("conversation_b", [])
        turn = item.get("turn", "1")  # Default to 1

        try:
            turn_num = int(turn)
        except (ValueError, TypeError):
            turn_num = 1
    
        max_messages = turn_num * 2
        conv_a_trimmed = conversation_a[:max_messages] if conversation_a else []
        conv_b_trimmed = conversation_b[:max_messages] if conversation_b else []
        
        return self.template(
            conversation_a=conv_a_trimmed,
            conversation_b=conv_b_trimmed,
        )

    def get_label(self, item: Dict[str, Any]) -> Any:

        return item.get("winner", [])

    def get_gold_label(self, item: Dict[str, Any]) -> Any:

        winners = item.get("winner", [])
        if len(winners) == 0:
            return ""
        winner_count = {}
        for winner in winners:
            if winner not in winner_count:
                winner_count[winner] = 0
            winner_count[winner] += 1
        max_count = max(winner_count.values())
        winners = [winner for winner, count in winner_count.items() if count == max_count]
        if len(winners) == 1:
            return winners[0]
        else:
            return ""

    def phrase_output(self, llm_output: str) -> str:
        return llm_output.strip()

    def get_conversation(self, item: Dict[str, Any], model_key: str = "model_a") -> List[Dict[str, str]]:
        conv_key = f"conversation_{model_key[-1]}"  # Get "conversation_a" or "conversation_b"
        return item.get(conv_key, [])

    def get_responses(self, item: Dict[str, Any], model_key: str = "model_a") -> List[str]:
        conversation = self.get_conversation(item, model_key)
        return [turn.get("content", "") for turn in conversation if turn.get("role") == "assistant"]