# torchrun --nproc_per_node 2 --master_port 26543 ./examples/torch/test_sharding_model.py

import logging
import os

import numpy
import torch
from functorch.compile import aot_module
from spmd import DeviceMesh

from meshflow.torch import (enable_transform, get_input_strategy, meshflow_shard, set_device_mesh,
                            shard_module)
from meshflow.torch.model import GPT, resnet18
from meshflow.utils.testing import TorchMockDeviceMesh, setup_testing

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.INFO)


def test_gpt(mesh):
    enable_transform()

    torch_module = GPT(depth=1, dim=1024, num_heads=4, mlp_ratio=4).cuda()
    compiled_module = aot_module(torch_module,
                                 fw_compiler=meshflow_shard,
                                 bw_compiler=meshflow_shard)
    x = torch.rand(2, 256, 1024).cuda()

    y_ref = torch_module(x)

    y = compiled_module(x)

    print("Second Run with DTensor")

    set_device_mesh(mesh)
    x = shard_module(compiled_module.orig_module, [x], get_input_strategy())[0]

    y = compiled_module(x)

    print("Global: ", y_ref)
    print(y)


def test_resnet(mesh):
    enable_transform()

    torch_module = resnet18().cuda()
    compiled_module = aot_module(torch_module,
                                 fw_compiler=meshflow_shard,
                                 bw_compiler=meshflow_shard)
    x = torch.ones(4, 3, 224, 224).cuda()

    y = compiled_module(x)

    print("Second Run with DTensor")
    set_device_mesh(mesh)
    x = shard_module(compiled_module.orig_module, [x], get_input_strategy())[0]

    y = compiled_module(x)

    print(y.shape)


if __name__ == '__main__':
    setup_testing(backend="torch", device="cuda")

    torch.distributed.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    world_size = torch.distributed.get_world_size()

    mesh_shape = numpy.array(range(world_size)).reshape(1, -1)
    mesh = DeviceMesh("cuda", mesh_shape.tolist())

    mock_mesh = TorchMockDeviceMesh(*(mesh_shape.shape))
    set_device_mesh(mock_mesh)

    test_gpt(mesh)
