import pytest
import torch
import functorch
from functorch.compile import aot_function
import rich

from meshflow.utils.testing import setup_testing
from meshflow.torch import MFTorchShardingAnn
from meshflow.torch.passes import fix_addmm, fix_batch_norm, eliminate_detach


def fn_1(x, y):
    return torch.concat([x, y], dim=1)


def fn_2(x, y):
    return torch.mm(torch.exp(torch.tanh(x)), y)


@functorch.compile.make_boxed_compiler
def compiler_fn(fx_module: torch.fx.GraphModule, inps):
    fx_module = fix_batch_norm(fx_module)
    fx_module = fix_addmm(fx_module)
    fx_module = eliminate_detach(fx_module)
    fx_module.recompile()
    print(fx_module.graph)

    sharding_interpreter = MFTorchShardingAnn(fx_module)
    sharding_info, fwd_shape_info = sharding_interpreter.run(*inps)
    rich.print("sharding_info:\n", sharding_info)
    rich.print("fwd_shape_info:\n", fwd_shape_info)

    return fx_module


@pytest.mark.parametrize("fn", [fn_1, fn_2])
def test_simple_case(fn):
    setup_testing("torch")
    x = torch.randn(10, 10, requires_grad=True)
    y = torch.randn(10, 10, requires_grad=True)
    aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
    res = aot_print_fn(x, y)

    grad_res = torch.ones_like(res)
    res.backward(grad_res)


if __name__ == '__main__':
    test_simple_case()
