import collections
import json
import logging
import os
import random
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from transformers import AutoTokenizer, DataCollatorWithPadding, EvalPrediction, default_data_collator

from MQSP_evaluation.distributed_utils import DistGroups

from ..base.dataset import broadcast_data, get_sp_method

logger = logging.getLogger(__name__)


def get_qa_dataloader(
    datasets_path,
    model_name_or_path=None,
    pad_to_max_length=False,
    test_size=0.1,
    preprocessing_num_workers=4,
    use_fp16=False,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    doc_stride=128,
    max_seq_length=8192,
):
    raw_datasets = load_from_disk(datasets_path, keep_in_memory=True)
    column_names = raw_datasets["train"].column_names

    question_column_name = "question" if "question" in column_names else column_names[0]
    context_column_name = "context" if "context" in column_names else column_names[1]
    context_column_name = "supports" if "supports" in column_names else context_column_name
    # hotpot_qa
    answer_column_name = "answers" if "answers" in column_names else column_names[2]
    answer_column_name = "answer" if "answer" in column_names else answer_column_name
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"
    print("init dataset pad_on_right", pad_on_right)
    if max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )

    max_seq_length = min(max_seq_length, tokenizer.model_max_length)
    # Training preprocessing

    def prepare_context(examples):
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
        context_rows = examples[context_column_name if pad_on_right else question_column_name]
        if context_rows and isinstance(context_rows[0], dict):  # hotpot_qa {"title":[],"sentences":[["xx"]]}
            context_rows = ["\n".join(["\n".join(senten) for senten in item["sentences"]]) for item in context_rows]
        if context_rows and isinstance(context_rows[0], list):
            context_rows = ["\n".join(item) for item in context_rows]
        examples[context_column_name if pad_on_right else question_column_name] = context_rows
        return examples

    def prepare_train_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        context_rows = examples[context_column_name if pad_on_right else question_column_name]
        if context_rows and isinstance(context_rows[0], dict):  # hotpot_qa {"title":[],"sentences":[["xx"]]}
            context_rows = ["\n".join(["\n".join(senten) for senten in item["sentences"]]) for item in context_rows]
        if context_rows and isinstance(context_rows[0], list):
            context_rows = ["\n".join(item) for item in context_rows]
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            context_rows,
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding=True,
            pad_to_multiple_of=dist.get_world_size(),  # pad to world size so same with ddp
        )
        tokenized_examples["token_length"] = []

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            # squad,train_ds[0]["answers"]: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}
            # hotpot_qa train_ds2[0]["answer"]: "Arthur's Magazine"

            answers = examples[answer_column_name][sample_index]
            if answers is None:  # hotpot qa ?
                continue
            tokenized_examples["token_length"].append(len(context_rows[sample_index]))
            if not isinstance(answers, dict):
                context_text = context_rows[sample_index]
                answer_start = context_text.lower().find(answers.lower())  # emm,why not found
                if answer_start > 0:
                    answers = {"text": [answers], "answer_start": [answer_start]}
                else:
                    answers = {"text": [answers], "answer_start": []}
                    # print("answers not found", context_text, answers)
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    start_positions = token_start_index - 1
                    tokenized_examples["start_positions"].append(start_positions)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    end_positions = token_end_index + 1
                    if end_positions <= start_positions:
                        end_positions = start_positions + 1
                    tokenized_examples["end_positions"].append(end_positions)

        return tokenized_examples

    # train_dataset = raw_datasets["train"]
    # # hotpot_qa have testset,but have no answer
    print("raw_datasets", raw_datasets)
    # Validation preprocessing
    # Log a few random samples from the training set:
    # raw_datasets["validation"]
    train_ds = raw_datasets["train"]  # .shuffle()
    splits = train_ds.train_test_split(test_size=test_size)
    train_ds = splits["train"]
    test_ds = splits["test"]
    raw_datasets["train"] = train_ds
    raw_datasets["test"] = test_ds
    # raw_datasets["train"] = train_ds
    for index in random.sample(range(len(raw_datasets["train"])), 3):
        logger.info(f"Sample {index} of the training set: {raw_datasets['train'][index]}.")
    cache_file_names = {
        "train": os.path.join(datasets_path, "train", "qa-%s.arrow" % (max_seq_length)),
        "validation": os.path.join(datasets_path, "validation", "qa-%s.arrow" % (max_seq_length)),
        "test": os.path.join(datasets_path, "test", "qa-%s.arrow" % (max_seq_length)),
    }
    features_ds = raw_datasets.map(
        prepare_train_features,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=True,
        cache_file_names=cache_file_names,
        desc="Running tokenizer on train dataset",
    )
    train_ds = features_ds["train"]

    def get_statictis_dataset(examples):
        statistic_dataset = {"answer_length": []}
        statistic_dataset["token_length"] = examples["token_length"]
        statistic_dataset["start_positions"] = examples["start_positions"]
        statistic_dataset["end_positions"] = examples["end_positions"]
        for start_, end_ in zip(statistic_dataset["start_positions"], statistic_dataset["end_positions"]):
            statistic_dataset["answer_length"].append(end_ - start_)
        return statistic_dataset

    statistic_ds = train_ds.map(
        get_statictis_dataset,
        batched=True,
        # num_proc=preprocessing_num_workers,
    )
    if DistGroups["dp"].rank() == 0:
        print(statistic_ds.to_pandas().describe())

    def prepare_validation_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        context_rows = examples[context_column_name if pad_on_right else question_column_name]
        if context_rows and isinstance(context_rows[0], dict):  # hotpot_qa {"title":[],"sentences":[["xx"]]}
            context_rows = ["\n".join(["\n".join(senten) for senten in item["sentences"]]) for item in context_rows]
        if context_rows and isinstance(context_rows[0], list):
            context_rows = ["\n".join(item) for item in context_rows]

        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            context_rows,
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding=True,
            pad_to_multiple_of=dist.get_world_size(),
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples

    eval_examples = raw_datasets["validation"]
    test_examples = raw_datasets["test"]
    eval_dataset = eval_examples.map(
        prepare_validation_features,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Running tokenizer on validation dataset",
    )
    test_dataset = test_examples.map(
        prepare_validation_features,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Running tokenizer on test dataset",
    )
    if pad_to_max_length:
        # If padding was already done ot max length, we use the default data collator that will just convert everything
        # to tensors.
        data_collator = default_data_collator
    else:
        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if use_fp16 else None))
    sampler = DistributedSampler(
        train_ds, shuffle=False, num_replicas=DistGroups["dp"].size(), rank=DistGroups["dp"].rank()
    )
    train_dataloader = DataLoader(
        train_ds,
        shuffle=sampler is None,
        collate_fn=data_collator,
        sampler=sampler,
        batch_size=per_device_train_batch_size,
    )
    if DistGroups["dp"].rank() == 0:
        print(train_ds[9])
    # FIXIT
    eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
    eval_dataloader = DataLoader(
        eval_dataset_for_model, collate_fn=data_collator, batch_size=per_device_eval_batch_size
    )
    # hold non-preprocessed,let eval fn use it;
    eval_dataloader.examples = eval_examples
    eval_dataloader.origin_dataset = eval_dataset
    # FIXIT qa dataset.test have no answer
    test_dataset_for_model = test_dataset.remove_columns(["example_id", "offset_mapping"])
    test_dataloader = DataLoader(
        test_dataset_for_model, collate_fn=data_collator, batch_size=per_device_eval_batch_size
    )
    test_dataloader.examples = test_examples
    test_dataloader.origin_dataset = test_dataset

    return train_dataloader, eval_dataloader, test_dataloader, tokenizer


