import torch
from tests.unit_tests.test_utilities import Utils
from megatron.core import ModelParallelConfig
import megatron.core.pipeline_parallel.schedules as schedule
from pytest_mock import mocker 
import pytest

rank = Utils.rank
 
def test_get_forward_backward_func():
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
    Utils.destroy_model_parallel()
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
    Utils.destroy_model_parallel()
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
    Utils.destroy_model_parallel()
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2, virtual_pipeline_model_parallel_size=4)
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
    Utils.destroy_model_parallel()

def test_deallocate_output_tensor():
    out = torch.tensor([[1, 2, 3], [4, 5, 6]])
    schedule.deallocate_output_tensor(out)
    assert(out.nelement() == 6) 

def test_forward_backward_func_without_pipeline_parallel(mocker):
    from megatron.core.pipeline_parallel import get_forward_backward_func

    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)

    def forward_step_func(data_iterator, model):
        import os
        rank = int(os.environ['LOCAL_RANK'])
        dummy_data = torch.ones(1,4)
        def loss_func(output_tensor):
            return rank, {'loss_reduced':rank}
        return model(dummy_data), loss_func

    model = torch.nn.Linear(4,1)
    model.model_type = 'unit-test'
    def set_input_tensor(input_tensor):
        return None
    model.set_input_tensor = set_input_tensor

    forward_backward_func = get_forward_backward_func()
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)

    mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
    config = ModelParallelConfig(
        pipeline_model_parallel_size = 1
    )
    model.config = config

    losses_reduced = forward_backward_func(
        forward_step_func=forward_step_func,
        data_iterator=range(0,100),
        model=[model],
        num_microbatches=4,
        seq_length=None,
        micro_batch_size=None,
        forward_only=True) 
    

    loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
    
    for i,j in zip(losses_reduced, loss_reduced_expected):
        print(losses_reduced)
        assert(i['loss_reduced'] == j['loss_reduced'])
    Utils.destroy_model_parallel() 


def test_forward_backward_func_with_pipeline_parallel(mocker):
    from megatron.core.pipeline_parallel import get_forward_backward_func

    Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4)

    def forward_step_func(data_iterator, model):
        import os
        rank = int(os.environ['LOCAL_RANK'])
        def loss_func(output_tensor):
            return rank, {'loss_reduced':rank}
        return torch.rand(512,8,256).cuda(), loss_func

    model = torch.nn.Linear(4,1)
    model.model_type = 'unit-test'
    def set_input_tensor(input_tensor):
        return None
    model.set_input_tensor = set_input_tensor

    forward_backward_func = get_forward_backward_func()
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)

    sequence_length = 512
    micro_batch_size = 8
    hidden_size = 256

    config = ModelParallelConfig(
        pipeline_model_parallel_size = 4,
        sequence_parallel = False,
        pipeline_dtype=torch.float,
    )
    config.hidden_size = hidden_size
    model.config = config
    
    losses_reduced = forward_backward_func(
        forward_step_func=forward_step_func,
        data_iterator=None,
        model=[model],
        num_microbatches= micro_batch_size,
        seq_length=sequence_length,
        micro_batch_size=micro_batch_size,
        forward_only=True) 
    
    loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
    for i,j in zip(losses_reduced, loss_reduced_expected):
        print(losses_reduced)
        assert(i['loss_reduced'] == j['loss_reduced'])
    Utils.destroy_model_parallel()  


def test_forward_backward_func_with_interleaving(mocker):
    from megatron.core.pipeline_parallel import get_forward_backward_func
    from megatron.core.enums import ModelType

    Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)

    def forward_step_func(data_iterator, model):
        import os
        rank = int(os.environ['LOCAL_RANK'])
        def loss_func(output_tensor):
            return rank, {'loss_reduced':rank}
        return torch.rand(512,8,256).cuda(), loss_func

    model = torch.nn.Linear(4,1)
    def set_input_tensor(input_tensor):
        return None
    model.set_input_tensor = set_input_tensor

    forward_backward_func = get_forward_backward_func()
    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)

    sequence_length = 512
    micro_batch_size = 8
    hidden_size = 256

    config = ModelParallelConfig(
        pipeline_model_parallel_size = 4,
        sequence_parallel = False,
        pipeline_dtype=torch.float,
    )
    config.hidden_size = hidden_size
    model.config = config

    mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)

    with pytest.raises(RuntimeError):
        model.model_type = ModelType.encoder_and_decoder
        forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=[range(0,100)],
            model=[model, model],
            num_microbatches= micro_batch_size,
            seq_length=sequence_length,
            micro_batch_size=micro_batch_size, 
            decoder_seq_length=sequence_length,
            forward_only=True)
 
    with pytest.raises(RuntimeError):
        model.model_type = ModelType.encoder_or_decoder
        forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=[range(0,100)],
            model=[model, model],
            num_microbatches= micro_batch_size,
            seq_length=sequence_length,
            micro_batch_size=micro_batch_size, 
            decoder_seq_length=256,
            forward_only=True)
     
    with pytest.raises(RuntimeError):
        model.model_type = ModelType.encoder_or_decoder
        forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=[range(0,100)],
            model=[model, model],
            num_microbatches= 7,
            seq_length=sequence_length,
            micro_batch_size=micro_batch_size, 
            decoder_seq_length=512,
            forward_only=True)    

    
    model.model_type = ModelType.encoder_or_decoder
    losses_reduced = forward_backward_func(
        forward_step_func=forward_step_func,
        data_iterator=[range(0,100), range(0,100)],
        model=[model, model],
        num_microbatches= micro_batch_size,
        seq_length=sequence_length,
        micro_batch_size=micro_batch_size, 
        decoder_seq_length=sequence_length,
        forward_only=True) 
    
    loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
    for i,j in zip(losses_reduced, loss_reduced_expected):
        print(losses_reduced)
        assert(i['loss_reduced'] == j['loss_reduced'])

    Utils.destroy_model_parallel()    
