#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import argparse
import logging
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from typing import Dict, Optional
import math
import numpy as np

import torch
import datasets
import evaluate
from optimizer import add_optimizer_params
import transformers
from transformers import AutoModelForCausalLM, HfArgumentParser, set_seed, Trainer, TrainingArguments, GPT2Tokenizer
from transformers.trainer_utils import get_last_checkpoint
from peft import LoraConfig, get_peft_model
from gpu import add_gpu_params, distributed_sync
from data_utils import FT_Dataset, padding_tokens
from exps.modules.loss_modified import Trainer_modified
# from nltk.translate.bleu_score import corpus_bleu
from nlgeval import NLGEval
import inspect


torch.set_printoptions(threshold=100000)
logger = logging.getLogger(__name__)


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """
    
    task_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the task to train on: "
        },
    )
    dataset_name: Optional[str] = field(
        default=None, metadata={
            "help": "The name of the dataset to use (via the datasets library)."
        }
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={
            "help": "The configuration name of the dataset to use (via the datasets library)."
        }
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={
            "help": "Overwrite the cached preprocessed datasets or not."
        }
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
        },
    )
    train_file: Optional[str] = field(
        default=None, metadata={
            "help": "A csv or a json file containing the training data."
        }
    )
    validation_file: Optional[str] = field(
        default=None, metadata={
            "help": "A csv or a json file containing the validation data."
        }
    )
    test_file: Optional[str] = field(default=None, metadata={
        "help": "A csv or a json file containing the test data."
    })
    training_obj: str = field(default='clm', metadata={
        "help": "language model training objective - choices: [jlm, clm]"
    })
    
    # max_tokens_per_batch: Optional[int] = field(
    #     default=0,
    #     metadata={
    #         "help": "dynamic batching. Override batch size when larger than 0"
    #     },
    # )
    
    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
        elif self.dataset_name is not None:
            pass
        elif self.train_file is None or self.validation_file is None:
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
        else:
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json", "jsonl"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                    validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    
    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    config_name: Optional[str] = field(
        default=None, metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        }
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        }
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={
            "help": "Will enable to load a pretrained model whose head dimensions are different."
        },
    )


@dataclass
class LoraConfig_modified:
    task_type: str = field(default=None, metadata={
        "help": "Task type"
    })
    inference_mode: bool = field(default=False, metadata={
        "help": "Whether to use inference mode"
    })
    r: int = field(default=8, metadata={
        "help": "Lora attention dimension"
    })
    lora_alpha: int = field(default=None, metadata={
        "help": "Lora alpha"
    })
    lora_dropout: float = field(default=None, metadata={
        "help": "Lora dropout"
    })
    
    modified_dropout_pattern: str = field(default="", metadata={
        "help": "Dropout pattern"
    })
    modified_dropout_rate: str = field(default="", metadata={
        "help": "Dropout rate"
    })
    modified_dropout: Dict[str, float] = field(default_factory=dict,
                                              metadata={
                                                  "help": "Dropout {pattern: rate}. will be parsed to dict by json"
                                              })
    
    modified_aug_loss: str = field(default="none", metadata={
        "help": "kind of augmented loss"
    })
    modified_aug_loss_weight: float = field(default=0.0, metadata={
        "help": "augmented loss weight"
    })


@dataclass
class DatasetConfig:
    train_data: str = field(default=None, metadata={
        "help": "location of training data corpus"
    })
    valid_data: str = field(default=None, metadata={
        "help": "location of validation data corpus"
    })
    # test_data: str = field(default=None, metadata={"help": "location of test data corpus"})
    train_batch_size: int = field(default=8, metadata={
        "help": "training batch size"
    })
    valid_batch_size: int = field(default=4, metadata={
        "help": "validation batch size"
    })
    # test_batch_size: int = field(default=4, metadata={"help": "test batch size"})


def generate_parser():
    parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')
    
    add_gpu_params(parser)
    add_optimizer_params(parser)
    
    parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')
    
    parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')
    
    parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')
    
    parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
                        help='model names')
    
    parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')
    
    parser.add_argument('--fp16', action='store_true', help='train model with fp16')
    
    parser.add_argument('--log_interval', type=int, default=100, help='log interval')
    
    parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')
    
    parser.add_argument('--save_interval', type=int, default=500, help='save interval')
    
    parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
                        help='working folder.')
    
    parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'],
                        help='language model training objective')
    
    parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')
    
    parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')
    
    parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')
    
    parser.add_argument('--roll_step', type=int, default=100, help='rolling step')
    
    parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')
    
    return parser


# influence model, calculate the influence score between two samples.
def print_args(args):
    if args.rank == 0:
        print('=' * 100)
        for k, v in args.__dict__.items():
            print(f'        - {k} : {v}')
        print('=' * 100)


def __getitem_modified__(self, item):
    if (item >= self.num_examples):
        item = self.rng.randint(0, self.num_examples - 1)
    
    example = self.ft_samples[item]
    context = example[0]
    completion = example[1]
    
    pretokens = [i + self.prefix_cursor for i in range(0, self.prefix_len)]
    intokens = [i + self.infix_cursor for i in range(0, self.infix_len)]
    
    conditions = pretokens + context + intokens
    _input, _input_len = padding_tokens(conditions + completion, self.max_seq_length, 0, 1)
    
    pad_targets = [0 for i in range(0, self.prefix_len)] + context + [0 for i in
                                                                      range(0, self.infix_len)] + completion
    _target, _ = padding_tokens(pad_targets[1:], self.max_seq_length, 0, 1)
    
    if not self.joint_lm:
        _msk = [0.0] * (len(conditions) - 1) + [1.0] * (_input_len - len(conditions))
    else:
        _msk = [1.0] * (_input_len - 1)
    
    _msk, _ = padding_tokens(_msk, self.max_seq_length, 0.0, 1)
    
    output = {}
    output["id"] = torch.tensor(item, dtype=torch.long)
    
    _query, _query_len = padding_tokens(
        conditions, self.max_seq_length, 0, -1,
        max_context_length=self.max_seq_length - self.max_eval_length
    )
    # output["query"] = torch.tensor(_query, dtype=torch.long)
    # output["query_len"] = torch.tensor(_query_len, dtype=torch.long)
    
    output["input_ids"] = torch.tensor(_input, dtype=torch.long)
    output["label_ids"] = torch.tensor(_target, dtype=torch.long)
    
    output["attention_mask"] = torch.tensor(_msk, dtype=torch.float)
    return output


def _set_signature_columns_if_needed_modified(self):
    if self._signature_columns is None:
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        self._signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        # self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
        if "input_ids" not in self._signature_columns:
            self._signature_columns += ["input_ids"]
        if "attention_mask" not in self._signature_columns:
            self._signature_columns += ["attention_mask"]
        if "label_ids" not in self._signature_columns:
            self._signature_columns += ["label_ids"]


# input_ids,past_key_values,attention_mask,token_type_ids,position_ids,head_mask,inputs_embeds,encoder_hidden_states,encoder_attention_mask,labels,use_cache,output_attentions,output_hidden_states,return_dict,label_ids,labels,label.

class AverageMeter(object):
    """Computes and stores the average and current value
         Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update=True):
    if args.fp16:
        with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
            _scaled_loss.backward()
    else:
        _loss.backward()
    
    if is_update:
        if args.clip > 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)
        
        _optimizer.step()
        _optimizer.zero_grad()
    
    if _schedule is not None:
        _schedule.step()


