from model import SafeguardModel
from datasets import load_dataset
from typing import Optional, List
from metrics import SafeguardMetrics
from data_types import SafetyLabel, SafeguardSample

class SEASafeguardBench:
    available_subsets_splits = {
        "general": ["EN", "TA", "TH", "TL", "MS", "IN", "MY", "VI"],
        "cultural_content_generation": ["EN", "TA_EN", "TH_EN", "TL_EN", "MS_EN", "IN_EN", "MY_EN", "VI_EN"],
        "cultural_in_the_wild": ["TA_EN", "TH_EN", "TL_EN", "MS_EN", "IN_EN", "MY_EN", "VI_EN"],
    }

    def __init__(self, data_dir: str = "./dataset"):
        self.data_dir = data_dir

    def get_samples(self, subset: str, split: str, language: Optional[str] = None) -> List[SafeguardSample]:
        assert subset in self.available_subsets_splits, f"Invalid subset: {subset}. Valid subsets are {self.available_subsets_splits.keys()}"
        assert split in self.available_subsets_splits[subset], f"Invalid split: {split}. Valid splits are {self.available_subsets_splits[subset]}"
        if language:
            assert language in ["English", "Local"], f"Invalid language: {language}. Valid languages are ['English', 'Local']"

        dataset = load_dataset("csv", data_files=f"{self.data_dir}/{subset}/{split}.csv", cache_dir="./cache")["train"]

        samples = []
        for data in dataset:
            if subset == "general":
                sample = SafeguardSample(
                    prompt=data["prompt"],
                    prompt_gold_label=SafetyLabel(int(data["prompt_label"] == "Harmful")),
                    response=data["response"],
                    response_gold_label=SafetyLabel(int(data["response_label"] == "Harmful")) if data["response_label"] else None,
                )
                samples.append(sample)
            elif subset == "cultural_content_generation":
                if language is None or language == "English":
                    sample = SafeguardSample(
                        prompt=data["en_prompt"],
                        prompt_gold_label=SafetyLabel(int(data["prompt_label"] == "Harmful")),
                        response=data["en_response"],
                        response_gold_label=SafetyLabel(int(data["response_label"] == "Harmful" or data["prompt_label"] == "Sensitive")) if data["response_label"] else None,
                    )
                    samples.append(sample)
                if language is None or language == "Local":
                    sample = SafeguardSample(
                        prompt=data["local_prompt"],
                        prompt_gold_label=SafetyLabel(int(data["prompt_label"] == "Harmful")),
                        response=data["local_response"],
                        response_gold_label=SafetyLabel(int(data["response_label"] == "Harmful" or data["prompt_label"] == "Sensitive")) if data["response_label"] else None,
                    )
                    samples.append(sample)
            else:
                if language is None or language == "English":
                    sample = SafeguardSample(
                        prompt=data["en_prompt"],
                        prompt_gold_label=SafetyLabel(int(data["prompt_label"] == "Harmful")),
                    )
                    samples.append(sample)
                if language is None or language == "Local":
                    sample = SafeguardSample(
                        prompt=data["local_prompt"],
                        prompt_gold_label=SafetyLabel(int(data["prompt_label"] == "Harmful")),
                    )
                    samples.append(sample)
        return samples

    def eval(
        self, 
        model: SafeguardModel,
        subset: str = "general",
        split: str = "EN",
        language: Optional[str] = None,     # cultural_content_generation and cultural_in_the_wild subsets have both English and Local languages. If None, it will evaluate all languages.
        verbose: bool = False,
    ) -> SafeguardMetrics:
        print(f"Subset: {subset}, Split: {split}")
        metrics = SafeguardMetrics()
        samples = self.get_samples(subset, split, language)
        for output in model.predict(samples, verbose=verbose):
            metrics.add(
                prompt=output.prompt,
                prompt_gold_label=output.prompt_gold_label,
                prompt_harmful_score=output.prompt_harmful_score,
                response=output.response,
                response_gold_label=output.response_gold_label,
                response_harmful_score=output.response_harmful_score,
            )
        return metrics


if __name__ == "__main__":
    from model import DummySafeguardModel

    model = DummySafeguardModel()
    benchmark = SEASafeguardBench()
    metrics = benchmark.eval(model, subset="general", split="EN")
    metrics.save_to_json("dummy_performance.json")