def qa_eval_fn(model, args, eval_iter_or_dataset, eval_length, get_batch_fn, metric):
    device = torch.cuda.current_device()
    with torch.no_grad():
        model.eval()
        # max_length = model.config.n_positions
        all_start_logits = []
        all_end_logits = []
        assert metric is not None, "need metric_fn on qa task"
        if eval_iter_or_dataset is not None:
            eval_iter = iter(eval_iter_or_dataset)
        else:
            eval_iter = None

        for step in tqdm(
            range(eval_length),
            desc="eval",
            disable=not (dist.get_rank() == 0),
        ):
            batch = get_batch_fn(eval_iter)
            # input_ids = batch["input_ids"].size(1)  # (batch, sub_seq_length)
            batch = {k: v.to(device=device, non_blocking=True) for k, v in batch.items()}
            outputs = model(**batch)
            start_logits = outputs.start_logits
            end_logits = outputs.end_logits

            # start_tensor_list = [
            #     torch.zeros_like(start_logits, device=device) for _ in range(DistGroups["sp_cpu"].size())
            # ]
            # end_tensor_list = [
            #     torch.zeros_like(end_logits, device=device) for _ in range(DistGroups["sp_cpu"].size())
            # ]
            # dist.all_gather(start_tensor_list, start_logits, group=DistGroups["sp"])
            # dist.all_gather(end_tensor_list, end_logits, group=DistGroups["sp"])
            # TODO: all reduce to get mean? or just use rank0 for predict?
            dist.all_reduce(start_logits, op=dist.ReduceOp.SUM, group=DistGroups["sp"])
            dist.all_reduce(end_logits, op=dist.ReduceOp.SUM, group=DistGroups["sp"])

            # concat seq dim=1
            all_start_logits.append((start_logits / DistGroups["sp"].size()).cpu().numpy())
            all_end_logits.append((end_logits / DistGroups["sp"].size()).cpu().numpy())
            if step == 0:
                print(
                    "start_logits",
                    start_logits.shape,
                    end_logits.shape,
                    all_start_logits[0].shape,
                    all_end_logits[0].shape,
                )
        model.train()
        if eval_iter_or_dataset is None:
            # only compute eval on rank0
            return 0.0
        max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor
        predict_length = sum([x.shape[0] for x in all_start_logits])
        # result length
        start_logits_concat = create_and_fill_np_array(all_start_logits, predict_length, max_len)
        end_logits_concat = create_and_fill_np_array(all_end_logits, predict_length, max_len)
        # # delete the list of numpy arrays
        del all_start_logits
        del all_end_logits
        outputs_numpy = (start_logits_concat, end_logits_concat)
        post_processing_function = get_post_process_func()

        prediction = post_processing_function(
            eval_iter_or_dataset.examples, eval_iter_or_dataset.origin_dataset, outputs_numpy
        )
        eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
        # {'exact_match': 0.0, 'f1': 2.22565430929634478}
        print(eval_metric)
        eval_metric["perplexity"] = 100.0 - eval_metric["f1"]
        return eval_metric


