import sys

import pytest
import torch
import torch.nn as nn
import numpy as np

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.linear import FMoELinear
from fmoe.megatron.layers import _megatron_init_method


def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().max()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > precision:
            sys.stderr.write(f"=========== {name} moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write(f"=========== {name} raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            sys.stderr.write(f"=========== {name} diff ==============\n")
            sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
            assert False


class MyExpert(nn.Module):
    r"""
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    """

    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
        super().__init__()
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        r"""
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        """
        if type(inp) == dict:
            x = inp["x"]
            y = inp["y"]
        elif type(inp) == list:
            x = inp[0]
            y = inp[1]
        else:
            raise NotImplementedError
        x = self.htoh4(x, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        y = self.htoh4(y, fwd_expert_count)
        y = self.activation(y)
        y = self.h4toh(y, fwd_expert_count)
        if type(inp) == dict:
            ret = {"x": x, "y": y}
        elif type(inp) == list:
            ret = [x, y]

        return ret


class MyGate(NaiveGate):
    def __init__(self, d_model, num_expert, world_size, top_k=2, gate_bias=True):
        super().__init__(d_model, num_expert, world_size, top_k, gate_bias=gate_bias)

    def forward(self, inp, return_all_scores=False):
        if type(inp) == dict:
            x = inp["x"]
        elif type(inp) == list:
            x = inp[0]
        else:
            raise NotImplementedError
        return super().forward(x, return_all_scores)


class MyMoE(FMoE):
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=MyGate,
            world_size=world_size,
            mp_group=mp_group,
            top_k=top_k,
        )
        self.experts = MyExpert(num_expert, d_model, d_hidden, activation)

        rng = np.random.default_rng(1234)
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize(
    "data_type", [torch.float32, torch.float16, torch.bfloat16, torch.double]
)
@pytest.mark.parametrize("list_input", [False, True])
def test_fmoe_mimo_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
    rank,
    world_size,
    mp_group,
    dp_group,
    world_group,
    data_type,
    list_input,
    activation=torch.nn.functional.gelu,
):

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    moe = MyMoE(
        num_expert=num_expert,
        d_model=d_model,
        d_hidden=4 * d_model,
        world_size=world_size,
        mp_group=mp_group,
        top_k=top_k,
        activation=activation,
    ).cuda().to(data_type)

    x = torch.rand(batch_size, d_model).cuda().to(data_type)
    inp = [x, x.clone()] if list_input else {"x": x, "y": x.clone()}
    moe_out = moe(inp)

    if list_input:
        _assert_numerical(["x"], [moe_out[0]], [moe_out[1]], rank)
    else:
        _assert_numerical(["x"], [moe_out["x"]], [moe_out["y"]], rank)


if __name__ == "__main__":
    test_fmoe_mimo_linear(
        batch_size=2,
        num_expert=2,
        d_model=2,
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
        mp_group=None,
        dp_group=None,
        world_group=None,
        data_type=torch.bfloat16,
        list_input=True
    )
