# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import time

from golden_configs.lm_wikitext2 import MOE as MOEConfig
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import utils

MPI_PORT = 29500


def benchmark_single_process(config_class, args):
    """Benchmark a given model using a single process and multiple devices."""

    world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
    assert world_size > 0
    benchmark_config = utils.create_benchmark_config(args.model_name, config_class)
    model_specs = utils.get_model_specs(args.model_name, config_class)

    mp.spawn(train, args=(world_size, benchmark_config, model_specs, args), nprocs=world_size, join=True)


def train(rank, world_size, benchmark_config, model_specs, args):
    logger = mp.log_to_stderr()
    logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
    utils.init_random_seed(rank)

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(
        backend="nccl", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
    logger.info("train, rank={}".format(rank))
    device = torch.device("cuda", rank) if torch.cuda.is_available() else torch.device("cpu")

    criterion = benchmark_config["criterion"]

    model_config = utils.create_model_config(
        args, benchmark_config=benchmark_config, model_specs=model_specs, device=device
    )
    # vocab_size may change in create_model_config() due to input data
    vocab_size = model_specs["vocab_size"]
    model = model_config["model"]
    model.train()
    optimizer = model_config["optimizer"]
    optimizer = optimizer(model.parameters())
    group = model.group if hasattr(model, "group") else None
    utils.log_number_of_parameters(model, logger)

    total_loss = 0.0
    word_counter = 0
    total_tokens = 0
    total_tokens_per_log_interval = 0
    bptt = 2

    total_elapsed = 0.0

    model = DDP(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
    lm_dataloader, _, _ = utils.get_data_loader(
        model_config["dataset_info"], args, benchmark_config, model_specs, num_replicas=world_size, rank=rank
    )

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

    for i, batch in enumerate(lm_dataloader):
        if i == 1:
            epoch_start_time = time.time()

        if args.max_batch and i > args.max_batch:
            break

        if i > 0:
            total_tokens += batch.numel()

        start_time = time.time()
        optimizer.zero_grad()
        source, target = get_batch(batch)
        source = source.to(device)
        target = target.to(device)
        try:
            output = model(source.to(device))
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
            optimizer.step()
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        elapsed = time.time() - start_time
        total_elapsed += elapsed
        log_interval = 1
        total_tokens_per_log_interval += batch.numel()
        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            logger.debug(
                "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                    i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                )
            )
            total_tokens_per_log_interval = 0
            total_loss = 0

    wps = total_tokens / total_elapsed

    logger.debug("rank {}, wps: {}".format(rank, wps))
    logger.debug(
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )


if __name__ == "__main__":
    args = utils.init_args()
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)

    logging.info(f"Running single process benchmark with args: {args}")
    benchmark_single_process(MOEConfig, args)
