import random
from typing import Any, Callable

import torch
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
from datasets.arrow_dataset import Dataset


def collect_cer(
    examples: dict[str, list[list[float]]], models: list[str]
) -> list[list[float]]:
    """
    Calculate the average CER (Character Error Rate) for each index of each example across specified models.

    Args:
        examples (dict[str, list[list[float]]]): Dictionary containing CER lists for different models.
        models (list[str]): List of models to include in the calculation.

    Returns:
        list[list[float]]: A list where each sublist contains the average CERs for each index of an example.

    Raises:
        ValueError: If models have inconsistent numbers of examples or mismatched CER lengths.
    """
    model_cer_lists = [examples[model] for model in models if model in examples]

    if not model_cer_lists or not all(
        len(cer_list) == len(model_cer_lists[0]) for cer_list in model_cer_lists
    ):
        raise ValueError("All models must have the same number of examples.")

    averaged_cer = []
    for example_group in zip(*model_cer_lists):
        if not all(
            len(cer_list) == len(example_group[0]) for cer_list in example_group
        ):
            raise ValueError("All CER lists for an example must have the same length.")
        averaged_cer.append(
            [sum(values) / len(values) for values in zip(*example_group)]
        )

    return averaged_cer


def select_audio_mapper(
    language: str,
    kind: str = "best",
) -> Callable[[dict[str, list[Any]]], dict[str, list[Any]]]:
    """
    Create a mapping function for selecting audio data based on CER.

    Args:
        language (str): Language code for filtering unsupported models.
        kind (str, optional): Selection strategy ('best', 'worst', or 'random'). Defaults to 'best'.

    Returns:
        Callable[[dict[str, list[Any]]], dict[str, list[Any]]]: A function for mapping dataset examples.

    Raises:
        ValueError: If an invalid selection strategy is provided.
    """

    keys = {
        "audio",
        "filename",
        "gender",
        "num_samples",
        "seamlessm4t_asr",
        "seamlessm4t_asr_cer",
        "seamlessm4t_asr_translation",
        "seamlessm4t_asr_wer",
        "speaker_id",
        "split",
        "whisper_asr",
        "whisper_asr_cer",
        "whisper_asr_translation",
        "whisper_asr_wer",
    }

    # Define unsupported languages for each model
    seamless_unsupported = {
        "ast_Latn",
        "hau_Latn",
        "kam_Latn",
        "kea_Latn",
        "lin_Latn",
        "mri_Latn",
        "nso_Latn",
        "oci_Latn",
        "tgl_Latn",
        "umb_Latn",
        "wol_Latn",
        "xho_Latn",
    }
    whisper_unsupported = {
        "ast_Latn",
        "ceb_Latn",
        "ckb_Arab",
        "fuv_Latn",
        "gle_Latn",
        "ibo_Latn",
        "kam_Latn",
        "kea_Latn",
        "kir_Cyrl",
        "lug_Latn",
        "luo_Latn",
        "nso_Latn",
        "tgl_Latn",
        "umb_Latn",
        "wol_Latn",
        "xho_Latn",
        "zul_Latn",
    }

    # Define selection strategy
    if kind == "best":
        select_func = lambda scores: min(range(len(scores)), key=lambda i: scores[i])
    elif kind == "worst":
        select_func = lambda scores: max(range(len(scores)), key=lambda i: scores[i])
    elif kind == "random":
        select_func = lambda scores: random.randint(0, len(scores) - 1)
    else:
        raise ValueError("Invalid 'kind'. Must be one of 'best', 'worst', or 'random'.")

    # Determine which models are supported for the given language
    if language not in whisper_unsupported and language not in seamless_unsupported:
        models = ["whisper_asr_cer", "seamlessm4t_asr_cer"]
    elif language in whisper_unsupported:
        models = ["seamlessm4t_asr_cer"]
    elif language in seamless_unsupported:
        models = ["whisper_asr_cer"]
    else:
        models = ["whisper_asr_cer", "seamlessm4t_asr_cer"]

    def map_fn(examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
        """
        Map function to process dataset examples by selecting CER-based audio data.

        Args:
            examples (dict[str, list[Any]]): Dataset examples.

        Returns:
            dict[str, list[Any]]: Processed dataset examples.
        """
        cers = collect_cer(examples, models)
        indices = [select_func(cer) for cer in cers]


        for key, values in examples.items():
            if key not in keys:
                examples[key] = values
            else:
                examples[key] = [values[idx] for idx, values in zip(indices, examples[key])]
        return examples

    return map_fn



LANGS = [
    "afr_Latn",
    "amh_Ethi",
    "arb_Arab",
    "asm_Beng",
    "ast_Latn",
    "azj_Latn",
    "bel_Cyrl",
    "ben_Beng",
    "bos_Latn",
    "bul_Cyrl",
    "cat_Latn",
    "ceb_Latn",
    "ces_Latn",
    "ckb_Arab",
    "cym_Latn",
    "dan_Latn",
    "deu_Latn",
    "ell_Grek",
    "eng_Latn",
    "est_Latn",
    "fin_Latn",
    "fra_Latn",
    "fuv_Latn",
    "gaz_Latn",
    "gle_Latn",
    "glg_Latn",
    "guj_Gujr",
    "hau_Latn",
    "heb_Hebr",
    "hin_Deva",
    "hrv_Latn",
    "hun_Latn",
    "hye_Armn",
    "ibo_Latn",
    "ind_Latn",
    "isl_Latn",
    "ita_Latn",
    "jav_Latn",
    "jpn_Jpan",
    "kam_Latn",
    "kan_Knda",
    "kat_Geor",
    "kaz_Cyrl",
    "kea_Latn",
    "khk_Cyrl",
    "khm_Khmr",
    "kir_Cyrl",
    "kor_Hang",
    "lao_Laoo",
    "lin_Latn",
    "lit_Latn",
    "ltz_Latn",
    "lug_Latn",
    "luo_Latn",
    "lvs_Latn",
    "mal_Mlym",
    "mar_Deva",
    "mkd_Cyrl",
    "mlt_Latn",
    "mri_Latn",
    "mya_Mymr",
    "nld_Latn",
    "nob_Latn",
    "npi_Deva",
    "nso_Latn",
    "nya_Latn",
    "oci_Latn",
    "ory_Orya",
    "pan_Guru",
    "pbt_Arab",
    "pes_Arab",
    "pol_Latn",
    "por_Latn",
    "ron_Latn",
    "rus_Cyrl",
    "slk_Latn",
    "slv_Latn",
    "sna_Latn",
    "snd_Arab",
    "som_Latn",
    "spa_Latn",
    "srp_Cyrl",
    "swe_Latn",
    "swh_Latn",
    "tam_Taml",
    "tel_Telu",
    "tgk_Cyrl",
    "tgl_Latn",
    "tha_Thai",
    "tur_Latn",
    "ukr_Cyrl",
    "umb_Latn",
    "urd_Arab",
    "uzn_Latn",
    "vie_Latn",
    "wol_Latn",
    "xho_Latn",
    "yor_Latn",
    "zho_Hans",
    "zho_Hant",
    "zsm_Latn",
    "zul_Latn",
]

from datasets import load_dataset

out = {}
for split in ("train", "validation", "test"):
    dataset = load_dataset("wuenlp/fleurs-sib", "swh_Latn", split=split)
    map_ = select_audio_mapper(kind="best", language="eng_Latn")
    d = dataset.map(map_, batched=True, batch_size=10)
    from collections import Counter
    out[split] = Counter(d["split"])


map_ = select_audio_mapper(kind="worst", language="eng_Latn")
d2 = dataset.map(map_, batched=True, batch_size=10)

import random
from typing import Any

import torch
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast


def argmin(iterable: list) -> int:
    return min(enumerate(iterable), key=lambda x: x[1])[0]


def argmax(iterable: list) -> int:
    return max(enumerate(iterable), key=lambda x: x[1])[0]


def select_audio_mapper2(
    kind: str = "best", model: str = "whisper"
) -> list[dict[str, Any]]:
    keys = {
        "audio",
        "filename",
        "num_samples",
        "seamlessm4t_asr",
        "seamlessm4t_asr_cer",
        "seamlessm4t_asr_translation",
        "seamlessm4t_asr_wer",
        "speaker_id",
        "split",
        "whisper_asr",
        "whisper_asr_cer",
        "whisper_asr_translation",
        "whisper_asr_wer",
    }

    if kind == "best":
        select_func = argmin
    elif kind == "worst":
        select_func = argmax
    elif kind == "random":

        def select_func(list_: list) -> int:
            k = len(list_)
            return random.randint(0, k - 1)
    else:
        ValueError("Must be one of 'best', 'worst', or 'random'")

    def map_fn(examples: dict[str, list]):
        argmin_idx = [select_func(line) for line in examples[f"{model}_asr_cer"]]
        for k in keys:
            examples[k] = [line[idx] for idx, line in zip(argmin_idx, examples[k])]
        return examples

    return map_fn


map_2 = select_audio_mapper2(kind="best")
d2 = dataset.map(map_2, batched=True, batch_size=10)


for lang in LANGS:
    for split in ("train", "validation", "test"):
        dataset = load_dataset("wuenlp/fleurs-sib", lang, split=split)
        try:
            map_ = select_audio_mapper(kind="best", language=lang)
            d = dataset.map(map_, batched=True, batch_size=10)
            print(f"{lang}-{split} success")
        except:
            print(f"{lang}-{split} error")
