import random
import torch
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from typing import Any, Callable

from transformers import AutoTokenizer

from functools import lru_cache

@lru_cache()
def get_english_data():
    from datasets import load_dataset
    eng_Latn = load_dataset("facebook/belebele", "eng_Latn", split="test")
    eng_links = eng_Latn["link"]
    eng_no = eng_Latn["question_number"]
    eng_pairs = dict(zip(zip(eng_links, eng_no), range(len(eng_links))))
    return eng_Latn, eng_pairs


def select_audio_mapper(
    language: str,
    kind: str = "best",
) -> Callable[[dict[str, list[Any]]], dict[str, list[Any]]]:
    """
    Create a mapping function for selecting audio data based on CER.

    Args:
        language (str): Language code for filtering unsupported models.
        kind (str, optional): Selection strategy ('best', 'worst', or 'random'). Defaults to 'best'.

    Returns:
        Callable[[dict[str, list[Any]]], dict[str, list[Any]]]: A function for mapping dataset examples.

    Raises:
        ValueError: If an invalid selection strategy is provided.
    """

    keys = {
        "audio",
        "filename",
        "gender",
        "num_samples",
        "seamlessm4t_asr",
        "seamlessm4t_asr_cer",
        "seamlessm4t_asr_translation",
        "seamlessm4t_asr_wer",
        "speaker_id",
        "split",
        "whisper_asr",
        "whisper_asr_cer",
        "whisper_asr_translation",
        "whisper_asr_wer",
    }

    # Define unsupported languages for each model
    seamless_unsupported = {
        "ast_Latn",
        "hau_Latn",
        "kam_Latn",
        "kea_Latn",
        "lin_Latn",
        "mri_Latn",
        "nso_Latn",
        "oci_Latn",
        "tgl_Latn",
        "umb_Latn",
        "wol_Latn",
        "xho_Latn",
    }
    whisper_unsupported = {
        "ast_Latn",
        "ceb_Latn",
        "ckb_Arab",
        "fuv_Latn",
        "gle_Latn",
        "ibo_Latn",
        "kam_Latn",
        "kea_Latn",
        "kir_Cyrl",
        "lug_Latn",
        "luo_Latn",
        "nso_Latn",
        "tgl_Latn",
        "umb_Latn",
        "wol_Latn",
        "xho_Latn",
        "zul_Latn",
    }

    # Define selection strategy
    if kind == "best":
        select_func = lambda scores: min(range(len(scores)), key=lambda i: scores[i])
    elif kind == "worst":
        select_func = lambda scores: max(range(len(scores)), key=lambda i: scores[i])
    elif kind == "random":
        select_func = lambda scores: random.randint(0, len(scores) - 1)
    else:
        raise ValueError("Invalid 'kind'. Must be one of 'best', 'worst', or 'random'.")

    # Determine which models are supported for the given language
    if language not in whisper_unsupported and language not in seamless_unsupported:
        models = ["whisper_asr_cer", "seamlessm4t_asr_cer"]
    elif language in whisper_unsupported:
        models = ["seamlessm4t_asr_cer"]
    elif language in seamless_unsupported:
        models = ["whisper_asr_cer"]
    else:
        models = ["whisper_asr_cer", "seamlessm4t_asr_cer"]

    asr_keys = [
        "whisper_asr",
        "whisper_asr_translation",
        "seamlessm4t_asr",
        "seamlessm4t_asr_translation",
    ]

    def map_fn(examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
        """
        Map function to process dataset examples by selecting CER-based audio data.

        Args:
            examples (dict[str, list[Any]]): Dataset examples.

        Returns:
            dict[str, list[Any]]: Processed dataset examples.
        """
        sentence_data_containers: list[list[list]] = examples["sentence_data"]

        paragraphs = {k: [] for k in asr_keys}

        for sentence_data in sentence_data_containers:
            collected_sentence_data = []
            for sentence in sentence_data:
                cer_lists = [sentence[model] for model in models]
                averaged_cer = [
                    sum(aligned_cer) / len(aligned_cer)
                    for aligned_cer in zip(*cer_lists)
                ]
                argmin_idx = select_func(averaged_cer)
                sentence_dict = {key: sentence[key][argmin_idx] for key in keys}
                sentence_dict["id"] = sentence["id"]
                collected_sentence_data.append(sentence_dict)

            collected_sentence_data = list(
                sorted(collected_sentence_data, key=lambda x: x["id"])
            )
            for key in asr_keys:
                texts = " ".join(
                    [line[key].strip() for line in collected_sentence_data]
                ).strip()
                paragraphs[key].append(texts)
        for key in asr_keys:
            examples[f"{key}_flores_passage"] = paragraphs[key]

        links = examples["link"]
        numbers = examples["question_number"]
        eng_Latn, eng_pairs = get_english_data()
        eng_ids = [eng_pairs[(link, number)] for link, number in zip(links, numbers)]
        eng_samples = [eng_Latn[idx] for idx in eng_ids]
        for col in ("question", "mc_answer1", "mc_answer2", "mc_answer3", "mc_answer4"):
            examples[f"english_{col}"] = [sample[col] for sample in eng_samples]
        return examples

    return map_fn


def construct_paragraph(
    sentence_data: list[dict[str, str]],
    model: str = "whisper",
    translation: bool = False,
) -> str:
    key = f"{model}_asr"
    if translation:
        key += "_translation"
    if translation:
        key += "_translation"
    texts = " ".join([line[key] for line in sentence_data]).strip()
    return texts


def prepare_asr_belebele(
    examples: dict[str, list], model: str = "whisper", translation: bool = False
):
    paragraphs = [
        construct_paragraph(sentence_data, model=model, translation=translation)
        for sentence_data in examples["collected_sentence_data"]
    ]
    return {"paragraphs": paragraphs}


def preprocess(
    examples: dict, tokenizer, columns, max_length: int = 512, debug=False
) -> dict[str, list[int]]:
    """
    Preprocess the dataset for multiple-choice question answering.

    Args:
        examples (dict): A batch of examples from the dataset.
        tokenizer (PreTrainedTokenizerFast): The tokenizer to use for encoding text.
        columns (dict): A dictionary specifying dataset column names for context, question, and choices.
        max_length (int): Maximum tokenized input length for the model. Defaults to 512.
        debug (bool): If True, performs debugging by verifying span masks against the choices.

    Returns:
        dict: A dictionary with processed input_ids and span_masks.
    choices = list(zip(*[examples[c] for c in columns["choices"]]))

    choices: list[tuple[str, ...]] = [
        tuple(c.strip() for c in choices_) for choices_ in choices
    ]
    """
    choices_ = list(zip(*[examples[c] for c in columns["choices"]]))
    choices = []
    for line_choices_ in choices_:
        line_choices = []
        for c in line_choices_:
            if not isinstance(c, str) or c == "":
                line_choices.append("None")
            else:
                line_choices.append(c)
        choices.append(line_choices)
    choices: list[tuple[str, ...]] = [
        tuple(c.strip() for c in choices_) for choices_ in choices
    ]
    paragraphs: list[str] = examples[columns["context"]]
    # we prepend \n to questions as paragraphs is more likely to be truncated
    questions: list[str] = ["\n" + q for q in examples[columns["question"]]]

    start_special_token = False
    end_special_token_id = False

    test_text = choices[0][0]
    test_batch = tokenizer(test_text, return_special_tokens_mask=True)
    if test_batch["special_tokens_mask"][0] == 1:
        start_special_token = tokenizer.convert_ids_to_tokens(
            test_batch["input_ids"][0]
        )
    if test_batch["special_tokens_mask"][-1] == 1:
        end_special_token_id = test_batch["input_ids"][-1]

    paragraphs = [start_special_token + p for p in paragraphs]

    # tokenize everything

    paragraphs_tokenized = tokenizer(
        paragraphs, add_special_tokens=False, return_attention_mask=False
    )["input_ids"]
    questions_tokenized = tokenizer(
        questions, add_special_tokens=False, return_attention_mask=False
    )["input_ids"]

    choices_flattened = [c for choices_ in choices for c in choices_]
    choices_tokenized = tokenizer(
        choices_flattened, add_special_tokens=False, return_attention_mask=False
    )["input_ids"]

    separators = [
        tokenizer(f"\n{i} ", add_special_tokens=False, return_attention_mask=False)[
            "input_ids"
        ]
        for i in range(1, len(columns["choices"]) + 1)
    ]

    choices_separators = []
    k = 0
    for i in range(len(paragraphs)):
        line_i = []
        # 0, 4, 8, ...
        offset = i * len(columns["choices"])
        for j in range(len(columns["choices"])):
            # [0, ..., 3], [4, ..., 7]
            line_i.append(separators[j])
            line_i.append(choices_tokenized[offset + j])
            k += 1
        if end_special_token_id:
            line_i.append([end_special_token_id])
        choices_separators.append(line_i)

    # Truncation logic to ensure tokenized inputs fit within `max_length`.
    resulting_tokens = []
    ranges = []
    for i in range(len(paragraphs)):
        total_length = 0
        paragraph_tokenized = paragraphs_tokenized[i]
        question_tokenized = questions_tokenized[i]
        choices_tokens = choices_separators[
            i
        ]  # Flattened and tokenized choices for this example

        # Calculate initial total length
        total_length += len(paragraph_tokenized)
        total_length += len(question_tokenized)
        for tokens in choices_tokens:
            total_length += len(tokens)

        # Truncate if total_length exceeds max_length
        if total_length > max_length:
            diff = total_length - max_length
            # Start truncating from the paragraph
            if len(paragraph_tokenized) > diff:
                paragraph_tokenized = paragraph_tokenized[:-diff]
                diff = 0  # No more truncation needed
            else:
                diff -= len(paragraph_tokenized)
                paragraph_tokenized = []  # Truncate paragraph entirely

            # If there's still a difference, truncate the question
            if diff > 0:
                if len(question_tokenized) > diff:
                    question_tokenized = question_tokenized[:-diff]
                    diff = 0  # No more truncation needed
                else:
                    diff -= len(question_tokenized)
                    question_tokenized = []  # Truncate question entirely

        # Reconstruct the final sequence with separators
        final_tokens = paragraph_tokenized + question_tokenized
        ranges_ = []
        for i, tokens in enumerate(choices_tokens, 1):
            final_tokens += tokens
            if i % 2 == 0:
                range_ = [len(final_tokens) - len(tokens), len(final_tokens)]
                ranges_.append(range_)
        ranges.append(ranges_)
        resulting_tokens.append(final_tokens)

    L = [len(line) for line in resulting_tokens]
    masks = []
    for len_, ranges_ in zip(L, ranges):
        line_masks = []
        for range_ in ranges_:
            mask = [0] * len_
            span_length = range_[1] - range_[0]
            mask[range_[0] : range_[1]] = [1] * span_length
            line_masks.append(mask)
        masks.append(line_masks)

    if debug:

        def debug_masks(i):
            indices = resulting_tokens[i]
            masks_ = masks[i]
            for mask, choice in zip(masks_, choices[i]):
                extracted = [indices[k] for k in range(len(indices)) if mask[k]]
                decoded = tokenizer.decode(extracted)
                if not decoded.strip() == choice.strip():
                    print(choice)
                    print(decoded)
                    print()
                    assert False, "Above example yields error"
                # print(decoded == choice, choice, decoded)

        for i in range(len(paragraphs)):
            debug_masks(i)

    return {"input_ids": resulting_tokens, "span_masks": masks}


def pad_and_stack_masks(masks: list[list[torch.Tensor]]) -> torch.Tensor:
    """
    Pad and stack span masks into a single tensor with shape (B, L, C).

    Args:
        masks (list[list[torch.Tensor]]): List of lists of masks for each choice.

    Returns:
        torch.Tensor: Padded tensor of shape (B, L, C), where B is the batch size,
                      L is the padded length, and C is the number of choices.
    """
    # Get batch size (B) and the maximum number of choices (C)
    batch_size = len(masks)
    max_choices = max(len(choices) for choices in masks)

    # Determine the maximum length (L) across all masks
    max_length = max(max(mask.size(0) for mask in choices) for choices in masks)

    # Create a padded tensor of shape (B, L, C) filled with 0
    padded_tensor = torch.zeros(batch_size, max_length, max_choices, dtype=torch.long)

    for b_idx, choices in enumerate(masks):
        for c_idx, mask in enumerate(choices):
            # Copy the mask into the appropriate location in the padded tensor
            padded_tensor[b_idx, : mask.size(0), c_idx] = mask

    return padded_tensor

def collate_fn(
    examples: list[dict],
    tokenizer: PreTrainedTokenizerFast,
    tokenize_kwargs: dict = {
        "return_tensors": "pt",
        "padding": True,
    },
):
    """
    Collate function for preparing a batch of examples for multiple-choice question answering.

    Args:
        examples (list[dict]): List of preprocessed examples.
        tokenizer (PreTrainedTokenizerFast): Tokenizer to pad input IDs.
        tokenize_kwargs (dict): Additional arguments for tokenizer padding.

    Returns:
        dict: A dictionary containing padded input IDs, attention_masks, and span masks.
    """
    input_ids: list[dict[str, list[int]]] = [
        {"input_ids": line["input_ids"]} for line in examples
    ]
    batch = tokenizer.pad(input_ids, **tokenize_kwargs)

    # the inner lists of masks are all of equal length and comprise masks [0, 0, 1, 1, 1, 0, 0, ...]
    masks: list[list[torch.Tensor]] = [
        torch.LongTensor(line["span_masks"]) for line in examples
    ]
    batch["span_mask"] = pad_and_stack_masks(masks)
    batch["labels"] = torch.LongTensor(
        [int(line["correct_answer_num"]) - 1 for line in examples]
    )
    batch["return_dict"] = False
    return batch


def repl():
    # tokenizer = AutoTokenizer.from_pretrained("roberta-large")
    from src.fleurs.utils import get_tokenizer_with_pad_token_id

    tokenizer = get_tokenizer_with_pad_token_id(
        "LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse"
    )
    # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
    # idea: tokenize everything separately, then stitch together
    # if context length exceeds model, truncate from paragraph, then question

    from datasets import load_dataset
    eng_Latn = load_dataset("wuenlp/fleurs-belebele", "eng_Latn", split="test")
    examples = eng_Latn[:3]

    columns = {
        "choices": ["mc_answer1", "mc_answer2", "mc_answer3", "mc_answer4"],
        "context": "flores_passage",
        "question": "question",
        "label": "correct_answer_num",
    }

    dataset = eng_Latn.map(
        preprocess,
        fn_kwargs={
            "tokenizer": tokenizer,
            "columns": columns,
            "debug": True,
            "max_length": 512,
        },
        batched=True,
        batch_size=50,
    )
    for i in range(10):
        input_ids = dataset[i]['input_ids']
        print(tokenizer.decode(input_ids))
        span_masks = dataset[i]["span_masks"]
        for j, mask in enumerate(span_masks):
            print(j, tokenizer.decode([i for i, m in zip(input_ids, mask) if m == 1]))

    mapper = select_audio_mapper("eng_Latn")
    e = eng_Latn.map(
        mapper, batched=True, batch_size=30, remove_columns=["sentence_data"]
    )


    examples = [dataset[i] for i in range(10)]

    # Assuming `masks` is a list of lists of LongTensors
    # Example: masks = [[torch.LongTensor([0, 0, 1, 1, 1]), torch.LongTensor([0, 1, 1])], ...]

    from transformers import AutoModel, AutoConfig
    from transformers.models.roberta.modeling_roberta import RobertaModel
    from src.fleurs.modeling.multiple_choice import ModelForMultipleChoice

    config = AutoConfig.from_pretrained(
        "roberta-large",
        num_hidden_layers=2,
        hidden_size=8,
        intermediate_size=16,
        num_attention_heads=4,
    )
    model = RobertaModel(config)
    mc_model = ModelForMultipleChoice(model)

    mc_model(**batch)

    from hydra.utils import instantiate
    from omegaconf import OmegaConf

    with open("./configs/dataspec/belebele_text_train_val.yaml") as f:
        s = f.read()
    cfg = OmegaConf.create(s)
    OmegaConf.resolve(cfg)
    from trident.core.dataspec import TridentDataspec

    dataspec = TridentDataspec(cfg.test, "test")
    dl = dataspec.get_dataloader()

    batch = next(iter(dl))