def evaluate(model, valid_loader, args):
    model.eval()
    total_loss = 0.
    start_time = time.time()
    
    avg_lm_loss = AverageMeter()
    
    with torch.no_grad():
        for idx, data in enumerate(valid_loader):
            data = {key: value for key, value in data.items()}
            
            _input = data['input'].to(args.device)
            _target = data['target'].to(args.device)
            _msk = data['mask'].to(args.device)
            
            _lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk)
            loss = _loss.mean()
            
            avg_lm_loss.update(loss.item())
            
            if idx % 100 == 0:
                print('eval samples:', idx, 'loss:', loss.float())
        
        total_time = time.time() - start_time
        print('average loss', avg_lm_loss.avg)
    return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)


def train_validate(
        model,
        optimizer,
        scheduler,
        train_loader,
        valid_loader,
        args,
        train_step=0,
        epoch=0
):
    model.train()
    avg_lm_loss = AverageMeter()
    print('start to train the model................', epoch)
    log_start_time = time.time()
    best_val_ppl = None
    
    train_loader.sampler.set_epoch(epoch)
    
    for idx, data in enumerate(train_loader):
        data = {key: value for key, value in data.items()}
        
        _input = data['input'].to(args.device)
        _target = data['target'].to(args.device)
        _msk = data['mask'].to(args.device)
        
        _lm_logits, _lm_loss = model(
            _input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth
        )
        
        _lm_loss = _lm_loss.mean()
        
        train_step += 1
        is_update = True if train_step % args.grad_acc == 0 else False
        avg_lm_loss.update(_lm_loss.item())
        optimizer_step(
            _lm_loss / (args.grad_acc), optimizer, model, scheduler, args, is_update=is_update
        )
        
        if train_step % args.log_interval == 0:
            elapsed = time.time() - log_start_time
            lr = optimizer.param_groups[0]['lr']
            log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {idx + 1:>6d} batches | ' \
                      f'lr {lr:.3g} | ms/batch {elapsed * 1000 / args.log_interval:5.2f} | ' \
                      f'loss {avg_lm_loss.val:5.2f} | avg loss {avg_lm_loss.avg:5.2f} | ' \
                      f'ppl {math.exp(avg_lm_loss.avg):5.2f}'
            
            if args.rank == 0:
                print(log_str)
            log_start_time = time.time()
            avg_lm_loss.reset()
        
        if train_step % args.save_interval == 0:
            if args.rank == 0:
                model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
                print('saving checkpoint', model_path)
                torch.save({
                    'model_state_dict': lora.lora_state_dict(model)
                }, model_path)
            distributed_sync(args)
        
        # evaluation interval
        if train_step % args.eval_interval == 0:
            eval_start_time = time.time()
            
            valid_loss, valid_ppl = evaluate(model, valid_loader, args)
            
            if best_val_ppl is None or valid_ppl < best_val_ppl:
                best_val_ppl = valid_ppl
            
            log_str = f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' \
                      f'time: {time.time() - eval_start_time:5.2f}s | valid loss {valid_loss:5.2f} | ' \
                      f'valid ppl {valid_ppl:5.2f} | best ppl {best_val_ppl:5.2f} '
            
            if args.rank == 0:
                print('-' * 100)
                print(log_str)
                print('-' * 100)
            
            model.train()
            distributed_sync(args)
        
        if train_step == args.max_step:
            break
    
    if args.rank == 0:
        model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
        print('saving checkpoint', model_path)
        torch.save({
            'model_state_dict': model.state_dict()
        }, model_path)
    distributed_sync(args)
    return train_step


