"""HellaSwag dataset handler."""
import os
import pandas as pd
from typing import List, Dict
from .base import DatasetHandler


class HellaSwagHandler(DatasetHandler):
    name = "hellaswag"
    default_train_path = "data/hellaswag"
    default_test_path = "data/hellaswag"
    default_max_tokens = 64
    
    def __init__(self):
        from utils.reward_score import hellaswag as hellaswag_reward
        self.reward_module = hellaswag_reward
    
    def load_data(self, path: str, split: str = "train", max_samples: int = None) -> List[Dict]:
        if split == "train":
            parquet_file = os.path.join(path, "train-00000-of-00001.parquet")
        else:
            parquet_file = os.path.join(path, "validation-00000-of-00001.parquet")
        
        if not os.path.exists(parquet_file):
            raise FileNotFoundError(f"Parquet file not found: {parquet_file}")
        
        df = pd.read_parquet(parquet_file)
        df = df[df['label'].apply(lambda x: x is not None and str(x).strip() != '')]
        
        task_datas = []
        choice_labels = ['A', 'B', 'C', 'D']
        
        for idx, row in df.iterrows():
            ctx = row["ctx_a"] + " " + row["ctx_b"].capitalize()
            query = self.reward_module.preprocess(row["activity_label"] + ": " + ctx)
            choices = [self.reward_module.preprocess(ending) for ending in row["endings"]]
            gold = int(row["label"])
            
            choices_text = "\n".join([f"{label}. {choice}" for label, choice in zip(choice_labels, choices)])
            prompt = f"""Please complete the following sentence by choosing the best option.

{query}

Options:
{choices_text}

Answer with only the letter (A, B, C, or D) of the correct option."""
            
            task_datas.append({
                "messages": [{"role": "user", "content": prompt}],
                "ground_truth": gold,
                "ground_truth_label": choice_labels[gold],
            })
            if max_samples and len(task_datas) >= max_samples:
                break
        return task_datas
    
    def compute_reward(self, response: str, ground_truth: int) -> float:
        return self.reward_module.compute_score(response, ground_truth)
    
    def extract_answer(self, response: str) -> str:
        return self.reward_module.extract_choice(response)
    
    def format_answer_for_check(self, answer: str) -> str:
        return answer  # Just the letter A/B/C/D

