# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import os
from copy import deepcopy

import pytest
import torch
from packaging.version import Version as PkgVersion
from pytest_mock import mocker

import megatron.core.parallel_state as ps
from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset
from megatron.core.models.T5.t5_model import T5Model
from megatron.core.models.T5.t5_spec import (
    get_t5_decoder_with_local_block_spec,
    get_t5_decoder_with_transformer_engine_block_spec,
    get_t5_encoder_with_local_block_spec,
    get_t5_encoder_with_transformer_engine_block_spec,
)
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class TestT5Model:

    def setup_method(self, method):
        tp = 4
        pp = 1
        Utils.initialize_model_parallel(
            tensor_model_parallel_size=tp,
            pipeline_model_parallel_size=pp,
            encoder_pipeline_model_parallel_size=pp,
        )
        model_parallel_cuda_manual_seed(123)
        transformer_config = TransformerConfig(
            num_layers=12,
            hidden_size=768,
            num_attention_heads=12,
            kv_channels=64,
            ffn_hidden_size=3072,
            use_cpu_initialization=True,
            pipeline_dtype=torch.bfloat16,
            tensor_model_parallel_size=tp,
            pipeline_model_parallel_size=pp,
        )
        rank = ps.get_pipeline_model_parallel_rank()
        world_size = ps.get_pipeline_model_parallel_world_size()
        en_block_spec = get_t5_encoder_with_transformer_engine_block_spec(12)
        de_block_spec = get_t5_decoder_with_transformer_engine_block_spec(12)

        first_decoder_rank = pp
        pre_process = rank == 0 or rank == first_decoder_rank
        post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
        add_encoder = ps.is_inside_encoder(rank)
        add_decoder = ps.is_inside_decoder(rank)

        self.t5_model = T5Model(
            encoder_config=transformer_config,
            config=transformer_config,
            transformer_encoder_layer_spec=en_block_spec,
            transformer_decoder_layer_spec=de_block_spec,
            vocab_size=29184,
            max_sequence_length=4,
            pre_process=pre_process,
            post_process=post_process,
            add_encoder=add_encoder,
            add_decoder=add_decoder,
        )

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    def test_constructor(self):
        assert isinstance(self.t5_model, T5Model)
        assert Utils.world_size == 8

        assert self.t5_model.max_sequence_length == 4
        if self.t5_model.add_encoder:
            assert not self.t5_model.add_decoder
            assert self.t5_model.encoder.num_layers_per_pipeline_rank == 12
            assert self.t5_model.pre_process
            assert self.t5_model.post_process
        else:
            assert self.t5_model.add_decoder
            assert self.t5_model.decoder.num_layers_per_pipeline_rank == 12
            assert self.t5_model.pre_process
            assert self.t5_model.post_process

    def test_set_input_tensor(self):
        config: TransformerConfig = self.t5_model.config
        sequence_length = self.t5_model.max_sequence_length
        micro_batch_size = 2

        # [sequence length, batch size, hidden size]
        input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size))

        self.t5_model.set_input_tensor(input_tensor)

        if self.t5_model.add_encoder:
            assert self.t5_model.encoder.input_tensor.shape[0] == sequence_length
            assert self.t5_model.encoder.input_tensor.shape[1] == micro_batch_size
            assert self.t5_model.encoder.input_tensor.shape[2] == config.hidden_size
        else:
            assert self.t5_model.encoder is None
            assert self.t5_model.encoder_hidden_state.shape[0] == sequence_length
            assert self.t5_model.encoder_hidden_state.shape[1] == micro_batch_size
            assert self.t5_model.encoder_hidden_state.shape[2] == config.hidden_size

    @pytest.mark.flaky_in_dev
    def test_post_process_forward(self):
        config: TransformerConfig = self.t5_model.config
        sequence_length = self.t5_model.max_sequence_length
        micro_batch_size = 2

        self.t5_model.cuda()

        data = list(range(sequence_length))
        encoder_input_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        )
        decoder_input_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        )
        encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
        decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
        encoder_decoder_attn_mask = torch.ones(
            (1, sequence_length, sequence_length), dtype=bool
        ).cuda()

        if self.t5_model.add_decoder:
            encoder_hidden_states = torch.zeros(
                (sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32
            ).cuda()
        else:
            encoder_hidden_states = None

        output = self.t5_model.forward(
            encoder_input_ids=encoder_input_ids,
            decoder_input_ids=decoder_input_ids,
            encoder_attn_mask=encoder_attn_mask,
            decoder_attn_mask=decoder_attn_mask,
            encoder_decoder_attn_mask=encoder_decoder_attn_mask,
            encoder_hidden_states=encoder_hidden_states,
        )
        if self.t5_model.add_decoder:
            logits = output
            assert logits.shape[0] == micro_batch_size
            assert logits.shape[1] == sequence_length
            assert (
                logits.shape[2]
                == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size()
            )
        else:
            encoder_hidden_states = output
            assert encoder_hidden_states.shape[0] == sequence_length
            assert encoder_hidden_states.shape[1] == micro_batch_size
            assert encoder_hidden_states.shape[2] == config.hidden_size

    @pytest.mark.flaky_in_dev
    def test_forward_output_encoder_hidden_only(self):
        config: TransformerConfig = self.t5_model.config
        sequence_length = self.t5_model.max_sequence_length
        micro_batch_size = 2

        self.t5_model.cuda()

        data = list(range(sequence_length))
        encoder_input_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        )
        decoder_input_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        )
        encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
        decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
        encoder_decoder_attn_mask = torch.ones(
            (1, sequence_length, sequence_length), dtype=bool
        ).cuda()

        encoder_hidden_states = self.t5_model.forward(
            encoder_input_ids=encoder_input_ids,
            decoder_input_ids=decoder_input_ids,
            encoder_attn_mask=encoder_attn_mask,
            decoder_attn_mask=decoder_attn_mask,
            encoder_decoder_attn_mask=encoder_decoder_attn_mask,
            output_encoder_hidden_only=True,
        )
        if self.t5_model.add_decoder:
            assert encoder_hidden_states is None
        else:
            assert encoder_hidden_states.shape[0] == sequence_length
            assert encoder_hidden_states.shape[1] == micro_batch_size
            assert encoder_hidden_states.shape[2] == config.hidden_size

    @pytest.mark.flaky_in_dev
    def test_forward_with_encoder_hidden_states(self):
        config: TransformerConfig = self.t5_model.config
        sequence_length = self.t5_model.max_sequence_length
        micro_batch_size = 2

        self.t5_model.cuda()

        data = list(range(sequence_length))
        encoder_input_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        )
        decoder_input_ids = (
            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        )
        encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
        decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
        encoder_decoder_attn_mask = torch.ones(
            (1, sequence_length, sequence_length), dtype=bool
        ).cuda()
        encoder_hidden_states = torch.zeros(
            (sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32
        ).cuda()

        output = self.t5_model.forward(
            encoder_input_ids=None,
            decoder_input_ids=decoder_input_ids,
            encoder_attn_mask=encoder_attn_mask,
            decoder_attn_mask=decoder_attn_mask,
            encoder_decoder_attn_mask=encoder_decoder_attn_mask,
            encoder_hidden_states=encoder_hidden_states,
        )
        if self.t5_model.add_decoder:
            logits = output
            assert logits.shape[0] == micro_batch_size
            assert logits.shape[1] == sequence_length
            assert (
                logits.shape[2]
                == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size()
            )
        else:
            encoder_hidden_states = output
            assert encoder_hidden_states.shape[0] == sequence_length
            assert encoder_hidden_states.shape[1] == micro_batch_size
            assert encoder_hidden_states.shape[2] == config.hidden_size

    def test_no_post_process_forward(self):
        pass

    def test_no_preprocess_forward(self):
        pass

    def test_state_dict_for_save_checkpoint(self):
        pass

    def test_load_state_dict(self):
        pass


class TestT5ModelAttentionDimensions:

    def teardown_method(self, method):
        os.environ.pop('NVTE_FUSED_ATTN', None)
        os.environ.pop('NVTE_FLASH_ATTN', None)
        os.environ.pop('NVTE_UNFUSED_ATTN', None)

    def setup_method(self, method):
        self.bs = 4
        self.seq_len = 512
        self.seq_len_dec = 128
        self.encoder_tokens = torch.ones([self.bs, self.seq_len])
        self.decoder_tokens = torch.ones([self.bs, self.seq_len_dec])
        self.encoder_mask = torch.ones([self.bs, self.seq_len]) < 0.5
        self.decoder_mask = torch.ones([self.bs, self.seq_len_dec]) < 0.5

    @pytest.mark.internal
    def test_local_spec(self):
        encoder_mask, decoder_mask, encoder_decoder_mask = (
            T5MaskedWordPieceDataset.config_attention_mask(
                self.encoder_tokens,
                self.decoder_tokens,
                self.encoder_mask,
                self.decoder_mask,
                use_local=True,
            )
        )

        assert list(encoder_mask.shape) == [self.bs, 1, self.seq_len, self.seq_len]
        assert list(decoder_mask.shape) == [self.bs, 1, self.seq_len_dec, self.seq_len_dec]
        assert list(encoder_decoder_mask.shape) == [self.bs, 1, self.seq_len_dec, self.seq_len]

    @pytest.mark.internal
    def test_transformer_engine_version_1_10(self):
        encoder_mask, decoder_mask, encoder_decoder_mask = (
            T5MaskedWordPieceDataset.config_attention_mask(
                self.encoder_tokens,
                self.decoder_tokens,
                self.encoder_mask,
                self.decoder_mask,
                use_local=False,
                test_te_version="1.10",
            )
        )

        assert list(encoder_mask.shape) == [self.bs, 1, 1, self.seq_len]
        assert decoder_mask is None
        assert list(encoder_decoder_mask[0].shape) == [self.bs, 1, 1, self.seq_len_dec]
        assert list(encoder_decoder_mask[1].shape) == [self.bs, 1, 1, self.seq_len]

    @pytest.mark.internal
    def test_transformer_engine_version_1_7_to_1_10_flashfused_attn(self):
        os.environ['NVTE_FLASH_ATTN'] = '1'
        os.environ['NVTE_FUSED_ATTN'] = '1'

        encoder_mask, decoder_mask, encoder_decoder_mask = (
            T5MaskedWordPieceDataset.config_attention_mask(
                self.encoder_tokens,
                self.decoder_tokens,
                self.encoder_mask,
                self.decoder_mask,
                use_local=False,
                test_te_version="1.8",
            )
        )

        assert list(encoder_mask.shape) == [self.bs, 1, 1, self.seq_len]
        assert decoder_mask is None
        assert list(encoder_decoder_mask[0].shape) == [self.bs, 1, 1, self.seq_len_dec]
        assert list(encoder_decoder_mask[1].shape) == [self.bs, 1, 1, self.seq_len]

    @pytest.mark.internal
    def test_transformer_engine_version_1_7_to_1_10_unfused_attention(self):
        os.environ['NVTE_FLASH_ATTN'] = '0'
        os.environ['NVTE_FUSED_ATTN'] = '0'

        encoder_mask, decoder_mask, encoder_decoder_mask = (
            T5MaskedWordPieceDataset.config_attention_mask(
                self.encoder_tokens,
                self.decoder_tokens,
                self.encoder_mask,
                self.decoder_mask,
                use_local=False,
                test_te_version="1.8",
            )
        )

        assert list(encoder_mask.shape) == [self.bs, 1, self.seq_len, self.seq_len]
        assert decoder_mask is None
        assert list(encoder_decoder_mask.shape) == [self.bs, 1, self.seq_len_dec, self.seq_len]

    @pytest.mark.internal
    def test_transformer_engine_version_less_than_1_7(self):
        os.environ['NVTE_FLASH_ATTN'] = '1'
        with pytest.raises(Exception) as exc_info:
            encoder_mask, decoder_mask, encoder_decoder_mask = (
                T5MaskedWordPieceDataset.config_attention_mask(
                    self.encoder_tokens,
                    self.decoder_tokens,
                    self.encoder_mask,
                    self.decoder_mask,
                    use_local=False,
                    test_te_version="1.5",
                )
            )

        assert str(exc_info.value) == (
            "Flash and fused attention is not supported with transformer "
            "engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0"
            "or upgrade transformer engine >= 1.7"
        )