def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, LoraConfig_modified))
    model_args, data_args, training_args, peft_args_modified = parser.parse_args_into_dataclasses()
    # convert modified_dropout-relevant info to dict form
    modified_dropout_rate = [i.strip() for i in peft_args_modified.modified_dropout_rate.split(",")
                            if i.strip() != ""]
    peft_args_modified.modified_dropout_rate = list(map(float, modified_dropout_rate))
    modified_dropout_pattern = [i.strip() for i in peft_args_modified.modified_dropout_pattern.split(",")
                               if i.strip() != ""]
    peft_args_modified.modified_dropout_pattern = list(map(str, modified_dropout_pattern))
    peft_args_modified.modified_dropout = dict(
        zip(peft_args_modified.modified_dropout_pattern, peft_args_modified.modified_dropout_rate))
    peft_config = LoraConfig(task_type=peft_args_modified.task_type,
                             inference_mode=peft_args_modified.inference_mode,
                             r=peft_args_modified.r,
                             lora_alpha=peft_args_modified.lora_alpha,
                             lora_dropout=peft_args_modified.lora_dropout)
    
    # print configuration
    print("modified: model_args: ", asdict(model_args))
    print("modified: data_args: ", asdict(data_args))
    print("modified: training_args: ", asdict(training_args))
    print("modified: peft_args_modified: ", asdict(peft_args_modified))
    print("modified: peft_config: ", asdict(peft_config))
    
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -  %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()
    
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    
    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")
    
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
    
    # Set seed before initializing model.
    set_seed(training_args.seed)
    
    # if args.fp16: # todo: check if this is necessary
    #     try:
    #         from apex import amp
    #     except Exception as e:
    #         warnings.warn('Could not import amp, apex may not be installed')
    
    train_data = FT_Dataset(
        data_args.train_file, training_args.per_device_train_batch_size, data_args.max_seq_length,
        joint_lm=data_args.training_obj == 'jlm'
    )
    valid_data = FT_Dataset(
        data_args.validation_file, training_args.per_device_eval_batch_size, data_args.max_seq_length,
    )
    # test_data = FT_Dataset(
    #     data_args.test_file, training_args.per_device_eval_batch_size, data_args.max_seq_length,
    # )
    
    tokenizer = GPT2Tokenizer.from_pretrained(model_args.model_name_or_path, )
    
    # define metrics
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        nlgeval = NLGEval()
        scores = nlgeval.compute_metrics(hypothesis=predictions, references=[[label] for label in labels])
        return scores
    
    # Load pretrained model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        # from_tf=bool(".ckpt" in model_args.model_name_or_path),
        # config=config,
        cache_dir=model_args.cache_dir,
        # revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
    )
    model = get_peft_model(model, peft_config)
    
    # Disable_Dropout_modified = True
    # if Disable_Dropout_modified:
    #     model = disable_dropout_modified(model)
    #
    # Replace_Dropout_modified = True
    # if Replace_Dropout_modified:
    #     model = replace_dropout_modified(model, modified_dropout=peft_args_modified.modified_dropout, )
    
    # model.print_trainable_parameters()
    print(model)
    params_num = 0
    for k, v in model.named_parameters():
        if v.requires_grad and "lora_" in k:
            print(k, v.shape)
            params_num += v.numel()
    print(f"The number of trainable parameters in LoRA: {params_num}")
    
    # Initialize our Trainer
    Enable_Trainer_modified = False
    if Enable_Trainer_modified:
        trainer = Trainer_modified(
            model=model,
            args=training_args,
            train_dataset=train_data if training_args.do_train else None,
            eval_dataset=valid_data if training_args.do_eval else None,
            compute_metrics=compute_metrics,
            # tokenizer=tokenizer,
            # data_collator=data_collator,
            # modified_aug_loss=peft_args_modified.modified_aug_loss,
            # modified_aug_loss_weight=peft_args_modified.modified_aug_loss_weight,
        )
    else:
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_data if training_args.do_train else None,
            eval_dataset=train_data if training_args.do_eval else None,
            
            # train_dataset=train_dataset if training_args.do_train else None,
            # eval_dataset=eval_dataset if training_args.do_eval else None,
            # compute_metrics=compute_metrics,
            # tokenizer=tokenizer,
            # data_collator=data_collator,
        )
    
    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_data)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_data))
        
        trainer.save_model()  # Saves the tokenizer too for easy upload
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
    
    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        eval_dataset = valid_data
        metrics = trainer.evaluate(eval_dataset=eval_dataset)
        max_eval_samples = (
            data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        )
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
    
    # # Test
    # if training_args.do_predict:
    #     logger.info("*** Test ***")
    #     test_dataset = test_data
    #     test_result = trainer.predict(test_dataset, metric_key_prefix="test")
    #     metrics = test_result.metrics
    #     test_grd_ids = test_result.label_ids
    #     test_pred_ids = test_result.predictions
    #     generated_text = model.generate(input_ids=preds.input_ids, num_beams=5, max_length=50)
    #
    #     metrics["test_samples"] = len(test_dataset)
    #     trainer.log_metrics("test", metrics)
    #     trainer.save_metrics("test", metrics)


