# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

from transformers import AutoTokenizer, is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow


if is_flax_available():
    import jax.numpy as jnp
    from transformers import FlaxXLMRobertaModel


@require_sentencepiece
@require_tokenizers
@require_flax
class FlaxXLMRobertaModelIntegrationTest(unittest.TestCase):
    @slow
    def test_flax_xlm_roberta_base(self):
        model = FlaxXLMRobertaModel.from_pretrained("xlm-roberta-base")
        tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
        text = "The dog is cute and lives in the garden house"
        input_ids = jnp.array([tokenizer.encode(text)])

        expected_output_shape = (1, 12, 768)  # batch_size, sequence_length, embedding_vector_dim
        expected_output_values_last_dim = jnp.array(
            [[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]]
        )

        output = model(input_ids)["last_hidden_state"]
        self.assertEqual(output.shape, expected_output_shape)
        # compare the actual values for a slice of last dim
        self.assertTrue(jnp.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
