# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

""" Storage writer for PyT Distributed format allowing asynchronous save. """

import logging
import os
from itertools import chain
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple

import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future

logger = logging.getLogger(__name__)

WriteBucket = Tuple[Path, str, Tuple[list, list]]  # represents writes to a single file


class FileSystemWriterAsync(FileSystemWriter):
    """
    Async-enabled implementation of FileSystemWriter using file IO.

    This class doesn't spawn the async process itself, relies on the external async mechanism.

    Flow:
    1. Call `write_data`
    2. Externally start async process with `get_save_function_and_args` function and args
    3. The async function to call is `writer_proxy_func` which calls
       `write_preloaded_data` in multiple processes

    After saving is finalized on all ranks:
    4. Call `super().finish` with the results gathered in `self.writer_result`

    Note that step (3) above can also be called synchronously.

    Currently, it's assumed that a separate writer is created for each ckpt save
    (intermediate state is stored as writer attributes).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not self.single_file_per_rank:
            raise NotImplementedError(
                'single_file_per_rank flag not supported for FileSystemWriterAsync'
            )

        # Intermediate state between preparation and finalization
        self.write_buckets: Optional[List[WriteBucket]] = None
        self.write_results: Optional[Dict[int, List[WriteResult]]] = None

    def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
        """
        First stage of async saving. Copy data to CPU and plan the local saving.

        Args:
            plan (SavePlan): save plan generated by the PyT Distributed compatible planner
            planner (SavePlanner): save planner used to resolve the bytes and tensor data

        Returns: None, but stores the save plan in `self.write_buckets`
        """
        storage_plan: _StoragePrefix = plan.storage_data
        start = time()
        logger.debug(f"thread_count: {self.thread_count}, time: {start}")
        item_buckets = _split_by_size_and_type(self.thread_count, plan.items)
        logger.debug(f"bucket_prep, time: {time() - start}")

        start = time()
        # move tensors from GPU to CPU before starting async writing
        # We do D2H synchronously for now
        file_count = 0

        def gen_file():
            nonlocal file_count
            file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
            file_count += 1
            return file_name

        # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
        self.write_buckets = []
        for bucket in item_buckets:
            bytes_data = [
                (item, planner.resolve_data(item))
                for item in bucket
                if item.type == WriteItemType.BYTE_IO
            ]
            tensor_data = [
                (item, planner.resolve_data(item).detach().to("cpu", non_blocking=True))
                for item in bucket
                if item.type != WriteItemType.BYTE_IO
            ]
            if len(bytes_data) > 0 or len(tensor_data) > 0:
                file_name = gen_file()
                self.write_buckets.append(
                    (self.path / file_name, file_name, (bytes_data, tensor_data))
                )

        # Check if there is anything to write on this rank
        if len(self.write_buckets) > 0:
            assert len(self.write_buckets) <= self.thread_count, (
                len(self.write_buckets),
                self.thread_count,
            )
            ctx = mp.get_context('fork')
            self.write_results = ctx.Manager().dict()
        else:
            self.write_results = {}
        logger.debug(f"D2H and push, time: {time() - start}")

    def get_save_function_and_args(self) -> Optional[Tuple[Callable, Tuple]]:
        """
        Get function that saves the data to storage along with its arguments.
        Allows the external caller to apply the save function synchronously or asynchronously.

        Returns: None (if there is nothing to write on this rank) or a tuple of:
            - the function that saves the data
            - arguments to that function
        """
        if not self.write_buckets:
            return None
        return (self.write_preloaded_data_multiproc, (self.write_buckets, self.write_results))

    @staticmethod
    def write_preloaded_data_multiproc(
        write_buckets: List[WriteBucket], write_results: Dict[int, List[WriteResult]]
    ) -> None:
        """
        Performs saving data to storage with multiple processes.

        Args:
            write_buckets (List[WriteBucket]): write plan
            write_results: (Dict[int, List[WriteResult]]): dict to store the write results to.
                Assumes multiprocessing save, so keys are local process indices
        Returns: None
        """
        w_start = time()
        ctx = mp.get_context('fork')
        p_list = [
            ctx.Process(
                target=FileSystemWriterAsync.write_preloaded_data,
                args=(i, write_bucket, write_results, True),
            )
            for i, write_bucket in enumerate(write_buckets)
        ]
        for p in p_list:
            p.start()
        for p in p_list:
            p.join()

        w_end = time()
        logger.debug(
            f"{w_end}, rank: {torch.distributed.get_rank()}, write(sync,parallel): {w_end - w_start}"
        )

    @staticmethod
    def write_preloaded_data(
        local_proc_idx: int,
        write_bucket: WriteBucket,
        write_results: Dict[int, List[WriteResult]],
        use_fsync: bool,
    ) -> None:
        """
        Performs actual data saving to storage.

        Args:
            local_proc_idx (int): index of a local process that performs writing
            write_bucket (WriteBucket): data to write to storage
            write_results (Dict[int, List[WriteResult]]): dict to store the write results to.
                Assumes multiprocessing save, so keys are local process indices
            use_fsync (bool): if True, calls os.fsync at the end of saving

        Returns: None, the write result are written to the `write_results` dict
        """
        mem_before = _process_memory()

        local_results = []
        file_name, storage_key, (bytes_data, tensor_data) = write_bucket
        with open(file_name, "wb") as stream:
            for write_item, data in bytes_data:
                local_results.append(_write_item(stream, data, write_item, storage_key))

            for write_item, tensor in tensor_data:
                assert tensor.is_cpu
                local_results.append(_write_item(stream, tensor, write_item, storage_key))

            if use_fsync:
                os.fsync(stream.fileno())
        write_results[local_proc_idx] = local_results
        mem_after = _process_memory()
        logger.debug(
            f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}"
        )

    def write_data(self, plan: SavePlan, planner: SavePlanner,) -> Future[List[WriteResult]]:
        raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')

    def retrieve_write_results(self) -> List[WriteResult]:
        """
        Turn self.write_results into a single results lists. Includes error check.

        Returns (List[WriteResult]): the list of write results from all local processes performing the save.

        """
        assert self.write_results is not None
        assert self.write_buckets is not None
        if len(self.write_results) != len(self.write_buckets):
            raise RuntimeError(
                f'Incomplete worker results (expected {len(self.write_buckets)}, got {len(self.write_results)}.'
                f' This probably indicates a worker failure.'
            )
        return list(chain.from_iterable(self.write_results.values()))


def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
    """
    Splits write items according to item size into close to uniform bins.

    Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
    but with a fixed _item_size function.

    Args:
        bins (int): numbers of bins to split to
        items (List[WriteItem]): list of write items

    Returns (List[List[WriteItem]]): write items split to bins
    """
    if bins == 1:
        return [items]

    bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
    tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]

    buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
    bucket_sizes = [0 for _ in range(bins)]

    tensor_items.sort(key=_item_size, reverse=True)

    # Assign bytes with a simple round-robin
    for i, item in enumerate(bytes_items):
        buckets[i % bins].append(item)

    # Then, assign tensors according to their sizes
    for item in tensor_items:
        # TODO replace with headq
        idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
        buckets[idx].append(item)
        bucket_sizes[idx] += _item_size(item)

    return buckets


def _item_size(item: WriteItem) -> int:
    """
    Calculates size (in bytes) of a single write item.

    Same as torch.distributed.checkpoint.filesystem._item_size,
    but fixes computing chunk size (with item.tensor_data.chunk.sizes)

    Args:
        item (WriteItem): write item to compute the size of

    Returns (int): size of an item in bytes
    """
    size = 1
    assert item.tensor_data is not None
    # can't use math.prod as PT needs to support older python
    for s in item.tensor_data.chunk.sizes:
        size *= s

    dtype = item.tensor_data.properties.dtype
    return size * torch._utils._element_size(dtype)


def _process_memory() -> int:
    """
    Get memory used by current process.

    Returns (int): memory used by current process
    """
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return mem_info.rss
