# torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/bench_torch.py

import logging
import os
import sys

import numpy
import torch
import torch.optim as optim
import torch.utils._pytree as pytree
from functorch.compile import aot_module
from spmd import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP

from meshflow.torch import (enable_transform, get_input_strategy, meshflow_shard, set_device_mesh,
                            shard_module)
from meshflow.torch.model import GPT, GATLayer, wresnet50
from meshflow.utils.testing import TorchMockDeviceMesh, setup_testing
from meshflow.utils.timer import MFTimer
from meshflow.autoflow.solver import set_max_memory

sys.path.append(os.path.abspath(__file__))
from benchmark.bench_case import GPTCase, ResNetCase, GATCase

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.INFO)


def get_gpt_case(device="cuda"):

    case = GPTCase()
    model = GPT(depth=case.num_layers, dim=case.hidden_dim, num_heads=case.num_heads)
    input_ = torch.ones(case.batch_size, case.seq_size, case.hidden_dim)

    model.device = torch.device(device)
    return model.to(device=device), input_.to(device=device)


def get_resnet_case(device="cuda"):
    case = ResNetCase()
    model = wresnet50()
    input_ = torch.ones(case.batch_size, 3, 224, 224)

    model.device = torch.device(device)
    return model.to(device=device), input_.to(device=device)


def get_gat_case(device="cuda"):
    case = GATCase()
    model = GATLayer(case.in_feature, case.out_feature)
    input_ = torch.ones(case.num_node, case.in_feature)
    adj = torch.ones(case.num_node, case.num_node)

    model.device = torch.device(device)
    return model.to(device=device), [input_.to(device=device), adj.to(device=device)]


def bench_ddp(model, input_):

    if not isinstance(input_, list):
        input_ = [input_]

    world_size = torch.distributed.get_world_size()
    for i in range(len(input_)):
        input_[i] = torch.chunk(input_[i], world_size)[0]

    ddp_model = DDP(model)
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    def train_step():
        optimizer.zero_grad()
        output_ = ddp_model(*input_)
        output_grad = torch.ones_like(output_)
        output_.backward(output_grad)
        optimizer.step()

    torch.cuda.reset_peak_memory_stats()

    timer = MFTimer(train_step, in_ms=False)

    elaps_time = timer.time()
    peak_memory = torch.cuda.max_memory_allocated()

    print(f"Memory: {peak_memory / 1024 / 1024 / 1024} GB")
    print(f"Time: {elaps_time}")


def bench_fsdp(model, input_):

    if not isinstance(input_, list):
        input_ = [input_]

    if not isinstance(model, GATLayer):
        world_size = torch.distributed.get_world_size()
        for i in range(len(input_)):
            input_[i] = torch.chunk(input_[i], world_size)[0]

    ddp_model = FSDP(model)
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    def train_step():
        optimizer.zero_grad()
        output_ = ddp_model(*input_)
        output_grad = torch.ones_like(output_)
        output_.backward(output_grad)
        optimizer.step()

    torch.cuda.reset_peak_memory_stats()

    timer = MFTimer(train_step, in_ms=False)

    elaps_time = timer.time()
    peak_memory = torch.cuda.max_memory_allocated()

    print(f"Memory: {peak_memory / 1024 / 1024 / 1024} GB")
    print(f"Time: {elaps_time}")


def to_meta(_node_output):
    if type(_node_output) is torch.Tensor:
        return _node_output.to(device="meta")
    elif type(_node_output) is torch.nn.parameter.Parameter:
        return _node_output.to(device="meta")
    else:
        return _node_output


def bench_meshflow(model, input_):

    world_size = torch.distributed.get_world_size()

    mesh_shape = numpy.array(range(world_size)).reshape(1, -1)
    mesh = DeviceMesh("cuda", mesh_shape.tolist())

    mock_mesh = TorchMockDeviceMesh(*(mesh_shape.shape))
    set_device_mesh(mock_mesh)

    enable_transform()

    compiled_module = aot_module(model, fw_compiler=meshflow_shard, bw_compiler=meshflow_shard)

    if not isinstance(input_, list):
        input_ = [input_]

    if model.device == torch.device("meta"):
        input_ = pytree.tree_map(to_meta, input_)

    output_ = compiled_module(*input_)
    try:
        output_grad = torch.ones_like(output_)
        output_.backward(output_grad)
    except:
        pass

    set_device_mesh(mesh)
    input_ = shard_module(compiled_module.orig_module, input_, get_input_strategy())

    optimizer = optim.SGD(compiled_module.orig_module.parameters(), lr=0.001)

    def train_step():
        optimizer.zero_grad()
        output_ = compiled_module(*input_)
        output_grad = torch.ones_like(output_)
        output_.backward(output_grad)
        optimizer.step()

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    timer = MFTimer(train_step, in_ms=False)

    elaps_time = timer.time()
    peak_memory = torch.cuda.max_memory_allocated()

    print(f"Memory: {peak_memory / 1024 / 1024 / 1024} GB")
    print(f"Time: {elaps_time}")


def main():
    # setup meshflow
    setup_testing(backend="torch", device="cuda")

    # setup distributed
    torch.distributed.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    model, input_ = get_gpt_case(device="cuda")

    bench_meshflow(model, input_)


if __name__ == '__main__':
    main()
