from decision.xp.data.base import BaseDataset, register_ds, ForwardedMixin
from pathlib import Path
import os
import pandas as pd
from abc import abstractmethod
import torch
from datasets import Dataset
import numpy as np
import datasets
from tqdm import tqdm

data_path = Path("/decision_suboptimal_classifiers/decision") / "datasets/forwarded"


@register_ds("hate", "UCB")
class HateDataset(ForwardedMixin, BaseDataset):
    y_name = "hate"
    S_name = "confidence"
    ds_name = "ucberkeley-dlab/measuring-hate-speech"
    xcol = "text"
    ycol = "hate_speech_score"
    map = "y > 0.5"

    def extract(self, examples: dict):
        return examples[self.xcol]

    def get_label(self, examples: dict):
        return (torch.as_tensor(examples[self.ycol]) > 0.5).long()

    def get_groups(self, dataset: Dataset) -> np.ndarray:
        return np.asarray(dataset["comment_id"])

    def subsample_df(
        self, df_X: pd.DataFrame, df_y: pd.Series, df_S: pd.Series, df_group: pd.Series
    ) -> (pd.DataFrame, pd.Series, pd.Series):
        subdf = df_group.groupby(df_group).first()  # take the first annotator
        df_X = df_X.loc[subdf.index]
        df_y = df_y.loc[subdf.index]
        df_S = df_S.loc[subdf.index]
        df_group = df_group.loc[subdf.index]
        return df_X, df_y, df_S, df_group


def filter_ds(
    ds: datasets.Dataset, col_text: str, col_label: str, label_map: dict = None
):
    label_map = label_map or {}

    def filter_cols(examples):
        labels = []
        for label in examples[col_label]:
            try:
                label_int = int(label)
            except Exception:
                label_int = label
            if callable(label_map):
                label_mapped = label_map(label_int)
            else:  # else is assumed to be a dict
                label_mapped = label_map.get(label_int, label_int)
            labels.append(label_mapped)

        return {
            "text": examples[col_text],
            "label": labels,
        }

    filtered_ds = ds.map(
        filter_cols, batched=True, remove_columns=[col_text, col_label]
    )
    # Get columns of filtered_ds
    cols = [c for c in filtered_ds.column_names if c not in ["text", "label"]]
    # Remove cols
    filtered_ds = filtered_ds.remove_columns(cols)
    filtered_ds = filtered_ds.filter(
        lambda example: (example["text"] is not None and example["label"] is not None)
    )

    return filtered_ds


class MergedDataset(ForwardedMixin, BaseDataset):
    y_name = "label"
    S_name = "confidence"

    @property
    @abstractmethod
    def ds_specs(self) -> dict:
        """Define the specs of each dataset to load with the following format:
        (dataset_name, split, text_column, label_column, label_map)"""
        pass

    def load_dataset(self):
        """Load the huggingface dataset."""
        ds_list = []

        for ds_name, splits, col_text, col_label, label_map in tqdm(self.ds_specs):
            if isinstance(ds_name, tuple):
                ds_name, url = ds_name
                splits = "train"  # the default split is train for csv
            else:
                url = None

            if not isinstance(splits, dict):
                splits = {None: splits}

            for name, split in splits.items():  # name is the subset
                if not isinstance(split, list):
                    split = [split]
                for s in split:
                    ds = datasets.load_dataset(
                        ds_name, data_files=url, name=name, split=s
                    )
                    ds = filter_ds(ds, col_text, col_label, label_map)
                    features = ds.features
                    features["label"] = datasets.Value("int32")
                    ds = ds.cast(features)
                    ds_list.append(ds)

        return datasets.concatenate_datasets(ds_list)

    def extract(self, examples: dict):
        return examples["text"]

    def get_label(self, examples: dict):
        return examples["label"]


@register_ds("hate_en_tweets", "Tweets")
class HateEnglishTweets(MergedDataset):
    ds_name = "hate_en_tweets"
    # Define the specs of each dataset to load with the following format:
    # (dataset_name, split, text_column, label_column, label_map)
    ds_specs = [
        ("tweets_hate_speech_detection", "train", "tweet", "label", {}),
    ]


