# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import math
import os
from copy import deepcopy

import datasets
import evaluate
import torch
import transformers
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.data_loader import DataLoaderDispatcher
from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
from accelerate.utils import is_torch_xla_available, set_seed


os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"


class ListHandler(logging.Handler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logs = []

    def emit(self, record):
        self.logs.append(record)


def get_basic_setup(accelerator, num_samples=82, batch_size=16):
    "Returns everything needed to perform basic training"
    set_seed(42)
    model = RegressionModel()
    ddp_model = deepcopy(model)
    dset = RegressionDataset(length=num_samples)
    dataloader = DataLoader(dset, batch_size=batch_size)
    model.to(accelerator.device)
    ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
    return model, ddp_model, dataloader


def get_dataloader(accelerator: Accelerator, use_longest=False):
    tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased")
    dataset = load_dataset("glue", "mrpc", split="validation")

    def tokenize_function(examples):
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    with accelerator.main_process_first():
        tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["idx", "sentence1", "sentence2"],
        )

    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        if use_longest:
            return tokenizer.pad(examples, padding="longest", return_tensors="pt")
        return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")

    return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16)


def get_mrpc_setup(dispatch_batches, split_batches):
    dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches)
    accelerator = Accelerator(dataloader_config=dataloader_config)
    dataloader = get_dataloader(accelerator, not dispatch_batches)
    model = AutoModelForSequenceClassification.from_pretrained(
        "hf-internal-testing/mrpc-bert-base-cased", return_dict=True
    )
    ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)
    return {
        "ddp": [ddp_model, ddp_dataloader, torch_device],
        "no": [model, dataloader, accelerator.device],
    }, accelerator


def generate_predictions(model, dataloader, accelerator):
    logits_and_targets = []
    for batch in dataloader:
        input, target = batch.values()
        with torch.no_grad():
            logit = model(input)
            logit, target = accelerator.gather_for_metrics((logit, target))
            logits_and_targets.append((logit, target))
    logits, targs = [], []
    for logit, targ in logits_and_targets:
        logits.append(logit)
        targs.append(targ)
    logits, targs = torch.cat(logits), torch.cat(targs)
    return logits, targs


def test_torch_metrics(
    accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16
):
    _, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)
    logits, _ = generate_predictions(ddp_model, dataloader, accelerator)
    assert len(logits) == num_samples, (
        f"Unexpected number of inputs:\n    Expected: {num_samples}\n    Actual: {len(logits)}"
    )


def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
    metric = evaluate.load("glue", "mrpc")
    setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches)
    # First do baseline
    model, dataloader, device = setup["no"]
    model.to(device)
    model.eval()
    for batch in dataloader:
        batch.to(device)
        with torch.inference_mode():
            outputs = model(**batch)
        preds = outputs.logits.argmax(dim=-1)
        metric.add_batch(predictions=preds, references=batch["labels"])
    baseline = metric.compute()

    # Then do distributed
    model, dataloader, device = setup["ddp"]
    model.eval()
    for batch in dataloader:
        with torch.inference_mode():
            outputs = model(**batch)
        preds = outputs.logits.argmax(dim=-1)
        references = batch["labels"]
        preds, references = accelerator.gather_for_metrics((preds, references))
        metric.add_batch(predictions=preds, references=references)
    distributed = metric.compute()

    for key in "accuracy f1".split():
        assert math.isclose(baseline[key], distributed[key]), (
            f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n"
        )


