#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

import logging
import math
import os
import sys
import yaml
import warnings
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

import datasets
import evaluate
import torch
from datasets import load_dataset, load_from_disk
from safetensors.torch import load_file as safe_load

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from contextlm_arguments import ModelArguments, DataTrainingArguments

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0")

require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

logger = logging.getLogger(__name__)


MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)



def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if sys.argv[-1].endswith(".yaml"):
        with open(sys.argv[-1], 'r') as file:
            args = yaml.safe_load(file)
            model_args = ModelArguments(**args.get('model_args', {}))
            data_args = DataTrainingArguments(**args.get('data_args', {}))
            training_args = TrainingArguments(**args.get('training_args', {}))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    import os
    if training_args.do_train:
        os.environ["WANDB_PROJECT"]=model_args.wandb_project 

    send_example_telemetry("run_clm", model_args, data_args)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    else:
        from transformers import LlamaConfig
        config_class = LlamaConfig
        config = config_class()
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info(f"Overriding config: {model_args.config_overrides}")
            config.update_from_string(model_args.config_overrides)
            logger.info(f"New config: {config}")

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None]
            else getattr(torch, model_args.torch_dtype)
        )
        config = AutoConfig.from_pretrained(model_args.model_name_or_path)
        ori_num_hidden_layers = config.num_hidden_layers

        if 'gpt2' in training_args.output_dir:
            config.w_size = model_args.w_size
            config.encoder_layers = model_args.feature_extractor_layers
            config.use_context_lm = model_args.use_context_lm
            config.context_lm_layers = model_args.context_lm_layers

            if not model_args.hlm_n_head:
                config.hlm_n_head = config.n_head
                config.hlm_n_embd = config.n_embd
            else: 
                config.hlm_n_head = model_args.hlm_n_head
                config.hlm_n_embd = model_args.hlm_n_embd

            from models.modeling_gpt2_contextlm import TokenGPT2LMHeadModel
            model = TokenGPT2LMHeadModel(config)
        elif 'pythia' in training_args.output_dir:
            config.w_size = model_args.w_size
            config.encoder_layers = model_args.feature_extractor_layers

            config.use_context_lm = model_args.use_context_lm
            config.context_type = model_args.context_type
            config.context_lm_layers = model_args.context_lm_layers
            if ('contextlm' in training_args.output_dir or 'baseline' in training_args.output_dir):
                from models.modeling_pythia_contextlm import GPTNeoXForCausalLM as TokenGPTNeoXForCausalLM
                if training_args.do_train:
                    model = TokenGPTNeoXForCausalLM(config)
                    print(f"Train from scratch models from :{model_args.model_name_or_path}")
                else:
                    model = TokenGPTNeoXForCausalLM.from_pretrained(model_args.model_name_or_path, torch_dtype=torch_dtype, **config_kwargs)
                    print(f"Loading models from :{model_args.model_name_or_path}")
            else:
                from transformers import GPTNeoXForCausalLM
                if training_args.do_train:
                    model = GPTNeoXForCausalLM(config)
                    print(f"Train from scratch models from :{model_args.model_name_or_path}")

                else:
                    model = GPTNeoXForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
                    print(f"Loading models from :{model_args.model_name_or_path}")

    else:
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")

    # 计算总体参数
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    logger.info(f"Training new model from scratch - Total Param size={total_params/2**20:.2f}M params")
    logger.info(f"Training new model from scratch - Total Trainable Param size={trainable_params/2**20:.2f}M params")

    # 获取并打印可训练参数的名称
    trainable_params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    for name, param in model.named_parameters():
        if not param.requires_grad:
            logger.info(f'No Trainable Param is: {name}')
    trainable_params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    for name, param in model.named_parameters():
        if param.requires_grad:
            logger.info(f'Trainable Param is: {name}')
    from data.data import ArrowNCPDataset, PreprocessedDataset, ArrowDataset
    if training_args.do_train:
        if data_args.dataset_name == 'pile':
            train_dataset = ArrowDataset(data_args.dataset_path, block_size=data_args.block_size, split='train')
        elif data_args.dataset_name == "openwebtext":
            train_dataset = PreprocessedDataset(
                data_args.dataset_path, block_size=data_args.block_size, split="train", task=model_args.mode,
            )
        
        else:
            print(f"dataset [{data_args.dataset}] not supported for training")
            raise NotImplementedError

    if training_args.do_eval:
        if data_args.dataset_name == 'pile':
            validation_dataset = ArrowDataset(data_args.dataset_path, block_size=data_args.block_size, split='validation')
        elif data_args.dataset_name == "openwebtext":
            validation_dataset = PreprocessedDataset(
                data_args.dataset_path, block_size=data_args.block_size, split="val", task=model_args.mode,
            )
        else:
            print(f"dataset [{data_args.dataset}] not supported for training")
            raise NotImplementedError

    if training_args.do_train:
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        eval_dataset = validation_dataset
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

    from transformers import Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        callbacks=[LossLoggingCallback],
        data_collator=default_data_collator,
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload

        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is not None and model_args.use_lora:
            # 合并 LoRA 权重到基础模型
            merged_model = model.merge_and_unload()
            # 保存合并后的模型
            merged_model.save_pretrained(os.path.join(training_args.output_dir,"merged_model"))

        metrics = train_result.metrics

        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        metrics = trainer.evaluate()

        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()