@register_ds("hate_en_speech18", "Speech18")
class HateEnglishSpeech(MergedDataset):
    ds_name = "hate_en_speech18"
    ds_specs = [
        ("hate_speech18", "train", "text", "label", {3: 0, 2: 0}),
    ]


@register_ds("hate_en_speech_off", "Offensive")
class HateEnglishSpeech2(MergedDataset):
    ds_name = "hate_en_speech_off"
    ds_specs = [
        ("hate_speech_offensive", "train", "tweet", "class", {0: 1, 2: 0}),
    ]


@register_ds("hate_en_davidson", "Davidson")
class HateEnglishDavidson(MergedDataset):
    ds_name = "hate_en_davidson"
    ds_specs = [
        (
            "krishan-CSE/Davidson_Hate_Speech",
            ["train", "test"],
            "text",
            "labels",
            {0: 1, 1: 0},
        ),
    ]


@register_ds("hate_en_gender", "Gender")
class HateEnglishGender(MergedDataset):
    ds_name = "hate_en_gender"
    ds_specs = [
        (
            "ctoraman/gender-hate-speech",
            ["train", "test"],
            "Text",
            "Label",
            {2: 1},
        ),
    ]


@register_ds("hate_merged_en", "Merged")
class HateMergedEnglishDataset(MergedDataset):
    ds_name = "hate_merged_en"

    ds_specs = [
        *HateEnglishTweets.ds_specs,
        *HateEnglishSpeech.ds_specs,
        *HateEnglishSpeech2.ds_specs,
        *HateEnglishDavidson.ds_specs,
        *HateEnglishGender.ds_specs,
    ]


@register_ds("hate_merged_no_en", "Hate M ¬en")
class HateMergedNoEnglishDataset(MergedDataset):
    ds_name = "hate_merged_no_en"

    ds_specs = [
        ("hate_speech_filipino", ["train", "validation", "test"], "text", "label", {}),
        ("hate_speech_portuguese", "train", "text", "label", {}),
        ("hate_speech_pl", "train", "text", "rating", {2: 1, 3: 1, 4: 1}),
        (
            "mapsoriano/2016_2022_hate_speech_filipino",
            ["train", "validation", "test"],
            "text",
            "label",
            {},
        ),
        (
            "piuba-bigdata/contextualized_hate_speech",
            ["train", "dev", "test"],
            "text",
            "HATEFUL",
            {},
        ),
        # (
        #     "jeanlee/kmhas_korean_hate_speech",
        #     ["train", "validation", "test"],
        #     "text",
        #     "label",
        #     {},
        # ),
    ]


@register_ds("hate_merged_no_en2", "Hate M2 ¬en")
class HateMergedNoEnglish2Dataset(MergedDataset):
    ds_name = "hate_merged_no_en2"

    ds_specs = [
        (
            "jeanlee/kmhas_korean_hate_speech",
            ["train", "validation", "test"],
            "text",
            "label",
            lambda label: 0 if 8 in label else 1,  # labels are lists of ints
        ),
        (
            "classla/FRENK-hate-sl",
            ["train", "validation", "test"],
            "text",
            "label",
            {},
        ),
        (
            "classla/FRENK-hate-hr",
            ["train", "validation", "test"],
            "text",
            "label",
            {},
        ),
        (
            "roman_urdu_hate_speech",
            {"Fine_Grained": ["train", "validation", "test"]},
            "tweet",
            "label",
            {0: 1, 1: 0, 2: 1, 3: 1, 4: 1},
        ),
        (
            "kor_hate",
            ["train", "test"],
            "comments",
            "hate",
            {0: 1, 2: 0},
        ),
        (
            "ruanchaves/hatebr",
            ["train", "validation", "test"],
            "instagram_comments",
            "specialist_1_hate_speech",
            {},
        ),
        # (
        #     "arbml/Arabic_Hate_Speech",
        #     ["train", "validation"],
        #     "tweet",
        #     "is_hate",
        #     {
        #         "NOT_HS": 0,
        #         "HS1": 1,
        #         "HS2": 1,
        #         "HS3": 1,
        #         "HS5": 1,
        #         "HS6": 1,
        #     },
        # ),
    ]


