# torchrun --nproc_per_node 2 --master_port 26543 ./examples/torch/test_sharding_simple.py

import logging
import os

import numpy
import torch
from functorch.compile import aot_function
from spmd import DeviceMesh, Replicate, distribute_tensor

from meshflow.torch import set_device_mesh, meshflow_shard, enable_transform

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.DEBUG)


def fn(x, y):
    with torch.no_grad():
        tanh = torch.tanh(x)
        return torch.mm(torch.exp(tanh), y) + tanh


def main():

    torch.distributed.init_process_group(backend="nccl")
    print(torch.distributed.get_rank(), torch.distributed.get_world_size())
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    world_size = torch.distributed.get_world_size()

    mesh_shape = numpy.array(range(world_size)).reshape(2, -1).tolist()
    mesh = DeviceMesh("cuda", mesh_shape)
    set_device_mesh(mesh)

    x = torch.randn(16, 16, requires_grad=True).cuda()
    y = torch.randn(16, 16, requires_grad=True).cuda()

    # Pass on the compiler_fn to the aot_function API
    enable_transform()
    aot_print_fn = aot_function(fn, fw_compiler=meshflow_shard, bw_compiler=meshflow_shard)
    res = aot_print_fn(x, y)
    print(res.shape)

    # grad_res = torch.ones_like(res)
    # res.backward(grad_res)

    x = distribute_tensor(x, mesh, [Replicate(), Replicate()])
    y = distribute_tensor(y, mesh, [Replicate(), Replicate()])

    print("Second Run with DTensor")
    res = aot_print_fn(x, y)
    print(res._local_tensor.shape)


if __name__ == '__main__':
    main()
