import logging

import torch
from functorch.compile import aot_function

from meshflow.torch import meshflow_shard, set_device_mesh
from meshflow.utils.testing import TorchMockDeviceMesh

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.DEBUG)


def fn(x, y):
    tanh = torch.tanh(x)
    return torch.mm(torch.exp(tanh), y) + tanh


def main():

    mesh = TorchMockDeviceMesh(1, 4)
    set_device_mesh(mesh)

    x = torch.randn(10, 10, requires_grad=True)
    y = torch.randn(10, 10, requires_grad=True)
    # Pass on the compiler_fn to the aot_function API
    aot_print_fn = aot_function(fn, fw_compiler=meshflow_shard, bw_compiler=meshflow_shard)
    res = aot_print_fn(x, y)
    print(res.shape)

    grad_res = torch.ones_like(res)
    res.backward(grad_res)


if __name__ == '__main__':
    main()