@register_ds("merged_hate_check", "Checks")
class HateCheckDataset(MergedDataset):
    ds_name = "merged_hate_check"

    ds_specs = [
        (
            "Paul/hatecheck",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-french",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-spanish",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-polish",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-italian",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-german",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-dutch",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-arabic",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-hindi",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-mandarin",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
        (
            "Paul/hatecheck-portuguese",
            "test",
            "test_case",
            "label_gold",
            {"hateful": 1, "non-hateful": 0},
        ),
    ]


@register_ds("hate_en_frenk", "FRENK")
class HateEnglishFRENK(MergedDataset):
    ds_name = "hate_en_frenk"
    ds_specs = [
        ("classla/FRENK-hate-en", ["train", "validation", "test"], "text", "label", {}),
        (
            "limjiayi/hateful_memes_expanded",
            ["train", "validation", "test"],
            "text",
            "label",
            {},
        ),
    ]


@register_ds("hate_en_check", "Check")
class HateEnglishCheck(MergedDataset):
    ds_name = "hate_en_check"
    ds_specs = [
        HateCheckDataset.ds_specs[0],
    ]


@register_ds("hate_en_twitter", "Tweets 2")
class HateEnglishTwitter(MergedDataset):
    ds_name = "hate_en_twitter"
    ds_specs = [
        ("thefrankhsu/hate_speech_twitter", ["train", "test"], "tweet", "label", {}),
    ]


@register_ds("hate_en_open", "Open")
class HateEnglishOpen(MergedDataset):
    ds_name = "hate_en_open"
    ds_specs = [
        (
            "parnoux/hate_speech_open_data_original_class_test_set",
            "test",
            "tweet",
            "class",
            {0: 1, 2: 0},
        ),
    ]


@register_ds("hate_merged_en2", "Merged 2")
class HateMergedEnglish2Dataset(MergedDataset):
    ds_name = "hate_merged_en2"

    # Define the specs of each dataset to load with the following format:
    # (dataset_name, split, text_column, label_column, label_map)
    ds_specs = [
        *HateEnglishFRENK.ds_specs,
        *HateEnglishCheck.ds_specs,
        *HateEnglishTwitter.ds_specs,
        *HateEnglishOpen.ds_specs,
    ]


@register_ds("hate_dyn_gen", "DynGen")
class DynamicallyGeneratedHateSpeechDataset(MergedDataset):
    """https://github.com/bvidgen/Dynamically-Generated-Hate-Speech-Dataset"""

    ds_name = "hate_dyn_gen"

    ds_specs = [
        (
            (
                "csv",
                "https://raw.githubusercontent.com/bvidgen/Dynamically-Generated-Hate-Speech-Dataset/main/Dynamically%20Generated%20Hate%20Dataset%20v0.2.3.csv",
            ),
            None,
            "text",
            "label",
            {"hate": 1, "nothate": 0},
        ),
    ]


@register_ds("hate_merged_large_en", "Merged3")
class HateMergedEnglishLargeDataset(MergedDataset):
    ds_name = "hate_merged_large_en"

    ds_specs = [
        *HateMergedEnglishDataset.ds_specs,
        *HateMergedEnglish2Dataset.ds_specs,
        *DynamicallyGeneratedHateSpeechDataset.ds_specs,
    ]


@register_ds("hate_merged_large_no_en", "Hate L ¬en")
class HateMergedNoEnglishLargeDataset(MergedDataset):
    ds_name = "hate_merged_large_no_en"

    ds_specs = [
        *HateMergedNoEnglishDataset.ds_specs,
        *HateMergedNoEnglish2Dataset.ds_specs,
        *HateCheckDataset.ds_specs[1:],  # only the first is in english
    ]


@register_ds("hate_merged_large", "Hate L multi")
class HateMergedLargeDataset(MergedDataset):
    ds_name = "hate_merged_large"

    ds_specs = [
        *HateMergedEnglishLargeDataset.ds_specs,
        *HateMergedNoEnglishLargeDataset.ds_specs,
    ]
