#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
on a text file or a dataset without using HuggingFace Trainer.

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
# 代码复制于Transformer仓库https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py

import argparse
import json
import logging
import math
import os
import deepspeed
from time import time

import datasets
import torch
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from collections import defaultdict

from tqdm.auto import tqdm

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    mpu
)
from transformers import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from torch.optim import AdamW
import torch.distributed as dist
from torch.distributed import get_rank
from typing import Optional, Tuple

from minillm.pipelines import LMPipeline
from minillm.losses import Loss
from utils import print_args, initialize, load_parallel, get_tokenizer, parallel_model_map
import torch_npu
from torch_npu.contrib import transfer_to_npu

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.36.0.dev0")

logger = get_logger(__name__)

require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def save_rank(log_str, save_path, rank=0):
    if not dist.is_initialized() or dist.get_rank() == rank:
        with open(save_path, "a") as f:
            f.write(log_str + "\n")


def print_rank(*args, rank=0, **kwargs):
    if not dist.is_initialized() or dist.get_rank() == rank:
        print(*args, **kwargs)

def save_parallel(model, save_dir):
    mp_rank = mpu.get_model_parallel_rank()
    os.makedirs(os.path.join(save_dir, f"mp{mpu.get_model_parallel_world_size()}"), exist_ok=True)
    checkpoint_name = os.path.join(save_dir, f"mp{mpu.get_model_parallel_world_size()}", f"pytorch_model_{mp_rank}.bin")
    torch.save(model.state_dict(), checkpoint_name)
    print(f"Rank {get_rank()}: {checkpoint_name} saved.")


def save(directory: Optional[str] = None, global_iter_count='0', model_parallel=False, model=None, tokenizer=None):
    """Creates checkpoint of optimizer, scheduler and a model"""
    base_ckpt_path = directory
    ckpt_dir = os.path.join(base_ckpt_path, global_iter_count)
    os.makedirs(ckpt_dir, exist_ok=True)
    if model_parallel:
        if get_rank() == 0:
            model.module.config.to_json_file(os.path.join(ckpt_dir, "config.json"))
            tokenizer.save_pretrained(ckpt_dir)
        if mpu.get_data_parallel_rank() == 0:

            save_parallel(model.module.base_model, ckpt_dir)
    else:
        if get_rank() == 0:
            model.save_pretrained(ckpt_dir, safe_serialization=False)
            # model.module.base_model.save_pretrained(ckpt_dir, safe_serialization=False)
            # torch.save(self.model.module.value_model.state_dict(), os.path.join(ckpt_dir, "value_model.ckpt"))
            print_rank(f"Model save to {ckpt_dir}")
            tokenizer.save_pretrained(ckpt_dir)

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help="The name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The configuration name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--data_file", type=str, default=None, help="A csv, txt, bin or a json file containing the training data."
    )
    parser.add_argument(
        "--validation_split_percentage",
        default=5,
        help="The percentage of the train set used as validation set in case there's no validation split",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=False,
    )
    parser.add_argument(
        "--config_name",
        type=str,
        default=None,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--use_slow_tokenizer",
        action="store_true",
        help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument(
        "--model_type",
        type=str,
        default=None,
        help="Model type to use if training from scratch.",
        choices=MODEL_TYPES,
    )
    parser.add_argument(
        "--block_size",
        type=int,
        default=None,
        help=(
            "Optional input sequence length after tokenization. The training dataset will be truncated in block of"
            " this size for training. Default to the model max input length for single sentence inputs (take into"
            " account special tokens)."
        ),
    )
    parser.add_argument(
        "--preprocessing_num_workers",
        type=int,
        default=None,
        help="The number of processes to use for the preprocessing.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="The number of processes to load data.",
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
    )
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument(
        "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
    )
    parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
    parser.add_argument(
        "--trust_remote_code",
        type=bool,
        default=True,
        help=(
            "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
            "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
            "execute code present on the Hub on your local machine."
        ),
    )
    parser.add_argument(
        "--save_interval",
        type=int,
        default=10000,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help="If the training should continue from a checkpoint folder.",
    )
    parser.add_argument(
        "--with_tracking",
        action="store_true",
        help="Whether to enable experiment trackers for logging.",
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="all",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
            ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
            "Only applicable when `--with_tracking` is passed."
        ),
    )
    parser.add_argument(
        "--low_cpu_mem_usage",
        action="store_true",
        help=(
            "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
            "If passed, LLM loading time and RAM consumption will be benefited."
        ),
    )
    parser.add_argument("--train-num", type=int, default=-1)
    parser.add_argument("--dev-num", type=int, default=-1)
    parser.add_argument("--max_length", type=int, default=2048)
    parser.add_argument("--seed_lm", type=int, default=7)
    parser.add_argument("--model_parallel", action="store_true")
    parser.add_argument("--clip_grad", type=float, default=1.0)
    parser.add_argument("--fp32", action="store_true")
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='local rank passed from distributed launcher')
    parser.add_argument("--gradient_checkpointing", action="store_true")
    parser.add_argument("--save", type=str, default=None)

    parser = deepspeed.add_config_arguments(parser)

    args, unknown = parser.parse_known_args()

    return args


