# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional

from ...data import SFTDataCollatorWith4DAttentionMask,SFTDataCollatorWith4DBlockAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor, has_tokenized_data
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer, load_model_for_mtp
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
import torch,json
import os
import sys
from loguru import logger
from datasets import load_from_disk
from ast import literal_eval


if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments, TrainerCallback

    from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


def run_mtp(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
    callbacks: Optional[List["TrainerCallback"]] = None,
):
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    tokenizer.padding_side = 'left'
    tokenizer.truncation_side = 'left'#重要
    # tokenizer.pad_token=tokenizer.eos_token
    if not tokenizer.pad_token:
        tokenizer.pad_token=tokenizer.eos_token
        tokenizer.pad_token_id=tokenizer.eos_token_id
    print('tokenizer.eos_token is ',tokenizer.eos_token,tokenizer.eos_token_id,'tokenizer.pad_token is ',tokenizer.pad_token)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)

    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
    skip_ids = [IGNORE_INDEX] + list(tokenizer.added_tokens_encoder.values()) # output中希望忽略的token
    print('skip_ids is:',skip_ids)
    
    print('model_args.if_block_diag_attn',model_args.if_block_diag_attn)
    if model_args.if_block_diag_attn:
        model_args.block_diag_attn=True
        print('model_args.block_diag_attn modified to ',model_args.block_diag_attn)

    with training_args.main_process_first(desc="load dataset"):
        # 使用input最后一个token的last hidden state预测tokens
        def filter_labels(sample):
            input_ids = torch.tensor(sample['input_ids'])
            labels = torch.tensor(sample['labels'])
            attention_mask = torch.tensor(sample['attention_mask'])
        
            # Step 1: 提取prompt部分的input_ids和attention_mask
            prompt_input_ids = input_ids[labels == IGNORE_INDEX]
            prompt_attention_mask = attention_mask[:prompt_input_ids.shape[0]]
            # input_ids = input_ids[labels == IGNORE_INDEX]
            # attention_mask = attention_mask[:input_ids.shape[0]]

            '''gth修改'''
            # Step 2: 在prompt后面添加special tokens
            if model_args.use_same_special:
                emb_tokens = ["<emb_0>"]+["<emb_1>"]*(model_args.special_tokens_num-1)
            else:
                emb_tokens = [f"<emb_{i}>" for i in range(model_args.special_tokens_num)]

            emb_ids = tokenizer.convert_tokens_to_ids(emb_tokens)
            special_token_tensor = torch.tensor(emb_ids, dtype=prompt_input_ids.dtype)
            new_input_ids = torch.cat([prompt_input_ids, special_token_tensor])


            # Step 3: 更新attention_mask
            special_mask_value = 2 if model_args.block_diag_attn else 1
            special_attention_mask = torch.full((model_args.special_tokens_num,), special_mask_value, dtype=attention_mask.dtype)
            new_attention_mask = torch.cat([prompt_attention_mask, special_attention_mask])


            # Step 4: 提取有效的labels（非IGNORE_INDEX部分）
            valid_labels = labels[labels != IGNORE_INDEX]


            # Step 5: 生成原始label的编码 (original_labels)
            output_text = tokenizer.decode(valid_labels).replace(tokenizer.eos_token, '')
            sentences = [[{'role': 'system', 'content': 'You are a concise responder. You should always response as fewer keywords as possible to answer the question.'},
                          {'role': 'user', 'content': "Input:【"+output_text+"】\n\n"+"Instruction: 【Please provide the keywords for the input.】"}]]
            original_labels_ids = tokenizer.apply_chat_template(sentences, add_generation_prompt=True, tokenize=True)[0]
            original_labels = torch.tensor(original_labels_ids, dtype=input_ids.dtype)
            labels_attention_mask = torch.ones_like(original_labels, dtype=attention_mask.dtype)


            if len(valid_labels) == 0:
                print("发现全无效标签样本，样本索引：{}".format(sample['index']))
                return None  # 跳过该样本

            # Step 6: 根据special_tokens_num截断或填充labels
            target_labels = valid_labels[:model_args.special_tokens_num]
            padding_length = model_args.special_tokens_num - len(target_labels)
            if padding_length > 0:
                padding = torch.full((padding_length,), IGNORE_INDEX, dtype=labels.dtype)
                target_labels = torch.cat([target_labels, padding])
            # Step 7: 构造完整的labels，prompt部分设为IGNORE_INDEX，special tokens部分设为目标labels
            new_labels = torch.full_like(new_input_ids, IGNORE_INDEX)
            new_labels[-model_args.special_tokens_num:] = target_labels


            # '''one-step'''
            # target_labels = valid_labels[:model_args.special_tokens_num+1]
            # padding_length = model_args.special_tokens_num+1 - len(target_labels)
            # if padding_length > 0:
            #     padding = torch.full((padding_length,), IGNORE_INDEX, dtype=labels.dtype)
            #     target_labels = torch.cat([target_labels, padding])
            # # Step 7: 构造完整的labels，prompt部分设为IGNORE_INDEX，special tokens部分设为目标labels
            # new_labels = torch.full_like(new_input_ids, IGNORE_INDEX)
            # new_labels[-model_args.special_tokens_num-1:] = target_labels


            # Step 8: 过滤掉skip_ids中的特殊token（除了eos_token_id）
            skip_ids_set = set(skip_ids) - {tokenizer.eos_token_id}
            new_labels[torch.isin(new_labels, torch.tensor(list(skip_ids_set)))] = IGNORE_INDEX

            # 返回处理后的样本
            return {
                'input_ids': new_input_ids.tolist(),
                'labels': new_labels.tolist(),
                'attention_mask': new_attention_mask.tolist(),
                'original_labels': original_labels.tolist(),
                'labels_attention_mask': labels_attention_mask.tolist()
            }

            '''gth修改 end'''

            
        print('正在filter_labels函数中')
        dataset_module['train_dataset'] = dataset_module['train_dataset'].map(filter_labels,
                                    batched=False,
                                    num_proc=data_args.preprocessing_num_workers,
                                    desc='Filtering labels...')
        
        print(dataset_module['train_dataset'][:3])

    model = load_model_for_mtp(tokenizer, model_args, finetuning_args, training_args.do_train)

    '''gth修改 freeze参数'''
    if model_args.if_freeze_layer:
        for name, param in model.named_parameters():
            parts = name.split('.')
            # 确保参数名中包含'layers'且层号在正确的位置
            if "layers" in parts:
                # 查找'layers'在parts中的索引
                layer_pos = parts.index("layers")
                if layer_pos + 1 < len(parts) and parts[layer_pos + 1].isdigit():
                    layer_id = int(parts[layer_pos + 1])
                    if layer_id < 24:
                        param.requires_grad = False
    '''gth修改 end'''


    if getattr(model, "is_quantized", False) and not training_args.do_train:
        setattr(model, "_hf_peft_config_loaded", True)  # hack here: make model compatible with prediction


    '''gth修改'''
    if not model_args.block_diag_attn:
        print('now use 普通 4D attention')
        data_collator = SFTDataCollatorWith4DAttentionMask(
            template=template,
            pad_to_multiple_of=8 if training_args.do_train else None,  # for shift short attention
            label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
            block_diag_attn=model_args.block_diag_attn,
            attn_implementation=getattr(model.config, "_attn_implementation", None),
            compute_dtype=model_args.compute_dtype,
            **tokenizer_module,
        )
    else: # 带block
        print('now use block4D attention')
        data_collator = SFTDataCollatorWith4DBlockAttentionMask(
            template=template,
            pad_to_multiple_of=8 if training_args.do_train else None,  # for shift short attention
            label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
            block_diag_attn=model_args.block_diag_attn,
            attn_implementation=getattr(model.config, "_attn_implementation", None),
            compute_dtype=model_args.compute_dtype,
            **tokenizer_module,
        )     
    '''gth修改结束'''

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
    training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
    training_args.remove_unused_columns = False  # important for multimodal dataset

    # Metric utils
    metric_module = {}
    if training_args.predict_with_generate:
        metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
    elif finetuning_args.compute_accuracy:
        metric_module["compute_metrics"] = ComputeAccuracy()
        metric_module["preprocess_logits_for_metrics"] = eval_logit_processor

    # Initialize our Trainer
    trainer = CustomSeq2SeqTrainer(
        model=model,
        args=training_args,
        finetuning_args=finetuning_args,
        data_collator=data_collator,
        callbacks=callbacks,
        **dataset_module,
        **tokenizer_module,
        **metric_module,
    )

    # Keyword arguments for `model.generate`
    gen_kwargs = generating_args.to_dict()
    gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
    gen_kwargs["logits_processor"] = get_logits_processor()

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
        trainer.save_model()
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])

    if training_args.predict_with_generate:
        tokenizer.padding_side = "left"  # use left-padding in generation

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
        if training_args.predict_with_generate:  # eval_loss will be wrong if predict_with_generate is enabled
            metrics.pop("eval_loss", None)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Predict
    if training_args.do_predict:
        predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
        if training_args.predict_with_generate:  # predict_loss will be wrong if predict_with_generate is enabled
            predict_results.metrics.pop("predict_loss", None)
        trainer.log_metrics("predict", predict_results.metrics)
        trainer.save_metrics("predict", predict_results.metrics)
        trainer.save_predictions(dataset_module["eval_dataset"], predict_results)

    # Create model card
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
