# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
#
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

from transformers.utils import (
    ADAPTER_CONFIG_NAME,
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
    CONFIG_NAME,
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    XLA_FSDPV2_MIN_VERSION,
    PushInProgress,
    PushToHubMixin,
    can_return_loss,
    find_labels,
    is_accelerate_available,
    is_apex_available,
    is_bitsandbytes_available,
    is_datasets_available,
    is_galore_torch_available,
    is_grokadamw_available,
    is_in_notebook,
    is_ipex_available,
    is_liger_kernel_available,
    is_lomo_available,
    is_peft_available,
    is_safetensors_available,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_schedulefree_available,
    is_torch_compile_available,
    is_torch_mlu_available,
    is_torch_mps_available,
    is_torch_musa_available,
    is_torch_neuroncore_available,
    is_torch_npu_available,
    is_torch_xla_available,
    is_torch_xpu_available,
    is_torchao_available,
    logging,
    strtobool,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from accelerate.utils import (
        DistributedDataParallelKwargs,
        DistributedType,
        load_fsdp_model,
        load_fsdp_optimizer,
        save_fsdp_model,
        save_fsdp_optimizer,
    )
import torch.nn.functional as F
from utils import *
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaForCausalLM
from dataclasses import dataclass, field
import json
import math
import pathlib
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother

import sys
from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    split_model_path: Optional[str]= field(default="facebook/opt-125m")
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files"
        },
    )
    padding_side: str = field(
        default="right", metadata={"help": "The padding side in tokenizer"}
    )

    num_ee_block: int = field(
        default=4, metadata={"help": "The early exit layer blocks"}
    )
    headclass: str = field(
        default="trm", metadata={"help": "The headclass"}
    )


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    eval_data_path: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    lazy_preprocess: bool = False


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    num_train_epochs: float = field(
        default=None, 
        metadata={"help": "Total number of training epochs to perform."}
        )


local_rank = None


def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def trainer_save_model_safe(trainer: transformers.Trainer):
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp import StateDictType, FullStateDictConfig

    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(
        trainer.model, StateDictType.FULL_STATE_DICT, save_policy
    ):
        trainer.save_model()


class CustomModel(LlamaForCausalLM):
    
    def ee_hook(self, module, input, output):
        self.activation.append(output)

    def __init__(self, config):
        super().__init__(config)
        self.heads=[]
        for i in range(num_ee_block):
            if config.headclass=="trm":
                head = Transformerhead(config,split_model_path)
            else:
                head = Normhead(config,split_model_path)
            self.heads[i]=head

        self.activation=[]
        self.hooks=[]
        interval = config.num_hidden_layers//num_ee_block
        for i in range(num_ee_block):
            
            hook = self.model.layers[interval*(i+1)-1].register_forward_hook(self.ee_hook)
            self.hooks[i]=hook
        



    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
    ) :
        
        outputs=super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            images=images,
            image_sizes=image_sizes,
            return_dict=return_dict
        )
        self.eeoutputs=[]
        for i in range(num_ee_block):
            x = self.activation.pop()

            self.eeoutputs[i] = self.heads[i](x)

        return self.eeoutputs,outputs