def main():
    args = parse_args()
    initialize(args)
    args.output_dir = os.path.join(args.output_dir, args.config_name.split('/')[-1].split('.')[0])
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps  # 2
    ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size  # 2
    # 梯度截断的教程https://zhuanlan.zhihu.com/p/557949443
    ds_config["gradient_clipping"] = args.clip_grad  # 1.0 网络参数梯度的范数上线
    ds_config["steps_per_print"] = 10000000

    args.fp32 = not ds_config["fp16"]["enabled"]
    args.deepspeed_config = None

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_clm_no_trainer", args)


    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if args.config_name:
        config = AutoConfig.from_pretrained(
            args.config_name,
            trust_remote_code=args.trust_remote_code,
        )
    elif args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            args.model_name_or_path,
            trust_remote_code=args.trust_remote_code,
        )
    else:
        config = CONFIG_MAPPING[args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    tokenizer = None
    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
        )
    elif args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    # Read data and DataLoaders creation:
    if args.data_file is not None:
        lm_pipeline = LMPipeline(
            args, tokenizer, "train", args.data_file, num=args.train_num) if args.data_file is not None else None
        eval_lm_pipeline = LMPipeline(
            args, tokenizer, "valid", args.data_file, num=args.dev_num) if args.data_file is not None else None

        train_dataloader = lm_pipeline.create_loader(
            args.per_device_train_batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True)
        # lm_iterator = iter(lm_dataloader)

        eval_dataloader = eval_lm_pipeline.create_loader(
            args.per_device_eval_batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False)
    else:
        raise ValueError(
            "This script does not support to download dataset from the hub, and you should download datasets to the local disk."
        )

    if args.model_name_or_path:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            low_cpu_mem_usage=args.low_cpu_mem_usage,
            trust_remote_code=args.trust_remote_code,
        )
    else:
        print_rank("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=args.trust_remote_code)

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True


    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps ,
        num_training_steps=args.max_train_steps
        if overrode_max_train_steps
        else args.max_train_steps,
    )

    # Prepare everything with our `accelerator`.
    # model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    #     model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    # )

    # accelerating training with deepspeed
    model, optimizer, _, scheduler = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        args=args,
        lr_scheduler=lr_scheduler,
        mpu=mpu if args.model_parallel else None,
        config_params=ds_config
    )
    loss = Loss(args, trainer=None)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)


    starting_epoch = 0

    device = torch.cuda.current_device()

    iter_count = 1
    global_iter_count = 1
    total_iter_count = int(
        args.num_train_epochs
        * len(train_dataloader)
        / args.gradient_accumulation_steps
        )
    args.save_interval = 10000
    for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()
        active_dataloader = train_dataloader
        for step, lm_batch in enumerate(active_dataloader):
            lm_pipeline.move_to_device(*lm_batch, device)
            model_batch, _ = lm_batch
            # forward
            elapsed_time = time()
            outputs = model(**model_batch, return_dict=True, use_cache=False)
            logits = outputs.logits
            lm_loss = loss.my_pt_loss(lm_batch, logits)
            # backward
            model.backward(lm_loss)
            # step
            model.step()
            lr_scheduler.step()

            if args.gradient_checkpointing:
                model.module.set_force_gradient_checkpointing(False)

            # save temporary results
            if global_iter_count % args.save_interval == 0:
                save(args.output_dir, str(global_iter_count), args.model_parallel, model, tokenizer)

            elapsed_time = time() - elapsed_time

            if iter_count % 50 == 0:
                # print logging information
                train_log_info = "[Train] | data_epochs: {:2d}/{:2d} | iter: {:6d}/{:6d} | loss: {:.4f} | time: {:.4f}".format(
                            epoch,
                            args.num_train_epochs,
                            iter_count,
                            total_iter_count,
                            lm_loss.item(),
                            elapsed_time,
                            )
                print_rank(train_log_info)
                save_rank(train_log_info, os.path.join(args.output_dir, "log.txt"))
            iter_count += 1
            if iter_count % args.gradient_accumulation_steps == 0:  # self.args.gradient_accumulation_steps=2
                global_iter_count += 1
        # save epoch results
        save(args.output_dir, str(epoch), args.model_parallel, model, tokenizer)

        model.eval()
        total_loss = []
        for lm_batch in tqdm(eval_dataloader, desc="LM Evaluation", disable=(not get_rank() == 0)):
            eval_lm_pipeline.move_to_device(*lm_batch, device)
            model_batch, _ = lm_batch
            with torch.no_grad():
                outputs = model(**model_batch, return_dict=True, use_cache=False)
            logits = outputs.logits
            lm_loss = loss.my_pt_loss(lm_batch, logits)
            total_loss.append(torch.tensor([lm_loss]))
        total_loss = torch.cat(total_loss)
        try:
            eval_loss = torch.mean(total_loss)
            perplexity = math.exp(eval_loss)
        except OverflowError:
            perplexity = float("inf")

        eval_log_info = "epoch: {:2d} | perplexity: {:.4f} | eval_loss: {:.4f}".format(
            epoch,
            perplexity,
            eval_loss
        )
        print_rank("*" * 100)
        print_rank(eval_log_info)
        print_rank("*" * 100)


if __name__ == "__main__":
    main()