# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import json
import math
import os
import re
import shutil
import time
import random
import sys
import numpy as np

try:
    import boto3
except ModuleNotFoundError:
    print(
        "For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3"
    )
try:
    import hf_transfer
except ModuleNotFoundError:
    print(
        "For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer"
    )
import torch
from glob import glob

from megatron import mpu
from megatron import print_rank_0
from megatron.utils import natural_sort
from megatron.text_generation_utils import get_batch, forward_model
from pathlib import Path
from pprint import pformat


def check_checkpoint_args(neox_args, checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
    arguments and the one retrieved from checkpoint."""

    assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
    for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
        args_value = getattr(neox_args, checkpoint_arg_name)
        error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format(
            checkpoint_arg_name, checkpoint_arg_value, args_value
        )
        assert checkpoint_arg_value == args_value, error_message


def do_forward_pass(neox_args, model, inference=False):

    # set to eval mode
    model_was_in_train = model.training
    model.eval()

    # get context tokens
    # always forward full batch size
    context_tokens_tensor = (
        torch.arange(neox_args.seq_length + 1)
        .repeat((neox_args.train_micro_batch_size_per_gpu, 1))
        .cuda()
    )

    # forward
    if inference:
        tokens, attention_mask, position_ids = get_batch(
            neox_args, context_tokens_tensor[:, : neox_args.seq_length]
        )
        model_inputs = (
            tokens,
            position_ids,
            attention_mask,
            torch.Tensor(),
        )
        logits, _ = forward_model(neox_args, model, model_inputs)
    elif neox_args.is_pipe_parallel:
        data_iterator = iter([{"text": context_tokens_tensor}])
        _, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
    else:
        tokens, attention_mask, position_ids = get_batch(
            neox_args, context_tokens_tensor[:, : neox_args.seq_length]
        )
        logits = model((tokens, position_ids, attention_mask))

    # reset to train mode, if model was in training before
    if model_was_in_train:
        model.train()

    if logits is not None:
        logits = logits.detach().cpu()[
            0
        ]  # just return first batch item (they are all equal)

    return logits


def check_forward_pass(neox_args, model, checkpoint_logits, inference):
    # do forward pass with loaded checkpoint
    logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)

    # check
    if (
        logits is not None and checkpoint_logits is not None
    ):  # this could be the case for non-final pipeline stages
        if not (logits == checkpoint_logits).all().item():
            if mpu.get_data_parallel_rank() == 0:
                print(
                    " > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result"
                )
            assert (
                torch.isclose(logits, checkpoint_logits).all().item()
            ), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"


def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
    """A unified checkpoint name."""
    if release:
        directory = "release"
    else:
        directory = "iter_{:07d}".format(iteration)
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}".format(
            mpu.get_model_parallel_rank() if mp_rank is None else mp_rank
        ),
        "model_optim_rng.pt",
    )


def get_checkpoint_tag(iteration: int) -> str:
    return f"global_step{iteration}"


def delete_old_checkpoints(save_dir, n_to_keep):
    if torch.distributed.get_rank() == 0:
        ckpt_dir_regex = r"global_step[\d]*"
        if save_dir.endswith("/"):
            save_dir = save_dir.strip("/")
        all_ckpts = natural_sort(
            [
                i
                for i in glob(f"{save_dir}/*")
                if os.path.isdir(i) and re.search(ckpt_dir_regex, i)
            ]
        )
        n_to_delete = len(all_ckpts) - n_to_keep
        if n_to_delete > 0:
            to_delete = all_ckpts[:n_to_delete]
            print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}")
            for ckpt in to_delete:
                try:
                    shutil.rmtree(ckpt)
                except FileNotFoundError:
                    pass


def save_ds_checkpoint(iteration, model, neox_args):
    """Save a model checkpoint."""
    sd = {
        "iteration": iteration,
        "args": {
            "num_layers": neox_args.num_layers,
            "hidden_size": neox_args.hidden_size,
            "num_attention_heads": neox_args.num_attention_heads,
            "max_position_embeddings": neox_args.max_position_embeddings,
            "make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by,
            "padded_vocab_size": neox_args.padded_vocab_size,
            "tokenizer_type": neox_args.tokenizer_type,
            "model_parallel_size": neox_args.model_parallel_size,
        },
    }
    # rng states.
    if not neox_args.no_save_rng:
        sd["random_rng_state"] = random.getstate()
        sd["np_rng_state"] = np.random.get_state()
        sd["torch_rng_state"] = torch.get_rng_state()
        sd["cuda_rng_state"] = torch.cuda.get_rng_state()
        sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()

    if neox_args.checkpoint_validation_with_forward_pass:
        logits = do_forward_pass(neox_args=neox_args, model=model)
        sd["checkpoint_validation_logits"] = logits

    # checkpoint folder name
    tag = get_checkpoint_tag(iteration)

    # save checkpoint
    model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)

    # save config files
    if torch.distributed.get_rank() == 0 and neox_args.config_files is not None:
        configs_directory = os.path.join(neox_args.save, tag, "configs")
        os.makedirs(configs_directory, exist_ok=True)
        for config_filename, config_data in neox_args.config_files.items():
            with open(os.path.join(configs_directory, config_filename), "w") as f:
                if isinstance(config_data, str):
                    f.write(config_data)
                else:
                    json.dump(config_data, f)


def multiprocessing_starmap(func, args, num_processes=None):
    """Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
    Args:
        func (Callable): Function to call
        args (Iterable): Iterable of arguments to pass to `func`
        num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
    """
    import multiprocessing

    num_processes = num_processes or (multiprocessing.cpu_count() - 1)
    with multiprocessing.get_context("spawn").Pool(
        processes=num_processes
    ) as process_pool:
        process_pool.starmap(func, args)
        process_pool.terminate()
        process_pool.join()
        del process_pool


def _upload(
    file_path: str,
    s3_key: str,
    chunk_size: int = 104_857_600,
    max_files: int = 64,
    parallel_failures: int = 63,
    max_retries: int = 5,
):
    """Upload local file to S3 using `hf_transfer` library
    Args:
        file_path (str): Local filename to upload
        s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file`
        chunk_size (int, optional): Chunk size to use for multipart upload.
            Defaults to 100MiB = 104_857_600
        max_files (int, optional):  Number of open file handles, which determines
            the maximum number of parallel downloads. Defaults to 64
        parallel_failures (int, optional): Number of maximum failures of different
            chunks in parallel (cannot exceed max_files). Defaults to 63
        max_retries (int, optional): Number of retries for each chunk. Defaults to 5
    """
    s3 = boto3.client("s3")
    bucket = s3_key.split("s3://")[1].split("/")[0]
    key = s3_key.split(bucket)[1].lstrip("/")

    # 1. Init multipart upload and obtain unique upload identifier
    upload = s3.create_multipart_upload(
        ACL="bucket-owner-full-control",
        Bucket=bucket,
        Key=key,
    )
    upload_id = upload["UploadId"]

    # 2. Generate presigned URLs for each part
    file_size = os.stat(file_path).st_size
    urls = []
    nb_parts = math.ceil(file_size / chunk_size)
    for part_number in range(1, nb_parts + 1):
        params = {
            "Bucket": bucket,
            "Key": key,
            "PartNumber": part_number,
            "UploadId": upload_id,
        }
        urls.append(
            s3.generate_presigned_url(
                ClientMethod="upload_part", Params=params, ExpiresIn=86400
            )
        )

    # 3. Upload parts in parallel
    responses = hf_transfer.multipart_upload(
        file_path=file_path,
        parts_urls=urls,
        chunk_size=chunk_size,
        max_files=max_files,
        parallel_failures=parallel_failures,
        max_retries=max_retries,
    )

    # 4. Complete multipart upload request with ETag values
    etag_with_parts = []
    for part_number, header in enumerate(responses):
        etag = header.get("etag")
        etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
    parts = {"Parts": etag_with_parts}
    s3.complete_multipart_upload(
        Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
    )


def upload_checkpoint(iteration, neox_args):
    local_checkpoint_path = os.path.join(
        os.path.abspath(neox_args.save), get_checkpoint_tag(iteration)
    )
    local_checkpoint_list = sorted(
        filter(
            lambda x: os.path.isfile(x),
            [str(p) for p in Path(local_checkpoint_path).rglob("*")],
        )
    )
    remote_checkpoint_path = os.path.join(
        neox_args.s3_path,
        os.path.basename(neox_args.save),
        get_checkpoint_tag(iteration),
    )
    remote_checkpoint_list = [
        os.path.join(
            remote_checkpoint_path,
            os.path.relpath(local_checkpoint, local_checkpoint_path),
        )
        for local_checkpoint in local_checkpoint_list
    ]
    inputs = zip(
        local_checkpoint_list,
        remote_checkpoint_list,
        [neox_args.s3_chunk_size] * len(local_checkpoint_list),
    )

    print_rank_0(
        f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`..."
    )
    start = time.time()
    multiprocessing_starmap(_upload, inputs)
    total_time = time.time() - start
    print_rank_0(
        f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s"
    )


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint."""

    if neox_args.deepspeed:
        save_ds_checkpoint(iteration, model, neox_args)
    else:
        raise ValueError("Must be using deepspeed to use neox")

    torch.distributed.barrier()
    upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
    if upload_to_s3:
        upload_checkpoint(iteration, neox_args)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if neox_args.keep_last_n_checkpoints is not None:
        delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints)

    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()


def load_checkpoint(
    neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None
):
    """Load a model checkpoint and return the iteration."""
    if neox_args.deepspeed:
        load_optim_and_scheduler = (
            not neox_args.no_load_optim
        )  # TODO: These should be configured by separate args
        if neox_args.finetune:
            load_optim_and_scheduler = False
        if iteration is not None:
            tag = get_checkpoint_tag(iteration)
        else:
            tag = None
        checkpoint_name, state_dict = model.load_checkpoint(
            neox_args.load,
            load_optimizer_states=load_optim_and_scheduler,
            load_lr_scheduler_states=load_optim_and_scheduler,
            load_module_only=not load_optim_and_scheduler,
            tag=tag,
        )

        if checkpoint_name is None:
            # if an iteration is specified, we want to raise an error here rather than
            # continuing silently, since we are trying to load a specific checkpoint
            if iteration is not None:
                available_checkpoints = sorted(
                    [
                        int(i.name.replace("global_step", ""))
                        for i in Path(neox_args.load).glob("global_step*")
                    ]
                )
                raise ValueError(
                    f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}"
                )
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")

            return 0  # iteration 0, if not checkpoint loaded
    else:
        raise ValueError("Must be using deepspeed to use neox")

    # Set iteration.
    if neox_args.finetune:
        iteration = 0
    else:
        if "iteration" in state_dict:
            iteration = state_dict["iteration"]
        else:
            iteration = state_dict.get(
                "total_iters"
            )  # total_iters backward compatible with older checkpoints
        if iteration is None:
            raise ValueError(
                f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
            )

    # Check arguments.
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
        print_rank_0(
            " > validated currently set args with arguments in the checkpoint ..."
        )
    else:
        print_rank_0(" > could not find arguments in the checkpoint for validation...")

    # Check loaded checkpoint with forward pass
    if neox_args.checkpoint_validation_with_forward_pass:
        if "checkpoint_validation_logits" in state_dict:
            check_forward_pass(
                neox_args=neox_args,
                model=model,
                checkpoint_logits=state_dict["checkpoint_validation_logits"],
                inference=inference,
            )
            print_rank_0(" > validated loaded checkpoint with forward pass ...")
        else:
            if mpu.get_data_parallel_rank() == 0:
                print(
                    " > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format(
                        checkpoint_name
                    )
                )

    # rng states.
    if not neox_args.finetune and not neox_args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_0(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name)
            )
            sys.exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print("  successfully loaded {}".format(checkpoint_name))

    return iteration
