import torch
import evaluate
from tqdm import tqdm
from torch.utils.data import DataLoader
from accelerate import Accelerator, DistributedDataParallelKwargs


def test_mrpc(model, dataset, args, round_idx):
    """
    Evaluate the global model on GLUE-MRPC.

    Parameters
    ----------
    model :
        The evaluated model (PEFT-enhanced BERT in this example).
    dataset :
        Validation or test split of GLUE-MRPC.
    args :
        Runtime arguments including batch size, collator and logger.
    round_idx : int
        Federated communication round index.

    Returns
    -------
    f1 : float
        F1-score on MRPC evaluation set.
    acc : float
        Accuracy on MRPC evaluation set.
    loss_placeholder : None
        Returned to match the caller signature.
    """

    metric = evaluate.load("glue", "mrpc")

    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

    model.eval()
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=args.data_collator,
    )

    dataloader, model = accelerator.prepare(dataloader, model)

    for step, batch in tqdm(
        enumerate(dataloader),
        desc=f"Evaluating Round {round_idx}",
        total=len(dataloader),
        disable=not accelerator.is_local_main_process,
    ):
        with torch.no_grad():
            outputs = model(**batch)

        preds = outputs.logits.argmax(dim=-1)
        preds, refs = accelerator.gather_for_metrics((preds, batch["labels"]))

        metric.add_batch(predictions=preds, references=refs)

    results = metric.compute()
    acc, f1 = results["accuracy"], results["f1"]

    args.logger.info(results, main_process_only=True)

    return f1, acc, None     # None - placeholder for test_loss
