import os.path
import random
from collections import defaultdict

import torch
from tqdm import trange
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase
from evaluation.tasks.metrics import get_average


def download_toxigen(minority_group: str):
    current_dir = os.path.abspath(os.path.dirname(__file__))
    url_format = "https://raw.githubusercontent.com/microsoft/TOXIGEN/main/prompts/hate_{minority_group}_1k.txt"

    url = url_format.format(minority_group=minority_group)
    download_path = os.path.join(current_dir, f"hate_{minority_group}.txt")
    os.system(f"wget {url} -O {download_path}")


@torch.no_grad()
def score_generations(
        hate_classifier_tokenizer, hate_classifier_model, generations, batch_size=32
):
    classifications = []
    for i in trange(0, len(generations), batch_size):
        batch = generations[i: i + batch_size]
        inputs = hate_classifier_tokenizer(
            batch, padding=True, truncation=True, return_tensors="pt"
        ).to(hate_classifier_model.device)
        outputs = hate_classifier_model(**inputs)
        classifications.extend(outputs.logits.argmax(dim=-1).tolist())
    return classifications


class ToxiGen(OpenEndedTaskBase):
    def __init__(self, classifier_model_name: str):
        super().__init__()
        self.classifier_model_name = classifier_model_name
        self.max_new_tokens, self.temperature, self.top_p = self.prepare_hparams()
        self.batch_size = 32

    def prepare_hparams(self):
        return 512, 0.0, 1.0

    def prepare_evaluator(self) -> tuple[AutoModelForSequenceClassification, AutoTokenizer]:
        classifier = AutoModelForSequenceClassification.from_pretrained(self.classifier_model_name)
        classifier = classifier.eval()
        classifier = classifier.cuda()
        tokenizer = AutoTokenizer.from_pretrained(self.classifier_model_name)
        return classifier, tokenizer

    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        inputs = [{"instruction": row["instruction"]} for row in self.data]
        completions = model.generate_completions(inputs,
                                                 max_new_tokens=self.max_new_tokens,
                                                 temperature=self.temperature,
                                                 top_p=self.top_p)
        assert len(completions) == len(self.data)
        outputs = [completion for completion in completions]

        classifier, tokenizer = self.prepare_evaluator()
        classifications = score_generations(
            tokenizer,
            classifier,
            outputs,
            batch_size=self.batch_size,
        )

        all_groups_scores = defaultdict(list)
        final_outputs = []
        for example, output, classification in zip(self.data, outputs, classifications):
            final_outputs.append({
                "id": example["id"],
                "input": example["instruction"],
                "output": output,
                "classification": classification,
            })
            for group in example["target_groups"]:
                all_groups_scores[group].append(classification)

        # Calculate the performance
        performance = {}
        # calculate score by targeted group
        for group in all_groups_scores:
            performance[group] = get_average(all_groups_scores[group])

        # overall perf
        performance["overall"] = get_average(classifications)
        return performance, final_outputs

    def required_input_fields(self) -> list[str]:
        return ["instruction"]

    def load(self) -> list[dict]:
        current_dir = os.path.abspath(os.path.dirname(__file__))
        minority_groups = ["asian",
                           "black",
                           "chinese",
                           "jewish",
                           "latino",
                           "lgbtq",
                           "mental_disability",
                           "mexican",
                           "middle_east",
                           "muslim",
                           "native_american",
                           "physical_disability",
                           "trans",
                           "women"]
        examples = []
        for minor in minority_groups:
            task_file = os.path.join(current_dir, f"hate_{minor}.txt")
            if not os.path.exists(task_file):
                download_toxigen(minor)

            with open(task_file, "r") as f:
                group_name = os.path.basename(task_file).split(".")[0]
                label = group_name.split("_")[0]
                minority_group = "_".join(group_name.split("_")[1:])
                group_prompts = [line.strip() for line in f]
                random.shuffle(group_prompts)
                for prompt in group_prompts:
                    # minor cleaning: replace \\ with \
                    prompt = prompt.replace("\\\\", "\\")
                    prompt = prompt.replace("\\n", "\n")
                    examples.append({
                        "instruction": prompt,
                        "label": label,
                        "target_groups": [minority_group],
                    })

        return examples


class TinyToxiGen(ToxiGen):
    """
    Subsample 100 per each minority group.
    """
    NUM_SAMPLES = 100

    def load(self) -> list[dict]:
        random.seed(42)
        current_dir = os.path.abspath(os.path.dirname(__file__))
        minority_groups = ["asian",
                           "black",
                           "chinese",
                           "jewish",
                           "latino",
                           "lgbtq",
                           "mental_disability",
                           "mexican",
                           "middle_east",
                           "muslim",
                           "native_american",
                           "physical_disability",
                           "trans",
                           "women"]
        examples = []
        for minor in minority_groups:
            task_file = os.path.join(current_dir, f"hate_{minor}.txt")
            if not os.path.exists(task_file):
                download_toxigen(minor)

            with open(task_file, "r") as f:
                group_name = os.path.basename(task_file).split(".")[0]
                label = group_name.split("_")[0]
                minority_group = "_".join(group_name.split("_")[1:])
                group_prompts = [line.strip() for line in f]
                random.shuffle(group_prompts)
                for prompt in group_prompts[:self.NUM_SAMPLES]:
                    # minor cleaning: replace \\ with \
                    prompt = prompt.replace("\\\\", "\\")
                    prompt = prompt.replace("\\n", "\n")
                    examples.append({
                        "instruction": prompt,
                        "label": label,
                        "target_groups": [minority_group],
                    })

        return examples
