import abc
from collections import defaultdict

from datasets import load_dataset

from evaluation.schemas import ClassificationTaskBase
from evaluation.tasks.metrics import get_average, get_f1
from src.classifier_models.base import SafetyClassifierBase, PromptHarmfulness


class WildGuardTestPromptHarmBase(ClassificationTaskBase, abc.ABC):
    def __init__(self, include_response: bool = False, **kwargs):
        self.include_response = include_response
        super().__init__(**kwargs)

    @abc.abstractmethod
    def load(self) -> list[dict]:
        raise NotImplementedError

    @property
    def label_field(self) -> str:
        return "prompt_harmfulness"

    def required_input_fields(self) -> list[str]:
        return ["prompt"]

    def required_output_fields(self) -> list[list[str]]:
        return [["prompt_harmfulness"]]

    def _evaluate(self, model: SafetyClassifierBase) -> tuple[dict, list[dict]]:
        classification_results = model.classify(self.data)
        pred_is_correct = defaultdict(list)
        individual_results = []
        for gt, pred in zip(self.data, classification_results):
            gt_label = gt[self.label_field]
            individual_results.append({
                "id": gt["id"],
                "prompt": gt["prompt"],
                "gt_prompt_harmfulness": gt_label.value,
                "gt_prompt_harm_category": gt["prompt_harm_category"] if gt.get("prompt_harm_category", None) else None,
                **pred.asdict(),
            })
            gt_harmful = gt_label == PromptHarmfulness.HARMFUL
            if pred.is_parsing_error:
                pred_is_correct["parsing_error"].append(1)
                pred_is_correct["micro_acc"].append(0)
                pred_is_correct["f1"].append({"pred": not gt_harmful, "gt": gt_harmful})  # consider as false negative
                pred_is_correct[gt_label.value].append(0)
                if "prompt_type" in gt:
                    pred_is_correct[f'{gt["prompt_type"]}_f1'].append({"pred": not gt_harmful, "gt": gt_harmful})
                    pred_is_correct[gt["prompt_type"]].append(0)
                if "prompt_harm_category" in gt:
                    pred_is_correct[f'harm_cat_{gt["prompt_harm_category"]}_f1'].append(
                        {"pred": not gt_harmful, "gt": gt_harmful})
                    pred_is_correct[f'harm_cat_{gt["prompt_harm_category"]}'].append(0)
            else:
                pred_is_correct["parsing_error"].append(0)
                pred_is_correct["micro_acc"].append(gt_label == pred.prompt_harmfulness)
                pred_harmful = pred.prompt_harmfulness == PromptHarmfulness.HARMFUL
                pred_is_correct["f1"].append({"pred": pred_harmful, "gt": gt_harmful})
                pred_is_correct[gt_label.value].append(gt_label == pred.prompt_harmfulness)
                if "prompt_type" in gt:
                    pred_is_correct[f'{gt["prompt_type"]}_f1'].append({"pred": pred_harmful, "gt": gt_harmful})
                    pred_is_correct[gt["prompt_type"]].append(gt_label == pred.prompt_harmfulness)
                if "prompt_harm_category" in gt:
                    pred_is_correct[f'harm_cat_{gt["prompt_harm_category"]}_f1'].append(
                        {"pred": pred_harmful, "gt": gt_harmful})
                    pred_is_correct[f'harm_cat_{gt["prompt_harm_category"]}'].append(
                        gt_label == pred.prompt_harmfulness)

                if gt_label == PromptHarmfulness.HARMFUL and \
                        gt.get("prompt_harm_category", None) and "prompt_harm_category" in model.get_output_fields():
                    gt_harm_category = gt["prompt_harm_category"]
                    pred_harm_category = pred.prompt_harm_category
                    pred_is_correct["prompt_harm_category_micro_acc"].append(gt_harm_category == pred_harm_category)
                    pred_is_correct[f"prompt_harm_category_{gt_harm_category.value}"].append(
                        gt_harm_category == pred_harm_category)

        return {k: get_average(v) if "f1" not in k else get_f1(v) for k, v in
                pred_is_correct.items()}, individual_results


class WildGuardPromptEval(WildGuardTestPromptHarmBase):
    def load(self) -> list[dict]:
        df = load_dataset("allenai/wildguardmix", "wildguardtest")["test"]
        df = df.to_pandas()
        # drop if prompt_harm_label is None
        df = df.dropna(subset=["prompt_harm_label"])
        # setup prompt_type based on "adversarial"
        df["prompt_type"] = df.apply(lambda row: "adversarial" if row["adversarial"] else "vanilla", axis=1)
        df["gt_label"] = df["prompt_harm_label"]

        return [{"prompt": row["prompt"],
                 "response": row["response"] if self.include_response else "",
                 # Pass empty response in case classifier requires it
                 "prompt_harmfulness": PromptHarmfulness.HARMFUL if \
                     row["gt_label"] == "harmful" else PromptHarmfulness.UNHARMFUL,
                 "prompt_type": row["prompt_type"],
                 "prompt_harm_category": row["subcategory"]} for
                i, row in df.iterrows()]
