# Copyright (c) 2021, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import json
import os
import re
import shutil
import random
import sys
import numpy as np

import torch
from glob import glob

from megatron import mpu
from megatron import print_rank_0
from megatron.utils import natural_sort
from megatron.text_generation_utils import get_batch, forward_model
from pathlib import Path
from pprint import pformat


def check_checkpoint_args(neox_args, checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
    arguments and the one retrieved from checkpoint."""

    assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
    for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
        args_value = getattr(neox_args, checkpoint_arg_name)
        error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format(
            checkpoint_arg_name, checkpoint_arg_value, args_value
        )
        assert checkpoint_arg_value == args_value, error_message


def do_forward_pass(neox_args, model, inference=False):

    # set to eval mode
    model_was_in_train = model.training
    model.eval()

    # get context tokens
    # always forward full batch size
    context_tokens_tensor = (
        torch.arange(neox_args.seq_length + 1)
        .repeat((neox_args.train_micro_batch_size_per_gpu, 1))
        .cuda()
    )

    # forward
    if inference:
        tokens, attention_mask, position_ids = get_batch(
            neox_args, context_tokens_tensor[:, : neox_args.seq_length]
        )
        model_inputs = (
            tokens,
            position_ids,
            attention_mask,
            torch.Tensor(),
        )
        logits, _ = forward_model(neox_args, model, model_inputs)
    elif neox_args.is_pipe_parallel:
        data_iterator = iter([{"text": context_tokens_tensor}])
        _, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
    else:
        tokens, attention_mask, position_ids = get_batch(
            neox_args, context_tokens_tensor[:, : neox_args.seq_length]
        )
        logits = model((tokens, position_ids, attention_mask))

    # reset to train mode, if model was in training before
    if model_was_in_train:
        model.train()

    if logits is not None:
        logits = logits.detach().cpu()[
            0
        ]  # just return first batch item (they are all equal)

    return logits


def check_forward_pass(neox_args, model, checkpoint_logits, inference):
    # do forward pass with loaded checkpoint
    logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)

    # check
    if (
        logits is not None and checkpoint_logits is not None
    ):  # this could be the case for non-final pipeline stages
        if not (logits == checkpoint_logits).all().item():
            if mpu.get_data_parallel_rank() == 0:
                print(
                    " > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result"
                )
            assert (
                torch.isclose(logits, checkpoint_logits).all().item()
            ), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"


def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
    """A unified checkpoint name."""
    if release:
        directory = "release"
    else:
        directory = "iter_{:07d}".format(iteration)
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}".format(
            mpu.get_model_parallel_rank() if mp_rank is None else mp_rank
        ),
        "model_optim_rng.pt",
    )


def delete_old_checkpoints(save_dir, n_to_keep):
    if torch.distributed.get_rank() == 0:
        ckpt_dir_regex = r"global_step[\d]*"
        if save_dir.endswith("/"):
            save_dir = save_dir.strip("/")
        all_ckpts = natural_sort(
            [
                i
                for i in glob(f"{save_dir}/*")
                if os.path.isdir(i) and re.search(ckpt_dir_regex, i)
            ]
        )
        n_to_delete = len(all_ckpts) - n_to_keep
        if n_to_delete > 0:
            to_delete = all_ckpts[:n_to_delete]
            print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}")
            for ckpt in to_delete:
                try:
                    shutil.rmtree(ckpt)
                except FileNotFoundError:
                    pass


def save_ds_checkpoint(iteration, model, neox_args, data_sampling_weights):
    """Save a model checkpoint."""
    sd = {
        "iteration": iteration,
        "args": {
            "num_layers": neox_args.num_layers,
            "hidden_size": neox_args.hidden_size,
            "num_attention_heads": neox_args.num_attention_heads,
            "max_position_embeddings": neox_args.max_position_embeddings,
            "make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by,
            "padded_vocab_size": neox_args.padded_vocab_size,
            "tokenizer_type": neox_args.tokenizer_type,
            "model_parallel_size": neox_args.model_parallel_size,
        },
    }
    # rng states.
    if not neox_args.no_save_rng:
        sd["random_rng_state"] = random.getstate()
        sd["np_rng_state"] = np.random.get_state()
        sd["torch_rng_state"] = torch.get_rng_state()
        sd["cuda_rng_state"] = torch.cuda.get_rng_state()
        sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()

    if neox_args.checkpoint_validation_with_forward_pass:
        logits = do_forward_pass(neox_args=neox_args, model=model)
        sd["checkpoint_validation_logits"] = logits

    # checkpoint folder name
    tag = f"global_step{iteration}"

    # save checkpoint
    model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)

    # save config files
    if torch.distributed.get_rank() == 0 and neox_args.config_files is not None:
        configs_directory = os.path.join(neox_args.save, tag, "configs")
        os.makedirs(configs_directory, exist_ok=True)
        for config_filename, config_data in neox_args.config_files.items():
            with open(os.path.join(configs_directory, config_filename), "w") as f:
                if isinstance(config_data, str):
                    f.write(config_data)
                else:
                    json.dump(config_data, f)
        if data_sampling_weights:
            # save the entire data_sampling_weights object
            torch.save(data_sampling_weights, os.path.join(neox_args.save, tag, "data_sampling_weights.pt"))
    
    # save current iterations per dataset
    if neox_args.use_named_train_datasets:
        # wait for all processes to make sure configs directory is created
        torch.distributed.barrier()
        with open(os.path.join(neox_args.save, tag, f"local_samples_seen_per_dataset_rank_{torch.distributed.get_rank()}.json"), "w") as f:
            json.dump(neox_args.local_samples_seen_per_dataset, f)


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler, data_sampling_weights=None):
    """Save a model checkpoint."""

    if neox_args.deepspeed:
        save_ds_checkpoint(iteration, model, neox_args, data_sampling_weights)
    else:
        raise ValueError("Must be using deepspeed to use neox")

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if neox_args.keep_last_n_checkpoints is not None:
        delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints)

    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()


def load_checkpoint(
    neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None
):
    """Load a model checkpoint and return the iteration."""
    if neox_args.deepspeed:
        load_optim_and_scheduler = (
            not neox_args.no_load_optim
        )  # TODO: These should be configured by separate args
        if neox_args.finetune:
            load_optim_and_scheduler = False
        if iteration is not None:
            tag = f"global_step{iteration}"
        else:
            tag = None
        checkpoint_name, state_dict = model.load_checkpoint(
            neox_args.load,
            load_optimizer_states=load_optim_and_scheduler,
            load_lr_scheduler_states=load_optim_and_scheduler,
            tag=tag,
        )

        if checkpoint_name is None:
            # if an iteration is specified, we want to raise an error here rather than
            # continuing silently, since we are trying to load a specific checkpoint
            if iteration is not None:
                available_checkpoints = sorted(
                    [
                        int(i.name.replace("global_step", ""))
                        for i in Path(neox_args.load).glob("global_step*")
                    ]
                )
                raise ValueError(
                    f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}"
                )
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")

            return 0  # iteration 0, if not checkpoint loaded
    else:
        raise ValueError("Must be using deepspeed to use neox")

    # Set iteration.
    if neox_args.finetune:
        iteration = 0
    else:
        iteration = state_dict.get("iteration") or state_dict.get(
            "total_iters"
        )  # total_iters backward compatible with older checkpoints
        if iteration is None:
            raise ValueError(
                f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
            )

    # Check arguments.
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
        print_rank_0(
            " > validated currently set args with arguments in the checkpoint ..."
        )
    else:
        print_rank_0(" > could not find arguments in the checkpoint for validation...")

    # Check loaded checkpoint with forward pass
    if neox_args.checkpoint_validation_with_forward_pass:
        if "checkpoint_validation_logits" in state_dict:
            check_forward_pass(
                neox_args=neox_args,
                model=model,
                checkpoint_logits=state_dict["checkpoint_validation_logits"],
                inference=inference,
            )
            print_rank_0(" > validated loaded checkpoint with forward pass ...")
        else:
            if mpu.get_data_parallel_rank() == 0:
                print(
                    " > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format(
                        checkpoint_name
                    )
                )

    # rng states.
    if not neox_args.finetune and not neox_args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_0(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name)
            )
            sys.exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print("  successfully loaded {}".format(checkpoint_name))

    return iteration

def load_data_sampling_weights(neox_args):
    if neox_args.load is not None and neox_args.iteration > 0:
        # load data sampling weights from checkpoint
        data_sampling_weights_path = os.path.join(neox_args.load, f"global_step{neox_args.iteration}", "data_sampling_weights.pt")
        if os.path.exists(data_sampling_weights_path):
            print_rank_0(f"Loading data_sampling_weights from {data_sampling_weights_path}")
            data_sampling_weights = torch.load(data_sampling_weights_path)
        else:
            data_sampling_weights = None
    else:
        data_sampling_weights = None
    return data_sampling_weights

def load_local_samples_seen_per_dataset(neox_args):
    if neox_args.load is not None and neox_args.iteration > 0:
        # load data sampling weights from checkpoint
        local_samples_seen_per_dataset_path = os.path.join(neox_args.load, f"global_step{neox_args.iteration}", f"local_samples_seen_per_dataset_rank_{torch.distributed.get_rank()}.json")
        if os.path.exists(local_samples_seen_per_dataset_path):
            print_rank_0(f"Loading local_samples_seen_per_dataset from {local_samples_seen_per_dataset_path}")
            with open(local_samples_seen_per_dataset_path, "r") as f:
                local_samples_seen_per_dataset = json.load(f)
        else:
            local_samples_seen_per_dataset = None
    else:
        local_samples_seen_per_dataset = None
    return local_samples_seen_per_dataset

def load_global_samples_seen_per_dataset(neox_args):
    global_samples_seen_per_dataset = {}
    if neox_args.load is not None and neox_args.iteration > 0:
        # load data sampling weights from checkpoint
        for path in Path(os.path.join(os.path.join(neox_args.load, f"global_step{neox_args.iteration}"))).glob("local_samples_seen_per_dataset_rank_*"):
            print_rank_0(f"Loading local_samples_seen_per_dataset from {path}")
            with path.open() as f:
                samples_seen = json.load(f)
            for k, v in samples_seen.items():
                if k not in global_samples_seen_per_dataset:
                    global_samples_seen_per_dataset[k] = 0
                global_samples_seen_per_dataset[k] += v

    # assert that numbers are correct
    assert(sum(global_samples_seen_per_dataset.values()) == neox_args.iteration * neox_args.train_batch_size)
    return global_samples_seen_per_dataset