# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import fixed_init_model, mps_ignored_test

from torchtune.generation._generation import (
    generate,
    get_causal_mask_from_padding_mask,
    get_position_ids_from_padding_mask,
    sample,
)
from torchtune.models.llama2 import llama2


class TestGenerate:
    """
    Test class for text generation functionality in :func:`~torchtune.generation.geneate`.
    """

    @pytest.fixture
    def generation_model_no_kv_cache(self):
        model = llama2(
            vocab_size=4_000,
            embed_dim=128,
            num_layers=2,
            num_heads=4,
            num_kv_heads=4,
            max_seq_len=2048,
        )
        fixed_init_model(model)
        model.eval()
        return model

    @pytest.fixture
    def generation_model_kv_cache(self):
        model = llama2(
            vocab_size=4_000,
            embed_dim=128,
            num_layers=2,
            num_heads=4,
            num_kv_heads=4,
            max_seq_len=2048,
        )
        fixed_init_model(model)
        model.setup_caches(batch_size=1, dtype=torch.float32)
        model.eval()
        return model

    @pytest.fixture
    def generation_model_kv_cache_batched(self):
        model = llama2(
            vocab_size=4_000,
            embed_dim=128,
            num_layers=2,
            num_heads=4,
            num_kv_heads=4,
            max_seq_len=2048,
        )
        fixed_init_model(model)
        model.setup_caches(batch_size=3, dtype=torch.float32)
        model.eval()
        return model

    @pytest.fixture
    def generation_model_batched_fixed_cache_seq_len(self, dtype=torch.float32):
        model = llama2(
            vocab_size=4_000,
            embed_dim=128,
            num_layers=2,
            num_heads=4,
            num_kv_heads=4,
            max_seq_len=2048,
        )
        fixed_init_model(model)
        model.setup_caches(batch_size=3, dtype=dtype, decoder_max_seq_len=1024)
        model.eval()
        return model

    @pytest.fixture
    def prompt_tokens(self):
        """
        Pytest fixture to create a list of prompt tokens for testing.
        """
        return torch.arange(2, 10)

    @pytest.fixture
    def prompt_tokens_batched(self):
        """
        Pytest fixture to create a list of batched prompt tokens for testing.
        """
        return torch.arange(2, 10).repeat(3, 1)

    @pytest.fixture
    def prompt_tokens_left_padded(self):
        """
        Pytest fixture to create a list of left-padded prompt tokens for testing.
        """
        return torch.cat([torch.tensor([0, 0]), torch.arange(2, 10)])

    @pytest.fixture
    def prompt_tokens_batched_left_padded(self):
        """
        Pytest fixture to create a list of left-padded batched prompt tokens for testing.
        """
        return torch.cat([torch.tensor([0, 0]), torch.arange(2, 10)]).repeat(3, 1)

    @pytest.fixture
    def expected_tokens(self):
        """
        The numbers here are the first 10 tokens generated by the model
        with constantly initialized weights, a tensor input with range 2 through 10,
        and the manual seed set to 42. They do not correspond to "recognizable" tokens.
        """
        return torch.tensor(
            [
                [
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    3987,
                    3991,
                    3953,
                    3957,
                    3983,
                    3964,
                    3928,
                    3932,
                    3986,
                    3982,
                ]
            ]
        )

    @pytest.fixture
    def expected_tokens_batched(self):
        """
        The numbers here are the first 10 tokens generated by the model
        with constantly initialized weights, a tensor input with range 2 through 10,
        and the manual seed set to 42. They do not correspond to "recognizable" tokens.
        """
        return torch.tensor(
            [
                [
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    3987,
                    3991,
                    3953,
                    3957,
                    3983,
                    3964,
                    3928,
                    3932,
                    3986,
                    3982,
                ],
                [
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    3958,
                    3979,
                    3934,
                    3945,
                    3993,
                    3904,
                    3950,
                    3988,
                    3948,
                    3999,
                ],
                [
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    3989,
                    3976,
                    3997,
                    3960,
                    3989,
                    3956,
                    3917,
                    3949,
                    3917,
                    3987,
                ],
            ]
        )

    def test_sample_consistency(self):
        """
        Test token sampling produces the right output.
        """
        # set all probabilities except for token_id=100 to 0
        logits = torch.zeros(2000)
        logits[100] = 1
        token = sample(logits, top_k=1)
        assert token.item() == 100

    @pytest.mark.parametrize(
        "model1",
        ["generation_model_no_kv_cache", "generation_model_kv_cache"],
    )
    @pytest.mark.parametrize(
        "model2",
        ["generation_model_no_kv_cache", "generation_model_kv_cache"],
    )
    def test_reproducibility(self, request, model1, model2, prompt_tokens):
        """
        Test to check if the generate function produces the same output when run with the same
        fixed seed and models with and without kv-cacheing.
        """

        model1 = request.getfixturevalue(model1)
        model2 = request.getfixturevalue(model2)

        temperature = 0.6
        top_k = 100

        torch.manual_seed(42)
        outputs_first, logits_first = generate(
            model=model1,
            prompt=prompt_tokens,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
        )

        torch.manual_seed(42)
        outputs_second, logits_second = generate(
            model=model2,
            prompt=prompt_tokens,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
        )
        assert torch.equal(outputs_first, outputs_second)
        torch.testing.assert_close(logits_first, logits_second)

    @pytest.mark.parametrize(
        "model1",
        [
            "generation_model_no_kv_cache",
            "generation_model_kv_cache_batched",
            "generation_model_batched_fixed_cache_seq_len",
        ],
    )
    @pytest.mark.parametrize(
        "model2",
        [
            "generation_model_no_kv_cache",
            "generation_model_kv_cache_batched",
            "generation_model_batched_fixed_cache_seq_len",
        ],
    )
    @pytest.mark.parametrize(
        "prompt1", ["prompt_tokens_batched", "prompt_tokens_batched_left_padded"]
    )
    @pytest.mark.parametrize(
        "prompt2", ["prompt_tokens_batched", "prompt_tokens_batched_left_padded"]
    )
    def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2):
        """
        Test to check if the generate function produces the same output when run with the same
        fixed seed, models with and without kv-cacheing, and batched inputs with and without left-padding.
        """

        model1 = request.getfixturevalue(model1)
        model2 = request.getfixturevalue(model2)
        prompt1 = request.getfixturevalue(prompt1)
        prompt2 = request.getfixturevalue(prompt2)

        temperature = 0.6
        top_k = 100

        torch.manual_seed(42)
        outputs_first, logits_first = generate(
            model=model1,
            prompt=prompt1,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
        )

        torch.manual_seed(42)
        outputs_second, logits_second = generate(
            model=model2,
            prompt=prompt2,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
        )

        # slicing for the last 18 tokens - this is the whole sequence for unpadded inputs
        # and excludes the first two tokens for padded inputs, which are padding tokens
        assert torch.equal(outputs_first[:, -18:], outputs_second[:, -18:])
        # logits are only ever returned for the generated tokens, so no slicing needed
        torch.testing.assert_close(logits_first, logits_second, atol=1e-4, rtol=1e-6)

    @pytest.mark.parametrize(
        "model",
        ["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
    )
    @pytest.mark.parametrize(
        "prompt", ["prompt_tokens_batched", "prompt_tokens_batched_left_padded"]
    )
    @mps_ignored_test()
    def test_stop_tokens_batched(self, request, model, prompt, expected_tokens_batched):
        """
        Test to check if the `generate` function produces the right output when stop tokens are
        provided.
        """
        model = request.getfixturevalue(model)
        prompt = request.getfixturevalue(prompt)
        temperature = 0.6
        top_k = 100

        # This is the first token generated by the model
        # so it should stop immediately resulting in only a single
        # token being generated
        stop_tokens = [3987, 3958, 3989]

        torch.manual_seed(42)

        outputs, _ = generate(
            model=model,
            prompt=prompt,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
            stop_tokens=stop_tokens,
        )

        assert torch.equal(outputs[:, -9:], expected_tokens_batched[:, :9])

    @pytest.mark.parametrize(
        "model",
        ["generation_model_no_kv_cache", "generation_model_kv_cache"],
    )
    @mps_ignored_test()
    def test_stop_tokens(self, request, model, prompt_tokens, expected_tokens):
        """
        Test to check if the `generate` function produces the right output when stop tokens are
        provided.
        """
        model = request.getfixturevalue(model)
        temperature = 0.6
        top_k = 100

        # This is the first token generated by the model
        # so it should stop immediately
        stop_tokens = [3987]

        torch.manual_seed(42)

        outputs, _ = generate(
            model=model,
            prompt=prompt_tokens,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
            stop_tokens=stop_tokens,
        )

        assert torch.equal(outputs, expected_tokens[:, :9])

    @pytest.mark.parametrize(
        "model",
        ["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
    )
    @mps_ignored_test()
    def test_stop_tokens_batched_uneven_stopping(
        self, request, model, prompt_tokens_batched
    ):
        """
        Test to check if the `generate` function produces the right output when different sequences
        in the batch stop at different lengths.
        """
        model = request.getfixturevalue(model)
        temperature = 0.6
        top_k = 100

        stop_tokens = [3953, 3979, 3989]

        torch.manual_seed(42)

        outputs, _ = generate(
            model=model,
            prompt=prompt_tokens_batched,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
            stop_tokens=stop_tokens,
        )

        expected_output = torch.tensor(
            [
                [2, 3, 4, 5, 6, 7, 8, 9, 3987, 3991, 3953],
                [2, 3, 4, 5, 6, 7, 8, 9, 3958, 3979, 0],
                [2, 3, 4, 5, 6, 7, 8, 9, 3989, 0, 0],
            ]
        )

        assert torch.equal(outputs, expected_output)

    @pytest.mark.parametrize(
        "model",
        ["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
    )
    @mps_ignored_test()
    def test_stop_tokens_batched_uneven_stopping_left_padded(
        self, request, model, prompt_tokens_batched_left_padded
    ):
        """
        Test to check if the `generate` function produces the right output when different sequences
        in the batch stop at different lengths.
        """
        model = request.getfixturevalue(model)
        temperature = 0.6
        top_k = 100

        stop_tokens = [3953, 3979, 3989]

        torch.manual_seed(42)

        outputs, _ = generate(
            model=model,
            prompt=prompt_tokens_batched_left_padded,
            max_generated_tokens=10,
            temperature=temperature,
            top_k=top_k,
            stop_tokens=stop_tokens,
        )

        expected_output = torch.tensor(
            [
                [0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3987, 3991, 3953],
                [0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3958, 3979, 0],
                [0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3989, 0, 0],
            ]
        )
        assert torch.equal(outputs, expected_output)


class TestGetPositionIDsFromPaddingMask:
    def test_get_position_ids_padding(self):
        outputs = get_position_ids_from_padding_mask(
            torch.Tensor([False, False, False, True, True, True, True, True])
        )
        expected_outputs = torch.Tensor([0, 0, 0, 0, 1, 2, 3, 4])
        assert torch.equal(outputs, expected_outputs)

    def test_get_position_ids_no_padding(self):
        outputs = get_position_ids_from_padding_mask(
            torch.Tensor([True, True, True, True, True, True, True, True])
        )
        expected_outputs = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7])
        assert torch.equal(outputs, expected_outputs)

    def test_get_position_ids_batched(self):
        outputs = get_position_ids_from_padding_mask(
            torch.Tensor(
                [
                    [False, False, False, True, True, True, True, True],
                    [False, False, False, False, False, False, True, True],
                    [True, True, True, True, True, True, True, True],
                ]
            )
        )
        expected_outputs = torch.Tensor(
            [
                [0, 0, 0, 0, 1, 2, 3, 4],
                [0, 0, 0, 0, 0, 0, 0, 1],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ]
        )
        assert torch.equal(outputs, expected_outputs)


class TestGetCausalMaskFromPaddingMask:
    @pytest.fixture
    def prompt_tokens_batched(self):
        """
        Pytest fixture to create a list of batched prompt tokens for testing.
        """
        return torch.arange(2, 10).repeat(3, 1)

    @pytest.fixture
    def left_padded_prompt_tokens(self):
        """
        Pytest fixture to create a list of left-padded prompt tokens for testing.
        """
        return torch.cat([torch.tensor([0, 0]), torch.arange(2, 6)]).unsqueeze(0)

    @pytest.fixture
    def left_padded_prompt_tokens_batched(self):
        """
        Pytest fixture to create a list of left-padded batched prompt tokens for testing.
        """
        return torch.tensor(
            [[0, 0, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 1]]
        )

    def test_get_causal_mask_for_left_padded_inputs(self, left_padded_prompt_tokens):
        """
        Test to check if the `get_causal_mask` function produces the right output for left-padded prompts.
        """
        expected_casual_mask = torch.tensor(
            [
                [True, False, False, False, False, False],
                [False, True, False, False, False, False],
                [False, False, True, False, False, False],
                [False, False, True, True, False, False],
                [False, False, True, True, True, False],
                [False, False, True, True, True, True],
            ]
        ).unsqueeze(0)

        causal_mask = get_causal_mask_from_padding_mask(left_padded_prompt_tokens != 0)
        assert torch.equal(causal_mask, expected_casual_mask)

    def test_get_causal_mask_for_left_padded_inputs_batched(
        self, left_padded_prompt_tokens_batched
    ):
        """
        Test to check if the `get_causal_mask` function produces the right output for left-padded batched prompts.
        """
        expected_causal_mask = torch.tensor(
            [
                [
                    [True, False, False, False, False, False],
                    [False, True, False, False, False, False],
                    [False, False, True, False, False, False],
                    [False, False, False, True, False, False],
                    [False, False, False, True, True, False],
                    [False, False, False, True, True, True],
                ],
                [
                    [True, False, False, False, False, False],
                    [False, True, False, False, False, False],
                    [False, True, True, False, False, False],
                    [False, True, True, True, False, False],
                    [False, True, True, True, True, False],
                    [False, True, True, True, True, True],
                ],
                [
                    [True, False, False, False, False, False],
                    [False, True, False, False, False, False],
                    [False, False, True, False, False, False],
                    [False, False, False, True, False, False],
                    [False, False, False, False, True, False],
                    [False, False, False, False, False, True],
                ],
            ]
        )

        causal_mask = get_causal_mask_from_padding_mask(
            left_padded_prompt_tokens_batched != 0
        )
        assert torch.equal(causal_mask, expected_causal_mask)