# Post-processing:


def postprocess_qa_predictions(
    examples,
    features,
    predictions: Tuple[np.ndarray, np.ndarray],
    version_2_with_negative: bool = False,
    n_best_size: int = 20,
    max_answer_length: int = 30,
    null_score_diff_threshold: float = 0.0,
    output_dir: Optional[str] = None,
    prefix: Optional[str] = None,
):
    """
    Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
    original contexts. This is the base postprocessing functions for models that only return start and end logits.

    Args:
        examples: The non-preprocessed dataset (see the main script for more information).
        features: The processed dataset (see the main script for more information).
        predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
            The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
            first dimension must match the number of elements of :obj:`features`.
        version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the underlying dataset contains examples with no answers.
        n_best_size (:obj:`int`, `optional`, defaults to 20):
            The total number of n-best predictions to generate when looking for an answer.
        max_answer_length (:obj:`int`, `optional`, defaults to 30):
            The maximum length of an answer that can be generated. This is needed because the start and end predictions
            are not conditioned on one another.
        null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
            The threshold used to select the null answer: if the best answer has a score that is less than the score of
            the null answer minus this threshold, the null answer is selected for this example (note that the score of
            the null answer for an example giving several features is the minimum of the scores for the null answer on
            each feature: all features must be aligned on the fact they `want` to predict a null answer).

            Only useful when :obj:`version_2_with_negative` is :obj:`True`.
        output_dir (:obj:`str`, `optional`):
            If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
            :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
            answers, are saved in `output_dir`.
        prefix (:obj:`str`, `optional`):
            If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
        log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
            ``logging`` log level (e.g., ``logging.WARNING``)
    """
    if len(predictions) != 2:
        raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
    all_start_logits, all_end_logits = predictions

    if len(predictions[0]) > len(features):
        raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    if version_2_with_negative:
        scores_diff_json = collections.OrderedDict()

    logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
    predict_length = len(all_start_logits)
    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_prediction = None
        prelim_predictions = []
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            if feature_index > predict_length - 1:  # out of predict
                continue
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]
            # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
            # available in the current feature.
            token_is_max_context = features[feature_index].get("token_is_max_context", None)

            # Update minimum null prediction.
            feature_null_score = start_logits[0] + end_logits[0]
            if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or len(offset_mapping[start_index]) < 2
                        or offset_mapping[end_index] is None
                        or len(offset_mapping[end_index]) < 2
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue
                    # Don't consider answer that don't have the maximum context available (if such information is
                    # provided).
                    if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                        continue

                    prelim_predictions.append(
                        {
                            "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                            "score": start_logits[start_index] + end_logits[end_index],
                            "start_logit": start_logits[start_index],
                            "end_logit": end_logits[end_index],
                        }
                    )
        if version_2_with_negative:
            # Add the minimum null prediction
            prelim_predictions.append(min_null_prediction)
            null_score = min_null_prediction["score"]

        # Only keep the best `n_best_size` predictions.
        predictions_list = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
        # Add back the minimum null prediction if it was removed because of its low score.
        if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions_list):
            predictions_list.append(min_null_prediction)

        # Use the offsets to gather the answer text in the original context.
        context = example.get("context", example.get("supports"))
        if isinstance(context, list):
            context = "\n".join(context)
        for pred in predictions_list:
            offsets = pred.pop("offsets")
            # lower? no, metric script use lower text
            pred["text"] = context[offsets[0] : offsets[1]].lower()

        # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
        # failure.
        # Note: only run partial eval
        if len(predictions_list) == 0 or (len(predictions_list) == 1 and predictions_list[0]["text"] == ""):
            # predictions_list.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
            continue
        # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
        # the LogSumExp trick).
        scores = np.array([pred.pop("score") for pred in predictions_list])
        exp_scores = np.exp(scores - np.max(scores))
        probs = exp_scores / exp_scores.sum()

        # Include the probabilities in our predictions.
        for prob, pred in zip(probs, predictions_list):
            pred["probability"] = prob

        # Pick the best prediction. If the null answer is not possible, this is easy.
        if not version_2_with_negative:
            all_predictions[example["id"]] = predictions_list[0]["text"]
        else:
            # Otherwise we first need to find the best non-empty prediction.
            i = 0
            while predictions_list[i]["text"] == "":
                i += 1
            best_non_null_pred = predictions_list[i]

            # Then we compare to the null prediction using the threshold.
            score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
            scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
            if score_diff > null_score_diff_threshold:
                all_predictions[example["id"]] = ""
            else:
                all_predictions[example["id"]] = best_non_null_pred["text"]

        # Make `predictions` JSON-serializable by casting np.float back to float.
        all_nbest_json[example["id"]] = [
            {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
            for pred in predictions_list
        ]

    # If we have an output_dir, let's save all those dicts.
    if output_dir is not None:
        if not os.path.isdir(output_dir):
            raise EnvironmentError(f"{output_dir} is not a directory.")

        prediction_file = os.path.join(
            output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
        )
        nbest_file = os.path.join(
            output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
        )
        if version_2_with_negative:
            null_odds_file = os.path.join(
                output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
            )

        logger.info(f"Saving predictions to {prediction_file}.")
        with open(prediction_file, "w") as writer:
            writer.write(json.dumps(all_predictions, indent=4) + "\n")
        logger.info(f"Saving nbest_preds to {nbest_file}.")
        with open(nbest_file, "w") as writer:
            writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
        if version_2_with_negative:
            logger.info(f"Saving null_odds to {null_odds_file}.")
            with open(null_odds_file, "w") as writer:
                writer.write(json.dumps(scores_diff_json, indent=4) + "\n")

    return all_predictions


def get_post_process_func(n_best_size=20, max_answer_length=30, null_score_diff_threshold=0.0):
    def post_processing_function(examples, features, predictions, stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions = postprocess_qa_predictions(
            examples=examples,
            features=features,
            predictions=predictions,
            version_2_with_negative=False,
            n_best_size=n_best_size,
            max_answer_length=max_answer_length,
            null_score_diff_threshold=null_score_diff_threshold,
            output_dir=None,
            prefix=stage,
        )
        # Format the result to the format the metric expects.
        formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
        references = []
        show_invalid_count = 3
        pic_record_ids = {pred_["id"] for pred_ in formatted_predictions}
        for ex in examples:
            # only keep id which in predict, so they have same length for metric
            if ex["id"] not in pic_record_ids:
                continue
            answers = ex.get("answers", ex.get("answer"))
            context_text = ex.get("context", ex.get("supports"))
            if isinstance(context_text, dict):  # hotpot_qa {"title":[],"sentences":[["xx"]]}
                context_text = "\n".join(["\n".join(senten) for senten in context_text["sentences"]])
            if isinstance(context_text, list):
                context_text = "\n".join(context_text)

            if answers is None:  # hotpot qa ?
                continue
            if not isinstance(answers, dict):
                # context_text = context_rows[sample_index]
                answer_start = context_text.lower().find(answers.lower())  # emm,why not found

                if answer_start > 0:
                    answers = {"text": [answers], "answer_start": [answer_start]}
                else:
                    if show_invalid_count > 0:
                        print("not found answer in context", answers, len(answers), context_text)
                        show_invalid_count -= 1
                    answers = {"text": [answers], "answer_start": []}
            references.append({"id": ex["id"], "answers": answers})
        return EvalPrediction(predictions=formatted_predictions, label_ids=references)

    return post_processing_function


def create_and_fill_np_array(start_or_end_logits, eval_length, max_len):
    """
    Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor

    Args:
        start_or_end_logits(:obj:`tensor`):
            This is the output predictions of the model. We can only enter either start or end logits.
        eval_dataset: Evaluation dataset
        max_len(:obj:`int`):
            The maximum length of the output tensor. ( See the model.eval() part for more details )
    """

    step = 0
    # create a numpy array and fill it with -100.
    logits_concat = np.full((eval_length, max_len), -100, dtype=np.float64)
    # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather
    for i, output_logit in enumerate(start_or_end_logits):  # populate columns
        # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
        # And after every iteration we have to change the step

        batch_size = output_logit.shape[0]
        cols = output_logit.shape[1]

        if step + batch_size < eval_length:
            logits_concat[step : step + batch_size, :cols] = output_logit
        else:
            logits_concat[step:, :cols] = output_logit[: eval_length - step]

        step += batch_size

    return logits_concat


def get_qa_train_batch(data_iterator):
    # keys = ["input_ids", "token_type_ids", "attention_mask", "labels"]
    keys = ["attention_mask", "end_positions", "input_ids", "start_positions", "token_type_ids"]
    datatype = torch.int32
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)
    # # get tensor parallel local rank
    local_world_size = DistGroups["sp"].size()
    local_rank = DistGroups["sp"].rank()
    seq_length = data_b["input_ids"].size(1)
    sub_seq_length = seq_length // local_world_size
    print("qa train batch", seq_length, sub_seq_length)
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank + 1) * sub_seq_length

    # # Unpack.
    input_ids = data_b["input_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    token_type_ids = data_b["token_type_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    sp_method = get_sp_method()
    if sp_method in ("colai", "megatron", "single"):
        attention_mask = data_b["attention_mask"].contiguous()
    elif sp_method.startswith("qasp"):
        attention_mask = data_b["attention_mask"][:, sub_seq_start:sub_seq_end].contiguous()
    else:
        raise ValueError
    # position_ids = full_position_ids[:, sub_seq_start:sub_seq_end].long().contiguous()
    start_positions = data_b["start_positions"].long().contiguous()
    end_positions = data_b["end_positions"].long().contiguous()
    return {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask,
        "start_positions": start_positions,
        "end_positions": end_positions,
    }


def get_qa_eval_batch(data_iterator):
    # keys = ["input_ids", "token_type_ids", "attention_mask", "labels"]
    keys = ["attention_mask", "input_ids", "token_type_ids"]
    datatype = torch.int32
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)
    # # get tensor parallel local rank
    local_world_size = DistGroups["sp"].size()
    local_rank = DistGroups["sp"].rank()
    seq_length = data_b["input_ids"].size(1)
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank + 1) * sub_seq_length

    # # Unpack.
    input_ids = data_b["input_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    token_type_ids = data_b["token_type_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    sp_method = get_sp_method()
    if sp_method in ("colai", "megatron", "single"):
        attention_mask = data_b["attention_mask"].contiguous()
    elif sp_method.startswith("qasp"):
        attention_mask = data_b["attention_mask"][:, sub_seq_start:sub_seq_end].contiguous()
    else:
        raise ValueError
    # position_ids = full_position_ids[:, sub_seq_start:sub_seq_end].long().contiguous()
    return {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask,
    }
