import torch
from transformers import GenerationConfig
from transformers.trainer_seq2seq import Seq2SeqTrainer
from transformers.trainer import *
import numpy as np
import torch.nn as nn
import pickle

from collator import SUPPORTED_DECODER_MODELS, check_model
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM

def nested_truncate(tensors, limit):
    "Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_truncate(t, limit) for t in tensors)
    return tensors[:limit]

def skip_instructions(model, predictions_ids, tokenizer, ignore_idx=-100):
    predictions_ids = np.where(predictions_ids == ignore_idx, tokenizer.pad_token_id, predictions_ids)
    predictions = tokenizer.batch_decode(
        predictions_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return predictions

def create_memory_replay_generators(task, task_list, replay_data_dict, split='train_mem'): # creating previous tasks memory buffers
    print('Creating generators for previous tasks ...')
    tasks_to_generators = {}
    curr_task_num = task_list.index(task)
    for idx in np.arange(curr_task_num):
        prev_task = task_list[idx]
        tasks_to_generators[prev_task] = iter(replay_data_dict[prev_task])
    return tasks_to_generators

class Trainer(Seq2SeqTrainer):
    
    def project_into_vocabluary(self, vector, E, tokenizer, top_k=20, bottom_k=-1):
        """
        Project a vector into the vocabulary space and return the top_k tokens.
        :param vector: D dimensional vector
        :param E: Language model embedding matrix (V, D)
        :param tokenizer: Model tokenizer
        :param top_k: How many top tokens to return
        :param bottom_k: How many bottom tokens to return. If -1, return top_k tokens
        :return:
        """
        vector = vector.to(torch.float32).to('cuda')
        E = E.to(torch.float32).to('cuda')
        vocab_ranking = torch.matmul(E, vector)     # (V,)
        sorted_token_ids = np.argsort(vocab_ranking.detach().cpu().numpy())[::-1]  # Descending order
        if bottom_k == -1:
            sorted_tokens = [tokenizer.decode(x).strip() for x in sorted_token_ids[:top_k]]
        else :
            sorted_tokens = [tokenizer.decode(x).strip() for x in sorted_token_ids[-bottom_k:][::-1]]  # Least score to most score
        return sorted_tokens

    def __init__(self, model, args, train_dataset, eval_dataset=None, forget_dataset=None, tokenizer=None, data_collator=None, compute_metrics=None):
        super().__init__(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics)

        self.forget_dataset = forget_dataset

        if self.args.forget_freq != -1:
            seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
            self.forget_dataloader = None
    
            if forget_dataset is not None:
                if self.is_deepspeed_enabled:
                    # sampler=DistributedSampler(
                    #     dataset,
                    #     num_replicas=self.args.world_size,
                    #     rank=self.args.process_index,
                    #     seed=seed),
                    generator = torch.Generator()
                    generator.manual_seed(seed)
                    sampler=RandomSampler(forget_dataset, generator=generator)
                else:
                    generator = torch.Generator()
                    generator.manual_seed(seed)
                    sampler=RandomSampler(forget_dataset, generator=generator)
                
                self.forget_dataloader = DataLoader(
                    forget_dataset,
                    batch_size=self.args.forget_batch_size,
                    sampler=sampler,
                    collate_fn=data_collator,
                    drop_last=self.args.dataloader_drop_last,
                    num_workers=self.args.dataloader_num_workers,
                    pin_memory=False,
                    worker_init_fn=seed_worker)
            self.forget_iterator = iter(self.forget_dataloader)
        
        if self.args.retain_weight > 0:
            print("Loading Reference Model ...")
            self.reference_model = AutoModelForCausalLM.from_pretrained(self.args.model_name_or_path,
                                                                        use_safetensors=True,
                                                                        torch_dtype="auto",
                                                                        # use_flash_attention_2=True
                                                                        )
            self.reference_model.to(self.args.device).to(torch.bfloat16)
        
        print(f"Loading Language Space from {args.space_tag}")
        with open(f'../language_space/{args.model_id}/{args.space_tag}/lang_specific_space_last.pkl', 'rb') as f:
            self.lang_spec_space = pickle.load(f)
            # self.lang_spec_space.requires_grad = False
        
        with open(f'../language_space/{args.model_id}/{args.space_tag}/lang_shared_space_last.pkl', 'rb') as f:
            self.lang_comm_space = pickle.load(f)
            '''
                shape 1 * d
            '''
            # self.lang_comm_space.requires_grad = False
        
        self.lang_direction = []
        
        for lang in args.lang_list:
            with open(f'../language_space/{args.model_id}/{args.space_tag}/{lang}_direction_last.pkl', 'rb') as f:
                lang_dir = torch.stack(pickle.load(f))
                lang_dir.requires_grad = False
                self.lang_direction.append(lang_dir)
                # print(lang, lang_dir[-1])

    def projection(self, emb, lang_dir):
        
        lang_dir_norm = lang_dir / torch.linalg.norm(lang_dir, axis=1, keepdims=True).to(emb.dtype)
        proj = torch.matmul(emb, lang_dir_norm.T)
        
        return torch.matmul(proj, lang_dir_norm)

    def orthogonality_loss(self, u, v):
        # 计算每一对向量的内积
        inner_product = torch.sum(u * v, dim=1)  # (B,)

        # 计算正交性损失（内积的平方和）
        loss = torch.mean(inner_product ** 2)  # 标量

        return loss

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        
        outputs = []
        for lang in self.args.lang_list:
            outputs.append(model(**{'input_ids': inputs[f"input_ids_{lang}"], 'attention_mask': inputs[f"attention_mask_{lang}"]}, output_hidden_states=True))

        if self.args.retain_weight > 0:
            outputs_ref = []
            for lang in self.args.lang_list:
                with torch.no_grad():
                    outputs_ref.append(self.reference_model(**{'input_ids': inputs[f"input_ids_{lang}"], 'attention_mask': inputs[f"attention_mask_{lang}"]}, output_hidden_states=True))
            
            en_retain_loss = nn.MSELoss()(outputs[0].hidden_states[-1].mean(dim=1), outputs_ref[0].hidden_states[-1].mean(dim=1))
            print(f"{self.args.retain_weight} * en_retain_loss:", self.args.retain_weight * en_retain_loss.item())

            loss = self.args.retain_weight * en_retain_loss
        
        if self.args.pull_weight['en'] > 0:
            lang_spec_proj_list = []
            for lang, outputs_lang, outputs_lang_ref, lang_direction in zip(self.args.lang_list[1:], outputs[1:], outputs_ref[1:], self.lang_direction[1:]):
                if self.args.layer_wise:
                    pull_loss = []

                    for layer_idx in range(self.args.pull_end_layer + 1, 33):
                        lang_spec_proj = self.projection(outputs_lang.hidden_states[layer_idx][:, -1, :], self.lang_spec_space[layer_idx-1].to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                        ori_lang_spec_proj = self.projection(outputs_lang_ref.hidden_states[layer_idx][:, -1, :], self.lang_spec_space[layer_idx-1].to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))

                        pull_loss.append(nn.MSELoss()(lang_spec_proj - ori_lang_spec_proj, self.args.pull_weight[lang] * lang_direction[layer_idx-1].repeat(lang_spec_proj.shape[0], 1).to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device)))

                    pull_loss = torch.stack(pull_loss).sum()
                    loss += pull_loss
                    print(f"{self.args.pull_weight[lang]} * {lang}_pull_loss:", self.args.pull_weight[lang] * pull_loss.item())
                else:
                    lang_spec_proj = self.projection(outputs_lang.hidden_states[-1][:, -1, :], self.lang_spec_space[-1].to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                    ori_lang_spec_proj = self.projection(outputs_lang_ref.hidden_states[-1][:, -1, :], self.lang_spec_space[-1].to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                    pull_loss = nn.MSELoss()(lang_spec_proj - ori_lang_spec_proj, self.args.pull_weight[lang] * lang_direction[-1].repeat(lang_spec_proj.shape[0], 1).to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                    loss += pull_loss
                    print(f"{self.args.pull_weight[lang]} * {lang}_pull_loss:", self.args.pull_weight[lang] * pull_loss.item()) 

                    lang_spec_proj_list.append(lang_spec_proj)
        
        if self.args.push_weight > 0:
            # en_lang_spec_proj = self.projection(outputs[0].hidden_states[-1][:, -1, :], self.lang_spec_space[-1].to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
            for lang, outputs_lang, lang_spec_proj in zip(self.args.lang_list[1:], outputs[1:], lang_spec_proj_list):
                if self.args.layer_wise:
                    push_loss = []
                    for layer_idx in range(self.args.pull_end_layer + 1, 33):
                        en_comm_proj = self.projection(outputs[0].hidden_states[layer_idx][:, -1, :].unsqueeze(0), self.lang_comm_space[layer_idx-1].unsqueeze(dim=0).to(outputs[0].hidden_states[-1].dtype).to(outputs[0].hidden_states[-1].device))
                        lang_comm_proj = self.projection(outputs_lang.hidden_states[layer_idx][:, -1, :].unsqueeze(0), self.lang_comm_space[layer_idx-1].unsqueeze(dim=0).to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                        push_loss.append(nn.MSELoss()(en_comm_proj, lang_comm_proj))
                    
                    push_loss = torch.stack(push_loss).sum()
                    loss += self.args.push_weight * push_loss
                    print(f"{self.args.push_weight} * {lang}_push_loss:", self.args.push_weight * push_loss.item())
                else:
                    en_comm_proj = self.projection(outputs[0].hidden_states[-1][:, -1, :], self.lang_comm_space[-1].unsqueeze(dim=0).to(outputs[0].hidden_states[-1].dtype).to(outputs[0].hidden_states[-1].device))
                    lang_comm_proj = self.projection(outputs_lang.hidden_states[-1][:, -1, :], self.lang_comm_space[-1].unsqueeze(dim=0).to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                    # en_comm_proj = self.projection(outputs[0].hidden_states[-1][:, -1, :], self.lang_comm_space[-1].to(outputs[0].hidden_states[-1].dtype).to(outputs[0].hidden_states[-1].device))
                    # lang_comm_proj = self.projection(outputs_lang.hidden_states[-1][:, -1, :], self.lang_comm_space[-1].to(outputs_lang.hidden_states[-1].dtype).to(outputs_lang.hidden_states[-1].device))
                    # en_comm_proj = outputs_lang.hidden_states[-1][:, -1, :] - en_lang_spec_proj
                    # lang_comm_proj = outputs_lang.hidden_states[-1][:, -1, :] - lang_spec_proj
                    
                    # orthogonal_loss = self.orthogonality_loss(lang_comm_proj, lang_spec_proj)
                    push_loss = nn.MSELoss()(en_comm_proj, lang_comm_proj)
                    loss += self.args.push_weight * push_loss
                    print(f"{self.args.push_weight} * {lang}_push_loss:", self.args.push_weight * push_loss.item())
                    # print(f"{self.args.push_weight} * {lang}_orthogonal_loss:", self.args.push_weight * orthogonal_loss.item())

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs_en[self.args.past_index]
        
        return (loss, outputs_en) if return_outputs else loss

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
            
        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss, retain_graph=True)

        # for name, param in model.named_parameters():
        #     if "model.layers.31" not in name:  # 只保留第31层的梯度
        #         param.grad = None

        # if self.use_apex:
        #     with amp.scale_loss(loss, self.optimizer) as scaled_loss:
        #         scaled_loss.backward()
        # else:
        #     self.accelerator.backward(loss, retain_graph=True)

        # for n, p in self.accelerator.unwrap_model(model).named_parameters():
        #     if "model.layers." in n:
        #         layer_idx = int(n.split('.')[3])
        #         if layer_idx >= self.args.tuning_start_layer and layer_idx <= self.args.tuning_end_layer:
        #             p.requires_grad = False
        #         # if p.grad is not None:
        #     print(f"{n}.grad, {p.requires_grad}, {p.grad}")
        # # print(stop)
        
        # if self.use_apex:
        #     with amp.scale_loss(pull_loss, self.optimizer) as scaled_loss:
        #         scaled_loss.backward()
        # else:
        #     self.accelerator.backward(pull_loss)
        
        # for n, p in model.named_parameters():
        #     if p.requires_grad:
        #         print(n, p.grad)
            # if "model.layers." in n:
            #     layer_idx = int(n.split('.')[3])
            #     if layer_idx >= self.args.tuning_start_layer and layer_idx <= self.args.tuning_end_layer:
            #         p.requires_grad = True
        
        if self.args.forget_freq != -1 and self.state.global_step % self.args.forget_freq == 0:
            print("Forgetting on Harmful Content ...")
            generator_mem1 = self.forget_iterator
            try:
                # Samples the batch
                b = next(generator_mem1)
            except StopIteration:
                generator_mem1 = iter(self.forget_dataloader)
                self.replay_iterator = generator_mem1
                b = next(generator_mem1)
            
            forget_inputs = self._prepare_inputs(b)
            
            with self.compute_loss_context_manager():
                forget_loss = -1.0 * self.args.forget_ratio * self.compute_loss(model, forget_inputs)
            
            if self.args.n_gpu > 1:
                forget_loss = forget_loss.mean()  # mean() to average on multi-gpu parallel training

            if self.use_apex:
                with amp.scale_loss(forget_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                self.accelerator.backward(forget_loss)
            print("Done. Forgetting loss is", round(forget_loss.item(), 4))

        return loss.detach() / self.args.gradient_accumulation_steps

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

        Works both with or without labels.
        """
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

        # if eval is called w/o train init deepspeed here
        if args.deepspeed and not self.is_deepspeed_enabled:

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, # inference=True
            )
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine

        model = self._wrap_model(self.model, training=False)

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = dataloader.batch_size

        logger.info(f"***** Running {description} *****")
        if has_length(dataloader.dataset):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        # model.eval()

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = dataloader.dataset
        
        if args.past_index >= 0:
            self._past = None

        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            # print(inputs)
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
            # Update containers on host
            if loss is not None:
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if labels is not None:
                labels = self.accelerator.pad_across_processes(labels)
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            if logits is not None:
                logits = self.accelerator.pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None
                
        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
        if has_length(eval_dataset):
            num_samples = len(eval_dataset)
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
            num_samples = eval_dataset.num_examples
        else:
            num_samples = observed_num_examples

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
            
        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(dataset=eval_dataset, preds=all_preds, model=model, save_prefix=metric_key_prefix)
        else:
            metrics = {}

        metrics["global_step"] = self.state.global_step

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

    
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        # XXX: adapt synced_gpus for fairscale as well
        # gen_kwargs = self._gen_kwargs
        if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
            gen_kwargs = {
                "max_new_tokens": 100,
                "do_sample": True,
                "temperature": 0.6,
                "top_p": 0.9,
                "decoder_start_token_id": 0,
                "eos_token_id": 1,
                "pad_token_id": 0,
            }
        else:
            gen_kwargs = {
                "max_new_tokens": 100,
                "do_sample": True,
                "temperature": 0.6,
                "top_p": 0.9,
                "eos_token_id": 2,
                "pad_token_id": 1,
            }

        if "attention_mask" in inputs:
            gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)

        generation_config = GenerationConfig(**gen_kwargs)
        # prepare generation inputs
        # some encoder-decoder models can have varying encder's and thus
        # varying model input names
        if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
            generation_inputs = inputs[self.model.encoder.main_input_name]
        else:
            generation_inputs = inputs[self.model.main_input_name]
        
        # print("正在生成文本")
        generated_tokens = self.model.generate(
            input_ids=generation_inputs, 
            generation_config=generation_config
        )
        # print("生出来了")
        bs, source_len = inputs['input_ids'].shape
        # in case the batch is shorter than max length, the output should be padded
        if check_model(self.model.config._name_or_path, SUPPORTED_DECODER_MODELS):
            max_length = source_len + gen_kwargs["max_new_tokens"]
        else:
            max_length = gen_kwargs["max_new_tokens"]

        if generated_tokens.shape[-1] < max_length:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, max_length)

        with torch.no_grad():
            if has_labels:
                with self.autocast_smart_context_manager():
                    outputs = model(**inputs)
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return (loss, None, None)

        if has_labels:
            labels = inputs["labels"]
            if labels.shape[-1] < gen_kwargs["max_new_tokens"]:
                labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_new_tokens"])
        else:
            labels = None

        return (loss, generated_tokens, labels)