def compute_metrics(predictions, labels, tokenizer):
    predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    nlgeval = NLGEval()
    scores = nlgeval.compute_metrics(hypothesis=predictions, references=[[label] for label in labels])
    return scores
    # optimizer = create_adam_optimizer_from_args(lm_net, args)
    #
    # if args.max_step is None:
    #     args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
    #     print('set max_step:', args.max_step)
    #
    # scheduler = create_optimizer_scheduler(optimizer, args)
    # # if args.fp16: # todo: amp fp16
    # #     lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
    #
    # train_step = 0
    # for epoch in itertools.count(start=1):
    #     train_step = train_validate(
    #         lm_net, optimizer, scheduler, train_loader, valid_loader, args,
    #         train_step=train_step, epoch=epoch
    #     )
    
    # # 定义数据预处理函数
    # def preprocess_function(examples):
    #     # 将输入文本编码为数字
    #     inputs = tokenizer(examples['meaning_representation'],  truncation=True, max_length=512,add_eos=True)
    #     # 将目标文本编码为数字
    #     targets = tokenizer(examples['human_reference'], truncation=True, max_length=512)
    #     # 将字典返回为元组
    #     return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'labels': targets['input_ids']}
    #
    # # 对数据集进行预处理
    # dataset = dataset.map(preprocess_function, batched=False)
    #
    # # 定义 Trainer 对象
    # trainer = Trainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=dataset,
    #     data_collator=lambda data: {'input_ids': torch.stack([item['input_ids'] for item in data]),
    #                                'attention_mask': torch.stack([item['attention_mask'] for item in data]),
    #                                'labels': torch.stack([item['labels'] for item in data])},
    #     tokenizer=tokenizer,
    # )
    
    # data_files = {
    #     'train'     : './data/e2e/train.jsonl',
    #     'validation': './data/e2e/valid.jsonl',
    #     'test'      : './data/e2e/test.jsonl'
    # }
    # dataset = load_dataset('json', data_files=data_files)
    # tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium', )
    #
    # def encode_example(example):
    #     input_ids = tokenizer.encode(example['input'], max_length=512, truncation=True, padding='max_length')
    #     target_ids = tokenizer.encode(example['output'], max_length=512, truncation=True, padding='max_length')
    #     return {
    #         'input_ids': input_ids,
    #         'attention_mask': [1] * len(input_ids),
    #         'target_ids': target_ids
    #     }
    #
    # dataset = dataset.map(encode_example, batched=True)
    
    # train_loader = DataLoader(
    #     train_data, batch_size=training_args.per_device_train_batch_size, num_workers=0,
    #     shuffle=False, pin_memory=False, drop_last=True,
    #     # sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed)
    # )
    # valid_loader = DataLoader(
    #     valid_data, batch_size=training_args.per_device_eval_batch_size, num_workers=0,
    #     shuffle=False, pin_memory=False, drop_last=False,
    #     # sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
    # )
    # test_loader = DataLoader(
    #     test_data, batch_size=training_args.per_device_eval_batch_size, num_workers=0,
    #     shuffle=False, pin_memory=False, drop_last=False,
    #     # sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
    # )


if __name__ == '__main__':
    
    FT_Dataset.__getitem__ = __getitem_modified__
    Trainer._set_signature_columns_if_needed = _set_signature_columns_if_needed_modified
    
    main()
