import math
import logging

import torch
from functorch.compile import aot_module

from meshflow.torch import meshflow_shard, set_device_mesh
from meshflow.utils.testing import TorchMockDeviceMesh

from meshflow.torch.model import FeedForward, GPT, GPTLayer, SelfAttention
from meshflow.torch.model import GATLayer
from meshflow.torch.model import resnet18
from meshflow.torch.model import LLAMAConfig, LLAMA

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.DEBUG)


def test_gpt():
    # torch_module = FeedForward(hidden_size=1024, ratio=4)
    torch_module = GPTLayer(dim=1024, num_heads=4)
    # torch_module = SelfAttention(dim=1024, num_heads=4)
    # torch_module = GPT(depth=4, dim=1024, num_heads=4)
    compiled_module = aot_module(torch_module,
                                 fw_compiler=meshflow_shard,
                                 bw_compiler=meshflow_shard)
    x = torch.ones(2, 256, 1024)

    y = compiled_module(x)

    print(y.shape)


def test_gat():
    # torch_module = FeedForward(hidden_size=1024, ratio=4)
    # torch_module = GPTLayer(dim=1024, num_heads=4)
    torch_module = GATLayer(1024, 1024)

    compiled_module = aot_module(torch_module,
                                 fw_compiler=meshflow_shard,
                                 bw_compiler=meshflow_shard)

    node_num = 2048
    h = torch.ones(node_num, 1024)
    adj = torch.ones(node_num, node_num)

    h = compiled_module(h, adj)

    print(h.shape)


def test_resnet():
    torch_module = resnet18()
    compiled_module = aot_module(torch_module,
                                 fw_compiler=meshflow_shard,
                                 bw_compiler=meshflow_shard)
    x = torch.ones(4, 3, 224, 224)
    y = compiled_module(x)

    print(y.shape)


def test_llama():
    seq_len = 256
    config = LLAMAConfig(seq_len=seq_len)
    torch_module = LLAMA(config)
    compiled_module = aot_module(torch_module,
                                 fw_compiler=meshflow_shard,
                                 bw_compiler=meshflow_shard)

    x = torch.ones(4, 256, 512)
    y = compiled_module(x)

    print(y.shape)


if __name__ == '__main__':

    mesh = TorchMockDeviceMesh(1, 4)
    set_device_mesh(mesh)

    test_gpt()
