import random
from collections import defaultdict

from datasets import load_dataset

from evaluation.schemas import ClassificationTaskBase
from evaluation.tasks.metrics import get_average, get_f1
from src.classifier_models.base import SafetyClassifierBase, ResponseHarmfulness


class SafeRLHFResponseHarmEval(ClassificationTaskBase):
    def required_input_fields(self) -> list[str]:
        return ["prompt", "response"]

    @property
    def label_field(self) -> str:
        return "response_harmfulness"

    def required_output_fields(self) -> list[list[str]]:
        return [["response_harmfulness"]]

    def _evaluate(self, model: SafetyClassifierBase) -> tuple[dict, list[dict]]:
        classification_results = model.classify(self.data)
        pred_is_correct = defaultdict(list)
        individual_results = []
        for gt, pred in zip(self.data, classification_results):
            gt_label = gt[self.label_field]
            individual_results.append({
                "id": gt["id"],
                "prompt": gt["prompt"],
                "response": gt["response"],
                "gt_response_harmfulness": gt_label.value,
                **pred.asdict(),
            })
            if pred.is_parsing_error:
                pred_is_correct["parsing_error"].append(1)
                pred_is_correct["micro_acc"].append(0)
                gt_label_is_harmful = gt_label == ResponseHarmfulness.HARMFUL
                pred_is_correct["f1"].append({"pred": not gt_label_is_harmful,
                                              "gt": gt_label_is_harmful})
                pred_is_correct[gt_label.value].append(0)
            else:
                pred_is_correct["parsing_error"].append(0)
                pred_is_correct["micro_acc"].append(gt_label == pred.response_harmfulness)
                pred_is_correct["f1"].append({"pred": pred.response_harmfulness == ResponseHarmfulness.HARMFUL,
                                              "gt": gt_label == ResponseHarmfulness.HARMFUL})
                pred_is_correct[gt_label.value].append(gt_label == pred.response_harmfulness)

        return {k: get_average(v) if "f1" not in k else get_f1(v) for k, v in
                pred_is_correct.items()}, individual_results

    def load(self) -> list[dict]:
        dataset = load_dataset("PKU-Alignment/PKU-SafeRLHF", revision="v0")
        test_dataset = dataset["test"]

        prompts = test_dataset["prompt"]
        response_0 = test_dataset["response_0"]
        response_1 = test_dataset["response_1"]
        is_response_0_safe = test_dataset["is_response_0_safe"]
        is_response_1_safe = test_dataset["is_response_1_safe"]

        safe_responses, harmful_responses = [], []
        for p, r0, r1, r0s, r1s in zip(prompts, response_0, response_1, is_response_0_safe, is_response_1_safe):
            # pick only when r0s != r1s
            if r0s == r1s:
                continue
            if r0s:
                safe_responses.append({
                    "prompt": p,
                    "response": r0,
                    "response_harmfulness": ResponseHarmfulness.UNHARMFUL,
                })
            else:
                harmful_responses.append({
                    "prompt": p,
                    "response": r0,
                    "response_harmfulness": ResponseHarmfulness.HARMFUL,
                })

            if r1s:
                safe_responses.append({
                    "prompt": p,
                    "response": r1,
                    "response_harmfulness": ResponseHarmfulness.UNHARMFUL,
                })
            else:
                harmful_responses.append({
                    "prompt": p,
                    "response": r1,
                    "response_harmfulness": ResponseHarmfulness.HARMFUL,
                })

        # pick 1000 from each category
        random.seed(42)
        indices = random.sample(range(len(safe_responses)), 1000)
        return [safe_responses[i] for i in indices] + [harmful_responses[i] for i in indices]