def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():
    class DummyIterableDataset(IterableDataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __iter__(self):
            yield from self.data

    iterable_dataset = DummyIterableDataset([n for n in range(30)])
    dataloader = DataLoader(iterable_dataset, batch_size=4)
    accelerator = Accelerator()
    prepared_dataloader = accelerator.prepare(dataloader)

    if accelerator.is_main_process:
        logger = logging.root.manager.loggerDict["accelerate.accelerator"]
        list_handler = ListHandler()
        logger.addHandler(list_handler)

    batches_for_metrics = []
    for batch in prepared_dataloader:
        batches_for_metrics.append(accelerator.gather_for_metrics(batch))

    assert torch.cat(batches_for_metrics).size(0) == 30

    if accelerator.is_main_process:
        assert len(list_handler.logs) == 0
        logger.removeHandler(list_handler)


def test_gather_for_metrics_with_iterable_dataset():
    class DummyIterableDataset(IterableDataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __iter__(self):
            yield from self.data

    iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30)))
    dataloader = DataLoader(iterable_dataset, batch_size=4)

    accelerator = Accelerator()
    prepared_dataloader = accelerator.prepare(dataloader)

    assert isinstance(prepared_dataloader, DataLoaderDispatcher)

    if accelerator.is_main_process:
        logger = logging.root.manager.loggerDict["accelerate.accelerator"]
        list_handler = ListHandler()
        logger.addHandler(list_handler)

    batches_for_metrics = []
    for batch in prepared_dataloader:
        batches_for_metrics.append(accelerator.gather_for_metrics(batch))

    assert torch.cat(batches_for_metrics).size(0) == 30

    if accelerator.is_main_process:
        assert len(list_handler.logs) == 0

        logger.removeHandler(list_handler)


def test_gather_for_metrics_drop_last():
    accelerator = Accelerator()
    per_device_batch_size = 5
    num_items = (10 * accelerator.num_processes) + 1
    dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True)
    dataloader = accelerator.prepare(dataloader)

    iterator = iter(dataloader)
    next(iterator)  # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')
    batch = next(iterator)
    gathered_items = accelerator.gather_for_metrics(batch)

    # Should return a full set of complete batches from each GPU
    num_expected_items = per_device_batch_size * accelerator.num_processes
    assert gathered_items.size(0) == (num_expected_items), (
        f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}"
    )


def main():
    dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False)
    accelerator = Accelerator(dataloader_config=dataloader_config)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    # TorchXLA does not support batch dispatching. 'put_on_device' is always False for
    # TorchXLA, which can cause a value error in 'prepare_data_loader' function.
    dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False]

    # Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for
    # inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug.
    # These are a bit slower so they should only be ran on the GPU or TPU
    if accelerator.device.type != "cpu" and not is_torch_xla_available():
        if accelerator.is_local_main_process:
            print("**Testing gather_for_metrics**")
        for split_batches in [True, False]:
            for dispatch_batches in dispatch_batches_options:
                if accelerator.is_local_main_process:
                    print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
                test_mrpc(dispatch_batches, split_batches)
                accelerator.state._reset_state()
        print("test_gather_for_metrics_with_iterable_dataset")
        test_gather_for_metrics_with_iterable_dataset()
        print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
        test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()

    # MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache.
    # This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended.
    # Skip this test when TorchXLA is enabled.
    if accelerator.state.distributed_type != DistributedType.XLA:
        if accelerator.is_local_main_process:
            print("**Test torch metrics**")
        for split_batches in [True, False]:
            for dispatch_batches in dispatch_batches_options:
                dataloader_config = DataLoaderConfiguration(
                    split_batches=split_batches, dispatch_batches=dispatch_batches
                )
                accelerator = Accelerator(dataloader_config=dataloader_config)
                if accelerator.is_local_main_process:
                    print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
                test_torch_metrics(accelerator, 99)
                accelerator.state._reset_state()
    if accelerator.is_local_main_process:
        print("**Test last batch is not dropped when perfectly divisible**")
    accelerator = Accelerator()
    test_torch_metrics(accelerator, 512)
    accelerator.state._reset_state()
    if accelerator.is_local_main_process:
        print("**Test that `drop_last` is taken into account**")
    test_gather_for_metrics_drop_last()
    accelerator.end_training()
    accelerator.state._reset_state()


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()
