import random
from typing import Any, Callable

import torch
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
from datasets.arrow_dataset import Dataset


def collect_cer(
    examples: dict[str, list[list[float]]], models: list[str]
) -> list[list[float]]:
    """
    Calculate the average CER (Character Error Rate) for each index of each example across specified models.

    Args:
        examples (dict[str, list[list[float]]]): Dictionary containing CER lists for different models.
        models (list[str]): List of models to include in the calculation.

    Returns:
        list[list[float]]: A list where each sublist contains the average CERs for each index of an example.

    Raises:
        ValueError: If models have inconsistent numbers of examples or mismatched CER lengths.
    """
    model_cer_lists = [examples[model] for model in models if model in examples]

    if not model_cer_lists or not all(
        len(cer_list) == len(model_cer_lists[0]) for cer_list in model_cer_lists
    ):
        raise ValueError("All models must have the same number of examples.")

    averaged_cer = []
    for example_group in zip(*model_cer_lists):
        if not all(
            len(cer_list) == len(example_group[0]) for cer_list in example_group
        ):
            raise ValueError("All CER lists for an example must have the same length.")
        averaged_cer.append(
            [sum(values) / len(values) for values in zip(*example_group)]
        )

    return averaged_cer


def select_audio_mapper(
    language: str,
    kind: str = "best",
) -> Callable[[dict[str, list[Any]]], dict[str, list[Any]]]:
    """
    Create a mapping function for selecting audio data based on CER.

    Args:
        language (str): Language code for filtering unsupported models.
        kind (str, optional): Selection strategy ('best', 'worst', or 'random'). Defaults to 'best'.

    Returns:
        Callable[[dict[str, list[Any]]], dict[str, list[Any]]]: A function for mapping dataset examples.

    Raises:
        ValueError: If an invalid selection strategy is provided.
    """

    keys = {
        "audio",
        "filename",
        "gender",
        "num_samples",
        "seamlessm4t_asr",
        "seamlessm4t_asr_cer",
        "seamlessm4t_asr_translation",
        "seamlessm4t_asr_wer",
        "speaker_id",
        "split",
        "whisper_asr",
        "whisper_asr_cer",
        "whisper_asr_translation",
        "whisper_asr_wer",
    }

    # Define unsupported languages for each model
    seamless_unsupported = {
        "ast_Latn",
        "hau_Latn",
        "kam_Latn",
        "kea_Latn",
        "lin_Latn",
        "mri_Latn",
        "nso_Latn",
        "oci_Latn",
        "tgl_Latn",
        "umb_Latn",
        "wol_Latn",
        "xho_Latn",
    }
    whisper_unsupported = {
        "ast_Latn",
        "ceb_Latn",
        "ckb_Arab",
        "fuv_Latn",
        "gle_Latn",
        "ibo_Latn",
        "kam_Latn",
        "kea_Latn",
        "kir_Cyrl",
        "lug_Latn",
        "luo_Latn",
        "nso_Latn",
        "tgl_Latn",
        "umb_Latn",
        "wol_Latn",
        "xho_Latn",
        "zul_Latn",
    }

    # Define selection strategy
    if kind == "best":
        select_func = lambda scores: min(range(len(scores)), key=lambda i: scores[i])
    elif kind == "worst":
        select_func = lambda scores: max(range(len(scores)), key=lambda i: scores[i])
    elif kind == "random":
        select_func = lambda scores: random.randint(0, len(scores) - 1)
    else:
        raise ValueError("Invalid 'kind'. Must be one of 'best', 'worst', or 'random'.")

    # Determine which models are supported for the given language
    if language not in whisper_unsupported and language not in seamless_unsupported:
        models = ["whisper_asr_cer", "seamlessm4t_asr_cer"]
    elif language in whisper_unsupported:
        models = ["seamlessm4t_asr_cer"]
    elif language in seamless_unsupported:
        models = ["whisper_asr_cer"]
    else:
        models = ["whisper_asr_cer", "seamlessm4t_asr_cer"]

    def map_fn(examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
        """
        Map function to process dataset examples by selecting CER-based audio data.

        Args:
            examples (dict[str, list[Any]]): Dataset examples.

        Returns:
            dict[str, list[Any]]: Processed dataset examples.
        """
        cers = collect_cer(examples, models)
        indices = [select_func(cer) for cer in cers]

        for key, values in examples.items():
            if key not in keys:
                examples[key] = values
            else:
                examples[key] = [
                    values[idx] for idx, values in zip(indices, examples[key])
                ]
        return examples

    return map_fn


class Chain:
    def __init__(self, functions) -> None:
        self.functions = functions

    def __call__(self, inputs) -> Any:
        result = inputs
        for function in self.functions.values():
            result = function(inputs)
        return result


def tokenize(
    examples: dict[str, list],
    tokenizer: PreTrainedTokenizerFast,
    text_column: str,
    tokenize_kwargs: dict = {},
):
    batch = tokenizer(examples[text_column], **tokenize_kwargs)
    return batch


def collate_text(
    inputs: list[dict],
    tokenizer,
    tokenize_kwargs,
    text_column: str,
    label_column: str = "category",
) -> dict[str, torch.Tensor]:
    texts = [dico.pop(text_column) for dico in inputs]
    labels = [dico.pop(label_column) for dico in inputs]
    batch = tokenizer(texts, **tokenize_kwargs)
    batch["labels"] = torch.LongTensor(labels)
    return batch


def collate_fn(
    inputs: list[dict],
    tokenizer,
    tokenize_kwargs,
    text_column,
    label_column: str = "category",
) -> dict[str, torch.Tensor]:
    labels = [dico.pop(label_column) for dico in inputs]
    texts = [dico.pop(text_column) for dico in inputs]
    batch = tokenizer(texts, **tokenize_kwargs)
    batch["labels"] = torch.LongTensor(labels)
    return batch


def load_translate_test(path: str, name: str, split: str):
    from pathlib import Path
    path_ = Path(path) / name / f"{split}.parquet"
    dataset = load_dataset("parquet", data_files={"train": str(path_)}, split="train")
    return dataset


def main():
    from hydra.utils import instantiate
    from omegaconf import OmegaConf

    with open("./configs/dataspec/sib.yaml") as f:
        s = f.read()
    cfg = OmegaConf.create(s)
    OmegaConf.resolve(cfg)
    from trident.core.dataspec import TridentDataspec

    dataspec = TridentDataspec(cfg.test, "test")
    dl = dataspec.get_dataloader()

    batch = next(iter(dl))

    eng_Latn = load_dataset("wuenlp/sib", "eng_Latn", split="test")
    mapper = select_audio_mapper("eng_Latn")
    dataset = eng_Latn.map(mapper, batched=True, batch_size=50)

    from transformers import AutoFeatureExtractor

    model = "facebook/wav2vec2-xls-r-300m"
    # model = "utter-project/mHuBERT-147"
    feature_extractor = AutoFeatureExtractor.from_pretrained(model)
    dataset = dataset.map(select_audio_mapper(), batched=True, batch_size=50)
    speech = [dataset[i]["audio"]["array"] for i in range(5)]
    batch = feature_extractor(
        speech,
        sampling_rate=16000,
        padding=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