class CustomTrainer(Trainer):

    def training_step(
        self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)
        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)
        #loss_dict = []
        with self.compute_loss_context_manager():
            loss_dict = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
            #loss_dict.extend(loss_all)

        del inputs
        if (
            self.args.torch_empty_cache_steps is not None
            and self.state.global_step % self.args.torch_empty_cache_steps == 0
        ):
            if is_torch_xpu_available():
                torch.xpu.empty_cache()
            elif is_torch_mlu_available():
                torch.mlu.empty_cache()
            elif is_torch_musa_available():
                torch.musa.empty_cache()
            elif is_torch_npu_available():
                torch.npu.empty_cache()
            elif is_torch_mps_available(min_version="2.0"):
                torch.mps.empty_cache()
            elif is_torch_hpu_available():
                logger.warning(
                    "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
                )
            else:
                torch.cuda.empty_cache()

        kwargs = {}

        # For LOMO optimizers you need to explicitly use the learnign rate
        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            kwargs["learning_rate"] = self._get_learning_rate()

        if self.args.n_gpu > 1:
            for i in range(num_ee_block):
                loss_dict[i]=loss_dict[i].mean()
            #loss0 = loss0.mean()  # mean() to average on multi-gpu parallel training
            #loss1 = loss1.mean()
            #loss2 = loss2.mean()
            #loss3 = loss3.mean()

        if self.use_apex:
            for i in range(num_ee_block):
                with amp.scale_loss(loss_dict[i], self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            
        else:
            # Finally we need to normalize the loss for reporting
            if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
                for i in range(num_ee_block):
                    loss_dict[i] = loss_dict[i] / self.args.gradient_accumulation_steps


            # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
            # https://github.com/huggingface/transformers/pull/35808
            if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs["scale_wrt_gas"] = False
            for i in range(num_ee_block):
                self.accelerator.backward(loss_dict[i], **kwargs)
            for i in range(num_ee_block):
                loss += loss_dict[i]
            return loss.detach()
    def forward_kl(self,logits, teacher_logits, labels):
        eps=1e-8
        teacher_probs = F.softmax(teacher_logits+eps, dim=-1, dtype=torch.float32)
        inf_mask = torch.isinf(logits)
        student_logprobs = F.log_softmax(logits+eps, dim=-1, dtype=torch.float32)
        prod_probs = torch.masked_fill(teacher_probs * student_logprobs, inf_mask, 0)
        x = torch.sum(prod_probs, dim=-1).view(-1)
        mask = (labels != -100).int()
        #print(torch.sum(x * mask.view(-1), dim=0),torch.sum(mask.view(-1), dim=0))

        distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
        if torch.sum(mask.view(-1), dim=0)==0:
            distil_loss = torch.tensor(0.0)
        return distil_loss

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        
        labels = inputs.pop("labels")
        T = 1
        eps=1e-8
        ce_loss = nn.CrossEntropyLoss()
       
        student_logits_list, teacher_logits = model(**inputs)
        teacher_logits=teacher_logits.logits
        teacher_probs = F.softmax(teacher_logits/T+eps, dim=-1)

        losses=[]
        for i in range(num_ee_block):
            

            student_probs = F.softmax(student_logits_list[i]/T+eps, dim=-1)

            KL_loss = self.forward_kl(student_logits_list[i], teacher_logits,labels)

            hard_loss = 0
            teacher_labels= torch.argmax(teacher_probs,dim=-1)
            teacher_max_probs = teacher_probs.gather(-1, teacher_labels.unsqueeze(-1)).squeeze(-1)
            student_at_teacher = student_probs.gather(-1, teacher_labels.unsqueeze(-1)).squeeze(-1)
            hard_loss = F.mse_loss(student_at_teacher, teacher_max_probs)

            loss = KL_loss+hard_loss
            losses[i]=loss


        return losses
        


def train():
    global local_rank
    global split_model_path
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    local_rank = training_args.local_rank
    split_model_path = model_args.split_model_path
    # Set RoPE scaling factor
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path,

    )
    orig_ctx_len = getattr(config, "max_position_embeddings", None)
    if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
        scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
        config.rope_scaling = {"type": "linear", "factor": scaling_factor}
    config.use_cache = False

    global num_ee_block
    num_ee_block = model_args.num_ee_block
    config.headclass = model_args.headclass

    # Load model and tokenizer
    #model = transformers.AutoModelForCausalLM.from_pretrained(
    #    model_args.model_name_or_path,
    #    config=config,
    #    cache_dir=training_args.cache_dir,
    #    trust_remote_code=model_args.trust_remote_code,
    #)
    
    model = CustomModel.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        torch_dtype=torch.bfloat16
        #trust_remote_code=model_args.trust_remote_code,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=True,
        #trust_remote_code=model_args.trust_remote_code,
    )

    if tokenizer.pad_token != tokenizer.unk_token:
        tokenizer.pad_token = tokenizer.unk_token
    for name,param in model.model.named_parameters():
        param.requires_grad = False
    for name,param in model.lm_head.named_parameters():
        param.requires_grad = False
    # Load data
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_path=data_args.data_path,lazy_preprocess=True)

    # Start trainner
    trainer = CustomTrainer(
        model=model, tokenizer=tokenizer, args=training_args, **data_module
    )

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"可训练参数: {name}, 形状: {param.shape}")
    #time.sleep()
    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    # Save model
    model.config.use_cache = True
    #trainer.save_state()
    #if trainer.is_deepspeed_enabled:
    trainer.save_model()
    #else:
    #    trainer_save_model_safe(trainer)
    torch.save(model.head.state_dict(),f"{training_args.output_dir}/lm_head4.pth")


if __name__ == "__main__":
    train()
