import abc
import json
import os
from collections import defaultdict

import pandas as pd

from evaluation.schemas import ClassificationTaskBase
from evaluation.tasks.metrics import get_average, get_f1
from src.classifier_models.base import SafetyClassifierBase, PromptHarmfulness, ResponseHarmfulness, ResponseRefusal


class HarmBenchPromptSafetyEvalBase(ClassificationTaskBase, abc.ABC):
    @abc.abstractmethod
    def load(self) -> list[dict]:
        raise NotImplementedError

    def required_input_fields(self) -> list[str]:
        return ["prompt"]

    @property
    def label_field(self) -> str:
        return "prompt_harmfulness"

    def required_output_fields(self) -> list[list[str]]:
        return [["prompt_harmfulness"]]

    def _evaluate(self, model: SafetyClassifierBase) -> tuple[dict, list[dict]]:
        classification_results = model.classify(self.data)
        pred_is_correct = defaultdict(list)
        individual_results = []
        for gt, pred in zip(self.data, classification_results):
            gt_label = gt[self.label_field]
            individual_results.append({
                "id": gt["id"],
                "prompt": gt["prompt"],
                "gt_prompt_harmfulness": gt_label.value,
                **pred.asdict(),
            })
            if pred.is_parsing_error:
                pred_is_correct["parsing_error"].append(1)
                pred_is_correct["micro_acc"].append(0)
                gt_label_is_harmful = gt_label == PromptHarmfulness.HARMFUL
                pred_is_correct["f1"].append({"pred": not gt_label_is_harmful,
                                              "gt": gt_label_is_harmful})
                pred_is_correct[gt["FunctionalCategory"]].append(0)
            else:
                pred_is_correct["parsing_error"].append(0)
                pred_is_correct["micro_acc"].append(gt_label == pred.prompt_harmfulness)
                pred_is_correct["f1"].append({"pred": pred.prompt_harmfulness == PromptHarmfulness.HARMFUL,
                                              "gt": gt_label == PromptHarmfulness.HARMFUL})
                pred_is_correct[gt["FunctionalCategory"]].append(gt_label == pred.prompt_harmfulness)

        return {k: get_average(v) if "f1" not in k else get_f1(v) for k, v in
                pred_is_correct.items()}, individual_results


class HarmBenchPromptSafetyEval(HarmBenchPromptSafetyEvalBase):
    def load(self) -> list[dict]:
        current_path = os.path.dirname(os.path.abspath(__file__))
        datapath = os.path.join(current_path, "harmbench_behaviors_text_test.csv")
        df = pd.read_csv(datapath, sep=',')
        # Filter to only "standard" or "copyright" FunctionalCategory
        # (remove "contextual" and "multimodal")
        df = df[df["FunctionalCategory"].isin(["standard", "copyright"])]
        return [{"prompt": row["Behavior"],
                 "response": "",  # Pass empty response in case classifier requires it
                 "prompt_harmfulness": PromptHarmfulness.HARMFUL,
                 "FunctionalCategory": row["FunctionalCategory"]}
                for i, row in df.iterrows()]


class HarmBenchResponseSafetyEvalBase(ClassificationTaskBase, abc.ABC):
    @abc.abstractmethod
    def load(self) -> list[dict]:
        raise NotImplementedError

    def required_input_fields(self) -> list[str]:
        return ["prompt", "response"]

    @property
    def label_field(self) -> str:
        return "response_harmfulness"

    def required_output_fields(self) -> list[list[str]]:
        return [["response_refusal"], ["response_harmfulness"]]

    def _evaluate(self, model: SafetyClassifierBase) -> tuple[dict, list[dict]]:
        map_refusal_to_harmfulness = lambda \
                r: ResponseHarmfulness.UNHARMFUL if r == ResponseRefusal.REFUSAL else ResponseHarmfulness.HARMFUL
        classification_results = model.classify(self.data)
        pred_is_correct = defaultdict(list)
        individual_results = []
        for gt, pred in zip(self.data, classification_results):
            if pred.response_harmfulness is None:
                pred_label = map_refusal_to_harmfulness(pred.response_refusal)
            else:
                pred_label = pred.response_harmfulness

            gt_label = gt["response_harmfulness"]
            individual_results.append({
                "id": gt["id"],
                "prompt": gt["prompt"],
                "response": gt["response"],
                "gt_response_harmfulness": gt_label.value,
                **pred.asdict(),
            })
            if pred.is_parsing_error:
                pred_is_correct["parsing_error"].append(1)
                pred_is_correct["micro_acc"].append(0)
                gt_label_is_harmful = gt_label == ResponseHarmfulness.HARMFUL
                pred_is_correct[gt_label.value].append(0)
                pred_is_correct["f1"].append({"pred": not gt_label_is_harmful,
                                              "gt": gt_label_is_harmful})
                pred_is_correct[f"{gt['prompt_type']}_f1"].append({"pred": not gt_label_is_harmful,
                                                                 "gt": gt_label_is_harmful})
            else:
                pred_is_correct["parsing_error"].append(0)
                pred_is_correct["micro_acc"].append(gt_label == pred_label)
                pred_is_correct[gt_label.value].append(gt_label == pred_label)
                pred_is_correct["f1"].append({"pred": pred_label == ResponseHarmfulness.HARMFUL,
                                              "gt": gt_label == ResponseHarmfulness.HARMFUL})
                pred_is_correct[f"{gt['prompt_type']}_f1"].append({"pred": pred_label == ResponseHarmfulness.HARMFUL,
                                                                 "gt": gt_label == ResponseHarmfulness.HARMFUL})

        return {k: get_average(v) if "f1" not in k else get_f1(v) for k, v in
                pred_is_correct.items()}, individual_results


class HarmBenchResponseSafetyEval(HarmBenchResponseSafetyEvalBase):
    def load(self) -> list[dict]:
        current_path = os.path.dirname(os.path.abspath(__file__))
        datapath = os.path.join(current_path, "harmbench_text_behaviors_val_set.json")

        with open(datapath, 'r') as f:
            raw_data = json.load(f)
        keys = list(raw_data.keys())
        all_data = []

        def preprocess_data(x):
            # x includes human_0, human_1, human_2.
            # if more than two have 1, then it is harmful.
            # if more than two have 0, then it is non-harmful.
            human_0 = x["human_0"]
            human_1 = x["human_1"]
            human_2 = x["human_2"]

            harmfulness = ResponseHarmfulness.HARMFUL if sum(
                [human_0, human_1, human_2]) > 1 else ResponseHarmfulness.UNHARMFUL
            is_adversarial = True if x["experiment"] != "default" else False
            return {"prompt": x["test_case"], "response": x["generation"], "response_harmfulness": harmfulness,
                    "prompt_type": "adversarial" if is_adversarial else "vanilla"}

        for k in keys:
            all_data.extend([preprocess_data(x) for x in raw_data[k]])
        return all_data
