# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Functions for benchmarks.
"""

import numpy as np
import pprint
import torch
import tqdm
from fvcore.common.timer import Timer

import timesformer.utils.logging as logging
import timesformer.utils.misc as misc
from timesformer.datasets import loader
from timesformer.utils.env import setup_environment

logger = logging.get_logger(__name__)


def benchmark_data_loading(cfg):
    """
    Benchmark the speed of data loading in PySlowFast.
    Args:

        cfg (CfgNode): configs. Details can be found in
            lib/config/defaults.py
    """
    # Set up environment.
    setup_environment()
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.OUTPUT_DIR)

    # Print config.
    logger.info("Benchmark data loading with config:")
    logger.info(pprint.pformat(cfg))

    timer = Timer()
    dataloader = loader.construct_loader(cfg, "train")
    logger.info(
        "Initialize loader using {:.2f} seconds.".format(timer.seconds())
    )
    # Total batch size across different machines.
    batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS
    log_period = cfg.BENCHMARK.LOG_PERIOD
    epoch_times = []
    # Test for a few epochs.
    for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS):
        timer = Timer()
        timer_epoch = Timer()
        iter_times = []
        if cfg.BENCHMARK.SHUFFLE:
            loader.shuffle_dataset(dataloader, cur_epoch)
        for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)):
            if cur_iter > 0 and cur_iter % log_period == 0:
                iter_times.append(timer.seconds())
                ram_usage, ram_total = misc.cpu_mem_usage()
                logger.info(
                    "Epoch {}: {} iters ({} videos) in {:.2f} seconds. "
                    "RAM Usage: {:.2f}/{:.2f} GB.".format(
                        cur_epoch,
                        log_period,
                        log_period * batch_size,
                        iter_times[-1],
                        ram_usage,
                        ram_total,
                    )
                )
                timer.reset()
        epoch_times.append(timer_epoch.seconds())
        ram_usage, ram_total = misc.cpu_mem_usage()
        logger.info(
            "Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. "
            "RAM Usage: {:.2f}/{:.2f} GB.".format(
                cur_epoch,
                len(dataloader),
                len(dataloader) * batch_size,
                epoch_times[-1],
                ram_usage,
                ram_total,
            )
        )
        logger.info(
            "Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} "
            "(avg/std) seconds.".format(
                cur_epoch,
                log_period,
                log_period * batch_size,
                np.mean(iter_times),
                np.std(iter_times),
            )
        )
    logger.info(
        "On average every epoch ({} videos) takes {:.2f}/{:.2f} "
        "(avg/std) seconds.".format(
            len(dataloader) * batch_size,
            np.mean(epoch_times),
            np.std(epoch_times),
        )
    )
