import os
import sys

import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from fairscale.nn.model_parallel import get_data_parallel_group

from meshflow.torch.model.gpt_tp import GPT
from meshflow.utils.timer import MFTimer
from meshflow.utils.testing import setup_testing

sys.path.append(os.path.abspath(__file__))
from benchmark.bench_case import GPTCase


def get_gpt_case(cuda=True):

    case = GPTCase()
    model = GPT(depth=case.num_layers, dim=case.hidden_dim, num_heads=case.num_heads)
    input_ = torch.ones(case.batch_size, case.seq_size, case.hidden_dim)

    if cuda:
        return model.cuda(), input_.cuda()

    return model, input_


def bench_tp(model, input_):

    ddp_model = DDP(model, process_group=get_data_parallel_group())
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    def train_step():
        optimizer.zero_grad()
        output_ = ddp_model(input_)
        output_grad = torch.ones_like(output_)
        output_.backward(output_grad)
        optimizer.step()

    torch.cuda.reset_peak_memory_stats()

    timer = MFTimer(train_step, in_ms=False)

    elaps_time = timer.time()
    peak_memory = torch.cuda.max_memory_allocated()

    print(f"Memory: {peak_memory / 1024 / 1024 / 1024} GB")
    print(f"Time: {elaps_time}")


def main():
    setup_testing(backend="torch", device="cuda")
    # setup distributed
    torch.distributed.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = torch.distributed.get_world_size()
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    model, input_ = get_gpt_case(cuda=True)

    bench_tp(model, input_)


if __name__ == '__main__':
    main()
