# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import (
    BertConfig,
    BertForMaskedLM,
    GPT2Config,
    GPT2ForSequenceClassification,
)

from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.test_utils import torch_device
from accelerate.utils import DistributedType, set_seed


model_to_config = {
    "bert": (BertForMaskedLM, BertConfig, 512),
    "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}


def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
    initializer, config, seq_len = model_to_config[model_name]
    config_args = {}
    # Eventually needed for batch inference tests on gpt-2 when bs != 1
    # if model_name == "gpt2":
    #     config_args["pad_token_id"] = 0
    model_config = config(**config_args)
    model = initializer(model_config)
    kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
    trace_input = torch.randint(size=(1, seq_len), **kwargs)
    inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
    return model, trace_input, inference_inputs


def test_bert(batch_size: int = 2):
    set_seed(42)
    state = PartialState()
    model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
    model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
    # For inference args need to be a tuple
    inputs = inference_inputs.to(torch_device)
    with torch.no_grad():
        output = model(inputs)
    # Zach: Check that we just grab the real outputs we need at the end
    if not state.is_last_process:
        assert output is None, "Output was not generated on just the last process!"
    else:
        assert output is not None, "Output was not generated in the last process!"


def test_gpt2(batch_size: int = 2):
    set_seed(42)
    state = PartialState()
    model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
    model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
    # For inference args need to be a tuple
    inputs = inference_inputs.to(torch_device)
    with torch.no_grad():
        output = model(inputs)
    # Zach: Check that we just grab the real outputs we need at the end
    if not state.is_last_process:
        assert output is None, "Output was not generated on just the last process!"
    else:
        assert output is not None, "Output was not generated in the last process!"


# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
# def test_resnet(batch_size: int = 2):
#     set_seed(42)
#     state = PartialState()
#     model = resnet34()
#     input_tensor = torch.rand(1, 3, 224, 224)
#     model = prepare_pippy(
#         model,
#         example_args=(input_tensor,),
#     )
#     inference_inputs = torch.rand(batch_size, 3, 224, 224)
#     inputs = send_to_device(inference_inputs, torch_device)
#     with torch.no_grad():
#         output = model(inputs)
#     # Zach: Check that we just grab the real outputs we need at the end
#     if not state.is_last_process:
#         assert output is None, "Output was not generated on just the last process!"
#     else:
#         assert output is not None, "Output was not generated in the last process!"


if __name__ == "__main__":
    state = PartialState()
    state.print("Testing pippy integration...")
    try:
        if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]:
            state.print("Testing GPT2...")
            test_gpt2()
            # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
            # due to references
            # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
            # test_gpt2(3)
            state.print("Testing BERT...")
            test_bert()
        else:
            print("Less than two GPUs found, not running tests!")
    finally:
        state.destroy_process_group()
