# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from collections import defaultdict
import copy
import json
import os

from os.path import exists, join, isdir
from dataclasses import dataclass, field
import sys
from typing import Optional, Dict, Sequence
import numpy as np
import datasets
from tqdm import tqdm
import logging
# import bitsandbytes as bnb
import pandas as pd
import importlib
from packaging import version
from read_json import sort_dict_by_value, print_dict, divide_dict
from packaging.version import parse

from accelerate.hooks import add_hook_to_module

import torch
import transformers
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.utils.import_utils import is_sagemaker_mp_enabled
from torch.nn.utils.rnn import pad_sequence
import argparse
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    set_seed,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    GenerationConfig,
    LlamaTokenizer, TrainerState, TrainerControl,
)
from datasets import load_dataset, Dataset, load_from_disk
import evaluate

from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    AdaLoraConfig,
    get_peft_model,
    PeftModel
)
from peft.tuners.lora import LoraLayer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

from torch import nn

# from trl import DataCollatorForCompletionOnlyLM

from transformers import Trainer

if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
    IS_SAGEMAKER_MP_POST_1_10 = False

logger = logging.getLogger(__name__)


def init_wandb(project_name):
    os.environ["WANDB_PROJECT"] = project_name


def is_ipex_available():
    def get_major_and_minor_from_version(full_version):
        return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

    _torch_version = importlib.metadata.version("torch")
    if importlib.util.find_spec("intel_extension_for_pytorch") is None:
        return False
    _ipex_version = "N/A"
    try:
        _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
    except importlib.metadata.PackageNotFoundError:
        return False
    torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
    ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
    if torch_major_and_minor != ipex_major_and_minor:
        warnings.warn(
            f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
            f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
        )
        return False
    return True


if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True

logger = logging.getLogger(__name__)

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "<pad>"


@dataclass
class ModelArguments:
    # Optional[str] is equivalent to Union[str, None].
    model_name_or_path: Optional[str] = field(
        default="EleutherAI/pythia-12b"
    )
    trust_remote_code: Optional[bool] = field(
        default=False,
        metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
    )
    use_auth_token: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables using Huggingface auth token from Git Credentials."}
    )

@dataclass
class DataArguments:
    eval_dataset_size: int = field(
        default=1024, metadata={"help": "Size of validation dataset."}
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
                    "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                    "value if set."
        },
    )
    source_max_len: int = field(
        default=1024,
        metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    target_max_len: int = field(
        default=256,
        metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    dataset: str = field(
        default='alpaca',
        metadata={"help": "Which dataset to finetune on. See datamodule for options."}
    )
    dataset_format: Optional[str] = field(
        default=None,
        metadata={"help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]"}
    )
    do_generate: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to generate the output."}
    )

@dataclass
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
    cache_dir: Optional[str] = field(
        default=None
    )
    train_on_source: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to train on the input in addition to the target text."}
    )
    train_without_system: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to train without system responses."}
    )
    mmlu_split: Optional[str] = field(
        default='eval',
        metadata={"help": "The MMLU split to run on"}
    )
    mmlu_dataset: Optional[str] = field(
        default='mmlu-fs',
        metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."}
    )
    do_mmlu_eval: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to run the MMLU evaluation."}
    )
    max_mmlu_samples: Optional[int] = field(
        default=None,
        metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."}
    )
    mmlu_source_max_len: int = field(
        default=2048,
        metadata={"help": "Maximum source sequence length for mmlu."}
    )
    full_finetune: bool = field(
        default=False,
        metadata={"help": "Finetune the entire model without adapters."}
    )
    adam8bit: bool = field(
        default=False,
        metadata={"help": "Use 8-bit adam."}
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=32,
        metadata={"help": "How many bits to use."}
    )
    lora_r: int = field(
        default=32,
        metadata={"help": "Lora R dimension."}
    )
    target_r: int = field(
        default=32,
        metadata={"help": "Target R dimension for AdaLoRA."}
    )
    lora_alpha: float = field(
        default=16,
        metadata={"help": " Lora alpha."}
    )
    lora_dropout: float = field(
        default=0.0,
        metadata={"help": "Lora dropout."}
    )
    dropout: float = field(
        default=0.0,
        metadata={"help": "Dropout rate."}
    )
    max_memory_MB: int = field(
        default=80000,
        metadata={"help": "Free memory per gpu."}
    )
    report_to: str = field(
        default='none',
        metadata={"help": "To use wandb or something else for reporting."}
    )
    run_name: str = field(
        default='none',
        metadata={"help": "To use wandb, and the running name."}
    )
    whether_quantize: bool = field(default=False, metadata={"help": 'Whether to quantize the model.'})
    gamma_learning_ratio: float = field(default=1, metadata={"help": 'The learning rate for the gamma parameters'})
    weight_learning_ratio: float = field(default=1.0, metadata={"help": 'The learning rate for the weight parameters'})

    results_dir: str = field(default='./results', metadata={"help": 'The output dir for logs and checkpoints'})
    output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
    optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'})
    per_device_train_batch_size: int = field(default=1, metadata={
        "help": 'The training batch size per GPU. Increase for better speed.'})
    per_device_eval_batch_size: int = field(default=1, metadata={
        "help": 'The evaluation batch size per GPU. Increase for better speed.'})
    gradient_accumulation_steps: int = field(default=16, metadata={
        "help": 'How many gradients to accumulate before to perform an optimizer step'})
    max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'})
    weight_decay: float = field(default=0.0, metadata={
        "help": 'The L2 weight decay rate of AdamW'})  # use lora dropout instead for regularization if needed
    learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'})
    lora_learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate for lora'})
    remove_unused_columns: bool = field(default=False,
                                        metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
    max_grad_norm: float = field(default=1.0, metadata={
        "help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
    gradient_checkpointing: bool = field(default=False,
                                         metadata={"help": 'Use gradient checkpointing. You want to use this.'})
    do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
    lr_scheduler_type: str = field(default='constant', metadata={
        "help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
    warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'})
    logging_steps: int = field(default=10,
                               metadata={"help": 'The frequency of update steps after which to log the loss'})
    group_by_length: bool = field(default=True, metadata={
        "help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
    save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'})
    save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
    save_total_limit: int = field(default=40,
                                  metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})

    distributed_training: bool = field(default=True, metadata={"help": 'Whether to use distributed training'})

    whether_localization: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to run localization."}
    )

    block_wise: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use block-wise for localization."}
    )

    add_weight: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to add weight."}
    )

    freeze_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to freeze lora."}
    )
    freeze_lora_weights: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to freeze lora weights."}
    )
    freeze_out_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Freeze out dir."}
    )
    peft_path: Optional[str] = field(
        default=None,
        metadata={"help": "Remote PEFT path."}
    )
    lw_init_value: Optional[float] = field(
        default=1.0,
        metadata={"help": "Initial value for lora weight."}
    )
    further_tune_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to further tune lora."}
    )
    further_tune_gamma: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to further tune lora gamma."}
    )
    reset_low_score_layers: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to freeze low score layers."}
    )
    peft_type: Optional[str] = field(
        default='lora',
        metadata={"help": "PEFT type."}
    )
    orth_reg_weight: Optional[float] = field(
        default=0.5,
        metadata={"help": "Orthogonal regularization weight."}
    )
    deltaT: Optional[int] = field(
        default=1,
        metadata={"help": "DeltaT for AdaLoRA."}
    )
    tinit: Optional[int] = field(
        default=500,
        metadata={"help": "Tinit for AdaLoRA."}
    )
    tfinal: Optional[int] = field(
        default=8000,
        metadata={"help": "Tfinal for AdaLoRA."}
    )
    gen_freq: Optional[int] = field(
        default=100,
        metadata={"help": "Generation frequency."}
    )
    freeze_pad_layer_embedding: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to freeze pad layer embedding."}
    )
    train_with_input: Optional[str] = field(
        default=True,
        metadata={"help": "Whether to train with input."}
    )


@dataclass
class GenerationArguments:
    # For more hyperparameters check:
    # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
    # Length arguments
    max_new_tokens: Optional[int] = field(
        default=128,
        metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops"
                          "if predict_with_generate is set."}
    )
    min_new_tokens: Optional[int] = field(
        default=None,
        metadata={"help": "Minimum number of new tokens to generate."}
    )

    # Generation strategy
    do_sample: Optional[bool] = field(default=False)
    num_beams: Optional[int] = field(default=1)
    num_beam_groups: Optional[int] = field(default=1)
    penalty_alpha: Optional[float] = field(default=None)
    use_cache: Optional[bool] = field(default=True)

    # Hyperparameters for logit manipulation
    temperature: Optional[float] = field(default=1.0)
    top_k: Optional[int] = field(default=50)
    top_p: Optional[float] = field(default=1.0)
    typical_p: Optional[float] = field(default=1.0)
    diversity_penalty: Optional[float] = field(default=0.0)
    repetition_penalty: Optional[float] = field(default=1.0)
    length_penalty: Optional[float] = field(default=1.0)
    no_repeat_ngram_size: Optional[int] = field(default=0)


@dataclass
class ExtraArguments:
    project_name: Optional[str] = field(
        default='PEFT'
    )
    mmlu_freq: Optional[int] = field(
        default=200,
        metadata={"help": "Frequency of MMLU evaluation."}
    )
    lora_drop_prob: Optional[float] = field(
        default=0.0,
        metadata={"help": "Stocastic drop LoRA Layers."}
    )

    layer_json_file: Optional[str] = field(
        default=None,
        metadata={"help": "Layer json file for localization."}
    )

    distribution_json_file: Optional[str] = field(
        default=None,
        metadata={"help": "Distribution json file for localization."}
    )

    layer_score_file: Optional[str] = field(
        default=None,
        metadata={"help": "Layer json file for localization."}
    )
    only_permutation: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use only permutation."}
    )
    desired_std_ratio: Optional[float] = field(
        default=1.0,
        metadata={"help": "Desired std for lora score."}
    )
    max_learning_rate_ratio: Optional[float] = field(
        default=1.2,
        metadata={"help": "Max learning rate ratio for lora score."}
    )
    min_learning_rate_ratio: Optional[float] = field(
        default=0.5,
        metadata={"help": "Max learning rate ratio for lora score."}
    )
    desired_mean: Optional[float] = field(
        default=1.0,
        metadata={"help": "Desired mean for lora score."}
    )
    dynamic_lr: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use dynamic learning rate."}
    )

    topK: Optional[int] = field(
        default=0,
        metadata={"help": "TopK for localization."}
    )
    divide_ratio: Optional[float] = field(
        default=0.0,
        metadata={"help": "Divide ratio for localization. If 0.0, use topK."}
    )
    first_ratio: Optional[float] = field(
        default=0.0,
        metadata={"help": "First ratio for localization."}
    )
    last_ratio: Optional[float] = field(
        default=0.0,
        metadata={"help": "Last ratio for localization."}
    )
    ratio_only_for_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use ratio only for lora."}
    )

    total_steps: Optional[int] = field(
        default=2000,
        metadata={"help": "Total steps for localization."}
    )
    whether_unfreeze_normal: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to freeze normal layers."}
    )
    interval: Optional[int] = field(
        default=400,
        metadata={"help": "Interval for training lora"}
    )
    prob_interval: Optional[int] = field(
        default=1,
        metadata={"help": "Interval for training lora"}
    )
    lora_gamma_interval: Optional[int] = field(
        default=5,
        metadata={"help": "Interval for training lora gamma."}
    )
    upperbound: Optional[float] = field(
        default=1.0
    )
    default_localization: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use default localization."}
    )
    Random: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use random Selection task-specific layers."}
    )
    start_iter: Optional[int] = field(
        default=0,
        metadata={"help": "Start Iteration for training lora"}
    )
    all_layers: Optional[str] = field(
        default='ALL',
        metadata={"help": "Whether to use all layers. ALL, ATT, FFN"}
    )
    ratio: Optional[float] = field(
        default=32.0,
        metadata={"help": "Ratio for Norm Lora layers."}
    )
    use_sigmoid: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use sigmoid."}
    )
    sigmoid_init_value: Optional[float] = field(
        default=5.0,
        metadata={"help": "Sigmoid init value."}
    )
    seperate_training: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use seperate training."}
    )
    combine_training: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use combine training."}
    )
    use_gradient_score: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use gradient score."}
    )
    save_interval: Optional[int] = field(
        default=100,
        metadata={"help": "Save interval for gradient score."}
    )


class SKillCallback(transformers.TrainerCallback):
    def __init__(self, divide_ratio, interval):
        self.divide_ratio = divide_ratio
        self.interval = interval

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            epoch = state.epoch
            model = kwargs["model"]
            if epoch == 1:
                score = {}
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_gamma'):
                        score[name] = module.lowerbound_down_times

                score = sort_dict_by_value(score)

                topK = int(len(list(score.keys())) * self.divide_ratio)
                drop_layers = list(score.keys())[:topK]

                # Save the score information
                file_content = {
                    'lora_score': {},
                    'diff_up': {},
                    'diff_down': {},
                    'upperbound_up_times': {},
                    'lowerbound_down_times': {},
                    'lora_gamma_value_list': {},
                    'drop_layers': []
                }
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_score'):
                        file_content['lora_score'][name] = module.lora_score
                        file_content['diff_up'][name] = module.diff_up
                        file_content['diff_down'][name] = module.diff_down
                        file_content['upperbound_up_times'][name] = module.upperbound_up_times
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                    if hasattr(module, 'lora_gamma_value_list'):
                        file_content['lora_gamma_value_list'][name] = module.lora_gamma_value_list

                file_content['drop_layers'] = drop_layers

                with open(os.path.join(args.output_dir, f'lora_score_{state.epoch}.json'), 'w') as fout:
                    fout.write(json.dumps(file_content))

                # Stop localization
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_gamma'):
                        module.init_localization()
                        module.whether_localization = False

                for name, module in model.named_modules():
                    if name in drop_layers and hasattr(module, 'lora_gamma'):
                        module.reset_lora_parameters('default')
                        # Freeze the module
                        for name2, param in module.named_parameters():
                            print(f"Freezing Layer {name}-{name2}!")
                            param.requires_grad = False


class GradientCallback(transformers.TrainerCallback):
    def __init__(self, save_interval):
        self.save_interval = save_interval
        self.gradient_norm_map = {}

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            global_step = state.global_step
            model = kwargs["model"]
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if name not in self.gradient_norm_map:
                        self.gradient_norm_map[name] = []
                    self.gradient_norm_map[name].append((param.grad ** 2).sum().item())

            if global_step % self.save_interval == 0:
                # Save the gradient information
                with open(os.path.join(args.results_dir, f'gradient_score_{state.global_step}.json'), 'w') as fout:
                    fout.write(json.dumps(self.gradient_norm_map))


def get_time_string():
    import time
    return time.strftime("%Y%m%d-%H%M%S")


def mkdir_if_not_exists(path):
    if not exists(path):
        os.makedirs(path)
    # # if already exists, rename the path and create a new one
    # else:
    #     path = path + '_new_' + get_time_string()
    #     if not exists(path):
    #         os.makedirs(path)

    return path


def find_all_linear_names(args, model, include_lm_head=False):
    cls = torch.nn.Linear
    # cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if not include_lm_head:
        if 'lm_head' in lora_module_names:  # needed for 16-bit
            lora_module_names.remove('lm_head')

    return list(lora_module_names)


def find_all_norm_layers_for_llama(args, model):
    cls = LlamaRMSNorm
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            # print(name)
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    return list(lora_module_names)


def unfreeze_all_normal_layers(model):
    for name, module in model.named_modules():
        if 'norm' in name:
            if hasattr(module, 'weight'):
                module.weight.requires_grad = True
            if hasattr(module, 'bias'):
                module.bias.requires_grad = True

class FreezerCallback(transformers.TrainerCallback):
    def __init__(self, prob_dict, full_finetune, interval):
        super().__init__()
        self.prob_dict = prob_dict
        self.full_finetune = full_finetune
        self.interval = interval

    def get_key(self, name):
        if not self.full_finetune:
            if 'lora' in name:
                name_split = name.split('.')
                name_split = name_split[:-3]
                name = '.'.join(name_split)
        else:
            name_split = name.split('.')
            name_split = name_split[:-1]
            name = '.'.join(name_split)
            name = 'base_model.model.' + name
        return name

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % self.interval == 0:
            model = kwargs["model"]
            cnt = 0
            requires_grad = {}
            for key, value in self.prob_dict.items():
                if random.uniform(0, 1) <= value:
                    requires_grad[key] = True
                    cnt += 1
                else:
                    requires_grad[key] = False

            for name, param in model.named_parameters():
                key = self.get_key(name)
                if key in self.prob_dict:
                    param.requires_grad = requires_grad[key]


class GenerationCallback(transformers.TrainerCallback):
    def __init__(self, train_dataset, model_id, dataset_name, max_new_tokens=64):
        super().__init__()
        self.train_dataset = train_dataset
        self.model_id = model_id
        self.gen_config = GenerationConfig.from_pretrained(model_id)
        self.max_new_tokens = max_new_tokens
        self.dataset = dataset_name
        self.prepare_for_generation(self.dataset)

    def prepare_for_generation(self, dataset):
        if dataset == 'LIMA' or dataset == 'no_robots' or dataset in ['alpaca', 'alpaca-clean', 'alpaca-gpt4']:
            # self.prompt = (
            #     "<s> User: {user_question} \n\n Assistant: "
            # )
            if 'Llama' in self.model_id or 'llama' in self.model_id or 'Mistral' in self.model_id:
                self.prompt = '<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
                system = ("A chat between a curious human and an artificial intelligence assistant. "
                 "The assistant gives helpful, detailed, and polite answers to the human's questions.")
                self.prompt = self.prompt.format(system=system)
                self.prompt += '{user_question} [/INST] '
                self.eval_data = None
            elif 'gemma' in self.model_id:
                Role1 = "<bos><start_of_turn>user\n"
                Role2 = "<end_of_turn>\n<start_of_turn>model\n"

                self.prompt = Role1 + '{user_question} ' + Role2
                self.eval_data = None
        else:
            raise ValueError(f"Dataset {dataset} not supported.")

    def generate(self, model, tokenizer, user_question):
        if self.prompt is not None:
            prompt = self.prompt.format(user_question=user_question)
            inputs = tokenizer(self.prompt.format(user_question=user_question), return_tensors="pt", add_special_tokens=False).to('cuda')
        else:
            raise ValueError(f"Prompt of dataset {self.dataset} not supported.")

        self.gen_config.max_new_tokens = self.max_new_tokens
        # self.gen_config.do_sample = False
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                # max_new_tokens=self.max_new_tokens,
                generation_config=self.gen_config
            )

        inputs = inputs['input_ids']
        text = tokenizer.decode(outputs[0][len(inputs[0]):], spaces_between_special_tokens=False)

        print('----------------------------------------------------------------------------')
        print(f'{prompt + text}')

        return prompt, text

    def generate_wiht_input(self, args, model, tokenizer, ouput_path, iteration):
        model.eval()
        if self.eval_data is not None:
            user_question_list = self.eval_data
        else:
            user_question_list = ["How can I improve my time management skills?",
                                  "What are the main differences between Python and JavaScript programming languages?",
                                  "Can you explain the basics of quantum computing?",
                                  "What are the differences between plant-based and animal-based protein sources?",
                                  "What are the most effective strategies for conflict resolution in the workplace?",
                                  "How can governments utilize fiscal and monetary policies to combat economic recessions?",
                                  "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?",
                                  "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.",
                                  "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?",
                                  "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.",
                                  "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?",
                                  "How can you determine if a person is genuinely interested in a conversation or simply being polite?",
                                  "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?",
                                  "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.",
                                  "What if the Black Death had not occurred in the 14th century?",
                                  "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).",
                                  "Write a script for a YouTube video exploring the history and cultural significance of jazz."]

        save_file = []
        save_path = join(ouput_path, f"output_{iteration}.json")

        for user_question in user_question_list:
            prompt, text = self.generate(model, tokenizer, user_question)
            content = {"question": prompt, "answer": text}
            save_file.append(content)

        # if args.report_to == 'wandb':
        #     #print("Logging to wandb")
        #     self.prompt_table(save_file, table_name=f"output_{iteration}")

        with open(save_path, "w") as fout:
            fout.write(json.dumps(save_file))

        model.train()

    def save_model(self, args, state, kwargs, suffix=None):
        # print('Saving PEFT checkpoint...')
        if state.best_model_checkpoint is not None:
            checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
        else:
            if suffix is None:
                suffix = f"{state.global_step}"

            checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-suffix")
            if args.freeze_lora or args.freeze_lora_weights:
                checkpoint_folder = os.path.join(args.freeze_out_dir, f"{PREFIX_CHECKPOINT_DIR}-suffix")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")

        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)

        if args.add_weight:
            model = kwargs["model"]
            file_content = {'weight_score': {}, }
            for name, module in model.named_modules():
                if hasattr(module, 'lora_weight'):
                    file_content['weight_score'][name] = module.lora_weight['default'].item()

            with open(os.path.join(args.output_dir, f'weight_score_{state.epoch}.json'), 'w') as fout:
                fout.write(json.dumps(file_content))

        if args.whether_localization:
            model = kwargs["model"]
            if args.block_wise:
                file_content = {'block_score': {},
                                'lora_score': {},
                                'diff_up': {},
                                'diff_down': {},
                                'upperbound_up_times': {},
                                'lowerbound_down_times': {},
                                'lora_gamma_value_list': {},
                                }
                for name, module in model.named_modules():
                    if hasattr(module, 'block_score'):
                        file_content['block_score'][name] = module.block_score
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_score'):
                        file_content['lora_score'][name] = module.lora_score
                        file_content['diff_up'][name] = module.diff_up
                        file_content['diff_down'][name] = module.diff_down
                        file_content['upperbound_up_times'][name] = module.upperbound_up_times
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                    if hasattr(module, 'lora_gamma_value_list'):
                        file_content['lora_gamma_value_list'][name] = module.lora_gamma_value_list
            else:
                # Save the score information
                file_content = {
                    'lora_score': {},
                    'diff_up': {},
                    'diff_down': {},
                    'upperbound_up_times': {},
                    'lowerbound_down_times': {},
                    'lora_gamma_value_list': {},
                }
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_score'):
                        file_content['lora_score'][name] = module.lora_score
                        file_content['diff_up'][name] = module.diff_up
                        file_content['diff_down'][name] = module.diff_down
                        file_content['upperbound_up_times'][name] = module.upperbound_up_times
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                    if hasattr(module, 'lora_gamma_value_list'):
                        file_content['lora_gamma_value_list'][name] = module.lora_gamma_value_list

            with open(os.path.join(args.results_dir, f'lora_score_{state.global_step}.json'), 'w') as fout:
                fout.write(json.dumps(file_content))

            with open(os.path.join(args.results_dir, f'lowerbound_{state.global_step}.json'), 'w') as fout:
                fout.write(json.dumps(file_content['lowerbound_down_times']))

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            super().on_epoch_end(args, state, control, **kwargs)

            self.generate_wiht_input(args, kwargs["model"], kwargs["tokenizer"], args.results_dir, state.epoch)

    # def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
    #     if state.is_world_process_zero:
    #         super().on_train_end(args, state, control, **kwargs)
    #         print("Training is done!")
    #         self.generate_wiht_input(args, kwargs["model"], kwargs["tokenizer"], args.results_dir, state.epoch)
    #         print("Generation is done!")
    #         print("Saving the model...")
    # def touch(fname, times=None):
    #     with open(fname, 'a'):
    #         os.utime(fname, times)
    #
    # touch(join(args.output_dir, 'completed'))
    # self.save_model(args, state, kwargs, suffix='final_v2')


class SavePeftModelCallback(transformers.TrainerCallback):
    def __init__(self, train_dataset=None):
        self.train_dataset = train_dataset
        self.interval = 100000

    def save_model(self, args, state, kwargs, suffix=None):
        # print('Saving PEFT checkpoint...')
        if state.best_model_checkpoint is not None:
            checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
        else:
            if suffix is None:
                suffix = f"{state.global_step}"

            checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{suffix}")
            if args.freeze_lora or args.freeze_lora_weights:
                checkpoint_folder = os.path.join(args.freeze_out_dir, f"{PREFIX_CHECKPOINT_DIR}-{suffix}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")

        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)

        if args.add_weight:
            model = kwargs["model"]
            file_content = {'weight_score': {}, }
            for name, module in model.named_modules():
                if hasattr(module, 'lora_weight'):
                    file_content['weight_score'][name] = module.lora_weight['default'].item()

            with open(os.path.join(args.output_dir, f'weight_score_{state.epoch}.json'), 'w') as fout:
                fout.write(json.dumps(file_content))

        if args.whether_localization:
            model = kwargs["model"]
            if args.block_wise:
                file_content = {'block_score': {},
                                'lora_score': {},
                                'diff_up': {},
                                'diff_down': {},
                                'upperbound_up_times': {},
                                'lowerbound_down_times': {},
                                'lora_gamma_value_list': {},
                                }
                for name, module in model.named_modules():
                    if hasattr(module, 'block_score'):
                        file_content['block_score'][name] = module.block_score
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_score'):
                        file_content['lora_score'][name] = module.lora_score
                        file_content['diff_up'][name] = module.diff_up
                        file_content['diff_down'][name] = module.diff_down
                        file_content['upperbound_up_times'][name] = module.upperbound_up_times
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                    if hasattr(module, 'lora_gamma_value_list'):
                        file_content['lora_gamma_value_list'][name] = module.lora_gamma_value_list
            else:
                # Save the score information
                file_content = {
                    'lora_score': {},
                    'diff_up': {},
                    'diff_down': {},
                    'upperbound_up_times': {},
                    'lowerbound_down_times': {},
                    'lora_gamma_value_list': {},
                }
                for name, module in model.named_modules():
                    if hasattr(module, 'lora_score'):
                        file_content['lora_score'][name] = module.lora_score
                        file_content['diff_up'][name] = module.diff_up
                        file_content['diff_down'][name] = module.diff_down
                        file_content['upperbound_up_times'][name] = module.upperbound_up_times
                        file_content['lowerbound_down_times'][name] = module.lowerbound_down_times
                    if hasattr(module, 'lora_gamma_value_list'):
                        file_content['lora_gamma_value_list'][name] = module.lora_gamma_value_list

            with open(os.path.join(args.results_dir, f'lora_score_{state.global_step}.json'), 'w') as fout:
                fout.write(json.dumps(file_content))

            with open(os.path.join(args.results_dir, f'lowerbound_{state.global_step}.json'), 'w') as fout:
                fout.write(json.dumps(file_content['lowerbound_down_times']))

        return peft_model_path

    def on_save(self, args, state, control, **kwargs):
        self.save_model(args, state, kwargs)
        return control

    def on_train_end(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            def touch(fname, times=None):
                with open(fname, 'a'):
                    os.utime(fname, times)

            touch(join(args.output_dir, 'completed'))
            path = self.save_model(args, state, kwargs, suffix='suffix')

            # def read_model(args, peft_path=''):
            #     if torch.cuda.is_available():
            #         n_gpus = torch.cuda.device_count()
            #     if is_ipex_available() and torch.xpu.is_available():
            #         n_gpus = torch.xpu.device_count()
            #
            #     max_memory = f'{args.max_memory_MB}MB'
            #     max_memory = {i: max_memory for i in range(n_gpus)}
            #
            #     # print(f'loading base model {args.model_name_or_path}...')
            #     compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
            #
            #     model = AutoModelForCausalLM.from_pretrained(
            #         "meta-llama/Llama-2-7b-hf",
            #         cache_dir=args.cache_dir,
            #         device_map='cpu',
            #         # max_memory=max_memory,
            #         torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
            #     )
            #
            #     if compute_dtype == torch.float16 and args.bits == 4:
            #         if torch.cuda.is_bf16_supported():
            #             print('=' * 80)
            #             print('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
            #             print('=' * 80)
            #
            #     if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()):
            #         compute_dtype = torch.bfloat16
            #         print('Intel XPU does not support float16 yet, so switching to bfloat16')
            #
            #     # setattr(model, 'model_parallel', True)
            #     # setattr(model, 'is_parallelizable', True)
            #
            #     model.config.torch_dtype = (
            #         torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
            #
            #     tokenizer = AutoTokenizer.from_pretrained(
            #         "meta-llama/Llama-2-7b-hf",
            #         cache_dir=args.cache_dir,
            #         padding_side="right",
            #         use_fast=False,
            #         add_eos_token=True,
            #         add_bos_token=True,
            #         # Needed for HF name change
            #     )
            #     if tokenizer._pad_token is None:
            #         smart_tokenizer_and_embedding_resize(
            #             args,
            #             special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            #             tokenizer=tokenizer,
            #             model=model,
            #         )
            #
            #     print("Loading adapters from checkpoint.")
            #     if len(peft_path) > 0:
            #         model = PeftModel.from_pretrained(model, peft_path, is_trainable=False)
            #     # model = PeftModel.from_pretrained(model, join(args.peft_path), is_trainable=False)
            #
            #     for name, module in model.named_modules():
            #         if isinstance(module, LoraLayer):
            #             # module = module.to(torch.bfloat16)
            #             if args.bf16:
            #                 module = module.to(torch.bfloat16)
            #
            #     return model, tokenizer
            #
            # model_save, tokenizer = read_model(args, peft_path=path)
            # for n1, p1 in model_save.named_parameters():
            #     for n2, p2 in kwargs['model'].named_parameters():
            #         if n1 == n2:
            #             original_dtype = p1.dtype
            #             p1.to(p2.device)
            #             diff = (p1 - p2).abs().sum().item()
            #             p1.to(original_dtype)
            #             logger.info(f"Diff for {n1}: {diff}")


class AdaLoRACallback(transformers.TrainerCallback):
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        model = kwargs["model"]
        model.update_and_allocate(state.global_step)


def get_freeze_layers_from_constant(args):
    fixed_layers = []
    with open(args.layer_json_file, 'r') as f:
        keys_list = json.load(f)
        keys = list(keys_list.keys())[:args.topK]
        for key in keys:
            fixed_layers.extend(keys_list[key])

    return fixed_layers


def get_layers_prob_distribution(args):
    prob_dict = {}
    if args.distribution_json_file is not None:
        with open(args.distribution_json_file, 'r') as f:
            prob_dict = json.load(f)
            topK = int(len(list(prob_dict.keys())) * args.divide_ratio)
            values = list(prob_dict.values())
            sum_of_v = sum(values)
            for key in prob_dict.keys():
                prob_dict[key] = prob_dict[key] * topK / sum_of_v

    return prob_dict


def get_freeze_layers(args):
    if args.layer_json_file is not None and not args.further_tune_lora:
        # if not args.full_finetune:
        if 'mix' in args.layer_json_file or 'lowerbound' in args.layer_json_file:
            largest_first = True
        elif 'gs' in args.layer_json_file:
            largest_first = False
        elif 'constant' in args.layer_json_file:
            fixed_layers = get_freeze_layers_from_constant(args)
            return fixed_layers
        else:
            raise ValueError(f'Unknown layer json file: {args.layer_json_file}')

        with open(args.layer_json_file, 'r') as f:
            score = json.load(f)
            topK = int(len(list(score.keys())) * args.divide_ratio)
            if args.first_ratio <= 0.0 and args.last_ratio <= 0.0:
                score = sort_dict_by_value(score, largest_first=largest_first)
                fixed_layers = list(score.keys())

            if args.Random:
                random.shuffle(fixed_layers)
                fixed_layers = fixed_layers[:topK]
            elif args.first_ratio > 0.0:
                fixed_layers = list(score.keys())[:topK]
            elif args.last_ratio > 0.0:
                fixed_layers = list(score.keys())[-topK:]
            else:
                fixed_layers = fixed_layers[:topK]
    else:
        fixed_layers = []

    return fixed_layers


def get_accelerate_model(args, checkpoint_dir):
    fixed_layers = get_freeze_layers(args)

    if torch.cuda.is_available():
        n_gpus = torch.cuda.device_count()
    if is_ipex_available() and torch.xpu.is_available():
        n_gpus = torch.xpu.device_count()

    max_memory = f'{args.max_memory_MB}MB'
    max_memory = {i: max_memory for i in range(n_gpus)}
    #
    #
    # if we are in a distributed setting, we need to set the device map and max memory per device
    # if os.environ.get('LOCAL_RANK') is not None:
    #     local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    #     device_map = {'': local_rank}
    #     max_memory = {'': max_memory[local_rank]}

    local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    device_map = {'': local_rank}
    using_device_map = True
    # read json
    if not args.distributed_training:
        device_map = "auto"
        using_device_map = True
    else:
        if args.deepspeed is not None:
            with open(args.deepspeed, 'r') as f:
                deepspeed_config = json.load(f)
                if deepspeed_config['zero_optimization']['stage'] == 3:
                    using_device_map = False

    if args.full_finetune: assert args.bits in [16, 32]


    # print(f'loading base model {args.model_name_or_path}...')
    compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
    if args.whether_quantize:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            cache_dir=args.cache_dir,
            load_in_4bit=args.bits == 4,
            load_in_8bit=args.bits == 8,
            device_map=device_map,
            max_memory=max_memory,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=args.bits == 4,
                load_in_8bit=args.bits == 8,
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=args.double_quant,
                bnb_4bit_quant_type=args.quant_type,
            ),
            torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
        )
    else:
        if args.full_finetune and not args.do_train and checkpoint_dir is not None:
            if using_device_map:
                model = AutoModelForCausalLM.from_pretrained(
                    checkpoint_dir,
                    cache_dir=args.cache_dir,
                    device_map=device_map,
                    # device_map=device_map,
                    # max_memory=max_memory
                    torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
                    attention_dropout = args.dropout,
                )
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    checkpoint_dir,
                    cache_dir=args.cache_dir,
                    torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
                    attention_dropout=args.dropout,
                )
        else:
            if using_device_map:
                model = AutoModelForCausalLM.from_pretrained(
                    args.model_name_or_path,
                    cache_dir=args.cache_dir,
                    device_map=device_map,
                    # device_map=device_map,
                    # max_memory=max_memory,
                    torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
                    attention_dropout=args.dropout,
                )
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    args.model_name_or_path,
                    cache_dir=args.cache_dir,
                    # device_map={"": int(os.environ.get("LOCAL_RANK") or 0)},
                    # device_map=device_map,
                    # max_memory=max_memory,
                    torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
                    attention_dropout=args.dropout,
                )

    if compute_dtype == torch.float16 and args.bits == 4:
        if torch.cuda.is_bf16_supported():
            print('=' * 80)
            print('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
            print('=' * 80)

    if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()):
        compute_dtype = torch.bfloat16
        # print('Intel XPU does not support float16 yet, so switching to bfloat16')

    if args.deepspeed is None:
        setattr(model, 'model_parallel', True)
        setattr(model, 'is_parallelizable', True)

    model.config.torch_dtype = (torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        cache_dir=args.cache_dir,
        padding_side="right",
        use_fast=False,
        add_eos_token=True,
        add_bos_token=True,
        # Needed for HF name change
    )

    if tokenizer._pad_token is None and tokenizer._unk_token is not None:
        # tokenizer.add_special_tokens(
        #     {"pad_token": tokenizer.unk_token}
        # )
        tokenizer.pad_token_id = tokenizer.unk_token_id
        model.config.pad_token_id = tokenizer.unk_token_id
    elif tokenizer._pad_token is None and tokenizer._unk_token is None:
        smart_tokenizer_and_embedding_resize(
            args,
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )


    # for name, param in model.named_parameters():
    #     print(f'{name}: {param.requires_grad}')

    # print(f'len(tokenizer): {len(tokenizer)}')
    # if tokenizer._pad_token is None:
    #     smart_tokenizer_and_embedding_resize(
    #         args,
    #         special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
    #         tokenizer=tokenizer,
    #         model=model,
    #     )

    # tokenizer.add_special_tokens({
    #     "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
    #     "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id)
    # })

    # print(f'len(tokenizer): {len(tokenizer)}')
    # print(f'eos: {model.config.eos_token_id}')
    # print(f'bos: {model.config.bos_token_id}')
    # print(f'pad: {model.config.pad_token_id}')

    if args.do_train:
        if not args.full_finetune:
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

            if args.peft_path is not None:
                checkpoint_dir = args.peft_path

            if checkpoint_dir is not None:
                # print("Loading adapters from checkpoint.")

                model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)

                if args.freeze_lora:
                    for name, param in model.named_parameters():
                        if 'lora_weight' not in name:
                            param.requires_grad = False
                            # print(f'!!! Freeze parameter: {name}')
                        else:
                            param.data.fill_(args.lw_init_value)
                            # print(f'!!! Unfreeze parameter: {name}: {param}')

                    for name, m in model.named_modules():
                        if hasattr(m, 'lora_weight'):
                            m.add_weight = True
                            if args.seperate_training:
                                m.seperate_training = True
                            m.interval = args.interval

                if args.freeze_lora_weights:
                    for name, param in model.named_parameters():
                        param.requires_grad = False

                    for name, param in model.named_parameters():
                        if 'lora' in name and 'lora_weight' not in name and 'lora_gamma' not in name and 'lora_buffer' not in name:
                            param.requires_grad = True
                            # print(f'!!! Unfreeze parameter: {name}')

                    for name, param in model.named_parameters():
                        if 'lora_weight' in name:
                            param.requires_grad = False
                            # print(f'!!! Freeze parameter: {name}')

                    # make sure the add_weight is True
                    for name, m in model.named_modules():
                        if hasattr(m, 'lora_weight'):
                            m.add_weight = True

                if args.further_tune_lora:
                    for name, param in model.named_parameters():
                        param.requires_grad = False

                    for name, param in model.named_parameters():
                        if 'lora' in name and 'lora_weight' not in name and 'lora_gamma' not in name and 'lora_buffer' not in name:
                            param.requires_grad = True
                            # print(f'!!! Unfreeze parameter: {name}')

                    for name, param in model.named_parameters():
                        if 'lora_weight' in name:
                            param.requires_grad = False
                            # print(f'!!! Freeze parameter: {name}')

                    # make sure the add_weight is False
                    for name, m in model.named_modules():
                        if hasattr(m, 'lora_weight'):
                            m.add_weight = False
                    # read file
                    with open(args.layer_json_file, 'r') as f:
                        score = json.load(f)

                        if "weight_score" in args.layer_json_file:
                            key = 'weight_score'
                            largest_first = False
                        else:
                            raise ValueError(f'Invalid layer json file: {args.layer_json_file}')

                        score = sort_dict_by_value(score[key], largest_first=largest_first)

                        fixed_layers = list(score.keys())
                        fixed_layers = fixed_layers[:args.topK]

                        for name, m in model.named_modules():
                            if name in fixed_layers:
                                m.fix_param()
                                if args.reset_low_score_layers:
                                    m.reset_lora_parameters('default')

                        fixed_layers = []

                if args.further_tune_gamma:
                    assert args.whether_localization

                    # make sure the add_weight is False
                    for name, m in model.named_modules():
                        if hasattr(m, 'whether_localization'):
                            m.whether_localization = True
                            m.add_weight = False
                            m.init_localization(lora_gamma_interval=args.lora_gamma_interval, use_sigmoid=args.use_sigmoid,
                                                default_localization=args.default_localization, sigmoid_init_value=args.sigmoid_init_value)

                    for name, param in model.named_parameters():
                        if 'lora_gamma' not in name:
                            param.requires_grad = False
                            # print(f'!!! Freeze parameter: {name}')
                        else:
                            param.requires_grad = True
                            # print(f'!!! Unfreeze parameter: {name}')

                # for name, param in model.named_parameters():
                #    print(f'{name}: {param.requires_grad}')
                #    if 'lora_gamma' in name:
                #        print(f'{name}: {param}')
                #
                # for name, m in model.named_modules():
                #     if hasattr(m, 'lora_gamma_interval'):
                #         print(f'{name}: {m.lora_gamma_interval}')
                #
                # print('Test')
            else:
                # print(f'adding LoRA modules...')

                if args.all_layers == 'ALL':
                    modules = find_all_linear_names(args, model)
                    modules = modules
                elif args.all_layers == 'ALL+LM':
                    modules = find_all_linear_names(args, model) + ['lm_head']
                elif args.all_layers == 'ALL+EMBD':
                    modules = find_all_linear_names(args, model) + ['lm_head', 'embed_tokens']
                    # modules = modules + ['lm_head', 'embed_tokens']
                elif args.all_layers == 'ALL+NORM':
                    modules = find_all_linear_names(args, model) + find_all_norm_layers_for_llama(args, model)
                elif args.all_layers == 'ALL+NORM+LM':
                    modules = find_all_linear_names(args, model, include_lm_head=True) + find_all_norm_layers_for_llama(
                        args, model)
                elif args.all_layers == 'ATT':
                    modules = ['k_proj', 'v_proj', 'q_proj', 'o_proj']
                elif args.all_layers == 'FFN':
                    modules = ['down_proj', 'up_proj', 'gate_proj']
                elif args.all_layers == 'ATT_2':
                    modules = ['k_proj', 'v_proj', 'q_proj']
                else:
                    raise ValueError(f'Invalid all_layers: {args.all_layers}')
                if args.peft_type == 'lora':
                    config = LoraConfig(
                        block_wise=args.block_wise,
                        sigmoid_init_value=args.sigmoid_init_value,
                        use_sigmoid=args.use_sigmoid,
                        ratio=args.ratio,
                        drop_prob=args.lora_drop_prob,
                        default_localization=args.default_localization,
                        interval=args.interval,
                        start_iter=args.start_iter,
                        lora_gamma_interval=args.lora_gamma_interval,
                        upperbound=args.upperbound,
                        whether_localization=args.whether_localization,
                        r=args.lora_r,
                        lora_alpha=args.lora_alpha,
                        target_modules=modules,
                        lora_dropout=args.lora_dropout,
                        bias="none",
                        task_type="CAUSAL_LM",
                        add_weight=args.add_weight,
                    )
                elif args.peft_type == 'adalora':
                    config = AdaLoraConfig(
                        target_r=args.target_r,
                        init_r=args.lora_r,
                        tinit=args.tinit,
                        tfinal=args.tfinal,
                        total_step=args.total_steps,
                        orth_reg_weight=args.orth_reg_weight,
                        deltaT=args.deltaT,
                        lora_alpha=args.lora_alpha,
                        target_modules=modules,
                        lora_dropout=args.lora_dropout,
                    )
                else:
                    raise ValueError(f'Invalid peft type: {args.peft_type}')

                model = get_peft_model(model, config)
        else:
            print('Full finetune, no adapters.')

    else:
        if not args.full_finetune and checkpoint_dir is not None:
            # print("Loading adapters from checkpoint.")
            model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model'), is_trainable=False)

    # for name, param in model.named_parameters():
    #     print(f'{name}: {param.dtype}')

    # Set the model to the correct dtype
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            # module = module.to(torch.bfloat16)
            if args.bf16:
                module = module.to(torch.bfloat16)
                if args.whether_localization:
                    module.lora_gamma = module.lora_gamma.to(torch.float32)
                    module.lora_buffer = module.lora_buffer.to(torch.float32)
        # if 'norm' in name:
        #     module = module.to(torch.float32)
        if 'lm_head' in name or 'embed_tokens' in name:
            if hasattr(module, 'weight'):
                if args.bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    if not args.further_tune_lora:
        if not args.full_finetune:
            if len(fixed_layers) > 0:
                # if args.layer_json_file is not None and ('mix' in args.layer_json_file or 'lowerbound' in args.layer_json_file):
                for name, param in model.named_parameters():
                    if any([layer in name for layer in fixed_layers]):
                        param.requires_grad = False
                        if local_rank == 0:
                            print(f'!!! Freeze parameter: {name}')
                        # print(f'!!! Freeze parameter: {name}')
                    if 'lora_gamma' in name or 'lora_buffer' in name:
                        param.requires_grad = False
                        if local_rank == 0:
                            print(f'!!! Freeze parameter: {name}')
            # else:
            #     if args.block_wise:
            #         for name, module in model.named_modules():
            #             if hasattr(module, 'block_buffer'):
            #                 module.add_block_params(use_sigmoid=args.use_sigmoid, sigmoid_init_value=args.sigmoid_init_value,
            #                                         whether_localization=args.whether_localization, start_iter=args.start_iter,
            #                                         lora_gamma_interval=args.lora_gamma_interval)
            #
            #         # Fixed the non-skilled blocks
            #         for name, module in model.named_modules():
            #             if hasattr(module, 'block_buffer'):
            #                 if name in fixed_layers:
            #                     module.disable_block_lora(LoraLayer)
            #
            #     else:
            #         # Fixed the non-skilled layers
            #         for name, module in model.named_modules():
            #             if name in fixed_layers:
            #                 module.fixed()

            # if args.whether_unfreeze_normal:
            #     unfreeze_all_normal_layers(model)
        else:
            # freeze the embedding layer
            for name, module in model.named_modules():
                if 'LM' not in args.all_layers:
                    if 'lm_head' in name:
                        module.weight.requires_grad = False

                if 'EMBED' not in args.all_layers:
                    if 'embed_tokens' in name:
                        module.weight.requires_grad = False

            if len(fixed_layers) > 0:
                for i in range(len(fixed_layers)):
                    fixed_layers[i] = '.'.join(fixed_layers[i].split('.')[2:])
                    # freeze the fixed layers
                    for name, param in model.named_parameters():
                        if any([layer in name for layer in fixed_layers]):
                            param.requires_grad = False
                            if local_rank == 0:
                                print(f'!!! Freeze parameter: {name}')
                            # print(f'!!! Freeze parameter: {name}')

            if args.first_ratio > 0.0:
                if not args.ratio_only_for_lora:
                    all_params = list(model.parameters())
                else:
                    all_params = []
                    for name, param in model.named_parameters():
                        keys = find_all_linear_names(args, model, include_lm_head=False)
                        if any([key in name for key in keys]):
                            all_params.append(param)

                n_params = len(all_params)
                n_first = int(n_params * args.first_ratio)
                freeze_params = all_params[:n_first]
                for p in freeze_params:
                    p.requires_grad = False

            if args.last_ratio > 0.0:
                if not args.ratio_only_for_lora:
                    all_params = list(model.parameters())
                else:
                    all_params = []
                    for name, param in model.named_parameters():
                        keys = find_all_linear_names(args, model, include_lm_head=False)
                        if any([key in name for key in keys]):
                            all_params.append(param)

                n_params = len(all_params)
                n_last = int(n_params * args.last_ratio)
                freeze_params = all_params[-n_last:]
                for p in freeze_params:
                    p.requires_grad = False

            # if len(tuned_layers) > 0:
            #     # first, freeze everything
            #     for name, param in model.named_parameters():
            #         param.requires_grad = False
            #
            #     for i in range(len(tuned_layers)):
            #         tuned_layers[i] = '.'.join(tuned_layers[i].split('.')[2:])
            #     # only unfreeze the tuned layers
            #     for name, param in model.named_parameters():
            #         if any([layer in name for layer in tuned_layers]):
            #             param.requires_grad = True
            #             param.requires_grad = True
            #
            #     if args.whether_unfreeze_normal:
            #         unfreeze_all_normal_layers(model)

    # for name, param in model.named_parameters():
    #     print(f'{name}: {param.requires_grad}')

    return model, tokenizer


def read_layer_score_file(args):
    if args.layer_score_file is not None:
        if args.block_wise:
            key = 'block_score'
        else:
            key = 'lora_score'
        with open(args.layer_score_file, 'r') as f:
            layer_json = json.load(f)
            lora_score = sort_dict_by_value(layer_json[key], largest_first=True)

    return lora_score


def create_optimizer_based_on_lora_score(args, model):
    lora_score = read_layer_score_file(args)
    if args.only_permutation:
        lora_score = {key: i for i, key in enumerate(lora_score.keys())}
    else:
        lora_score = {key: args.total_steps - value for key, value in lora_score.items()}

    mean_value = np.mean(list(lora_score.values()))
    std_value = np.std(list(lora_score.values()))

    original_cv = std_value / mean_value

    desired_mean = args.desired_mean
    desired_std_ratio = args.desired_std_ratio
    new_std = original_cv * desired_mean * desired_std_ratio

    for key, value in lora_score.items():
        lora_score[key] = (value - mean_value) / std_value
        lora_score[key] = lora_score[key] * new_std + desired_mean

    # clip the learning rate
    for key, value in lora_score.items():
        if value < 0.0:
            lora_score[key] = 0.0
            # print(f'!!! Freeze layer: {key}')
        elif value > args.max_learning_rate_ratio:
            lora_score[key] = args.max_learning_rate_ratio
            # print(f'!!! Clip layer: {key}')
        elif value < args.min_learning_rate_ratio:
            lora_score[key] = 0.1
            # print(f'!!! Clip layer: {key}')

    # desired_mean = args.desired_mean
    # desired_std = args.desired_std
    #
    # for key, value in lora_score.items():
    #     lora_score[key] = (value - mean_value) / std_value
    #     lora_score[key] = lora_score[key] * desired_std + desired_mean

    for key, value in lora_score.items():
        if value < 0.0:
            lora_score[key] = 0.0
            # print(f'!!! Freeze layer: {key}')

    if args.full_finetune:
        lora_score = {'.'.join(key.split('.')[2:]): value for key, value in lora_score.items()}

    group_parameters = {key: [] for key, value in lora_score.items()}

    for key, value in lora_score.items():
        for n, p in model.named_parameters():
            if p.requires_grad:
                if key in n:
                    group_parameters[key].append(p)

    optimizer_grouped_parameters = []
    for key, value in group_parameters.items():
        optimizer_grouped_parameters.append(
            {
                "params": value,
                "weight_decay": args.weight_decay,
                "lr": args.learning_rate * lora_score[key],
            }
        )

    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)

    return optimizer


def print_trainable_parameters(args, model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    if args.bits == 4: trainable_params /= 2
    if args.local_rank == 0:
        print(
            f"trainable params: {trainable_params} || "
            f"all params: {all_param} || "
            f"trainable: {100 * trainable_params / all_param}"
        )

def smart_tokenizer_and_embedding_resize(
        args,
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.config.pad_token_id = tokenizer.pad_token_id
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        output_embeddings_data = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg

    # set the model config
    # if args.freeze_pad_layer_embedding:
    #     old_embeddings = model.get_input_embeddings()
    #
    #     old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
    #     new_embeddings = torch.nn.Embedding(
    #         old_num_tokens,
    #         old_embedding_dim,
    #         padding_idx=model.config.pad_token_id,
    #         device=old_embeddings.weight.device,
    #         dtype=old_embeddings.weight.dtype,
    #     )
    #
    #     new_embeddings.weight.data = old_embeddings.weight.data
    #
    #     if hasattr(old_embeddings, "_hf_hook"):
    #         hook = old_embeddings._hf_hook
    #         add_hook_to_module(new_embeddings, hook)
    #
    #     old_embeddings_requires_grad = old_embeddings.weight.requires_grad
    #     new_embeddings.requires_grad_(old_embeddings_requires_grad)
    #     model.set_input_embeddings(new_embeddings)


@dataclass
class DataCollatorForCausalLM(object):
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int
    target_max_len: int
    train_on_source: bool
    predict_with_generate: bool

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # Extract elements
        sources = [example['input'] for example in instances]
        targets = [example['output'] for example in instances]
        # Tokenize
        tokenized_sources_with_prompt = self.tokenizer(
            sources,
            max_length=self.source_max_len,
            truncation=True,
            add_special_tokens=False,
        )
        tokenized_targets = self.tokenizer(
            targets,
            max_length=self.target_max_len,
            truncation=True,
            add_special_tokens=False,
        )
        # Build the input and labels for causal LM
        input_ids = []
        labels = []
        for tokenized_source, tokenized_target in zip(
                tokenized_sources_with_prompt['input_ids'],
                tokenized_targets['input_ids']
        ):
            if not self.predict_with_generate:
                input_ids.append(torch.tensor(tokenized_source + tokenized_target))
                if not self.train_on_source:
                    labels.append(
                        torch.tensor(
                            [IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
                    )
                else:
                    labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
            else:
                input_ids.append(torch.tensor(tokenized_source))

            # check_info = self.tokenizer.decode(tokenized_source + tokenized_target, skip_special_tokens=False, spaces_between_special_tokens=False)

        # Apply padding
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels, batch_first=True,
                              padding_value=IGNORE_INDEX) if not self.predict_with_generate else None

        data_dict = {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
        }
        if labels is not None:
            data_dict['labels'] = labels
        return data_dict

@dataclass
class DataCollatorForLIMA(object):
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int
    target_max_len: int
    train_on_source: bool
    train_without_system: bool
    predict_with_generate: bool

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        conversations = [example['output'] for example in instances]

        input_ids = []
        labels_ids = []

        bos_token = self.tokenizer.bos_token
        eos_token = self.tokenizer.eos_token

        for conversation in conversations:
            p = int(len(conversation) / 2)

            ids = []
            labels = []
            # if p > 1:
            #     # For Debug
            #     print(p)
            # else:
            #     return 'None'
            for i in range(p):
                start = conversation[2 * i]
                end = conversation[2 * i + 1]
                if i == 0:
                    system_with_prompt = (bos_token + "[INST] <<SYS>>\n{system}\n<</SYS>>\n\n")
                    input_with_prompt = "{human} [/INST] "

                    system = ("A chat between a curious human and an artificial intelligence assistant. "
                              "The assistant gives helpful, detailed, and polite answers to the human's questions.")
                    human = start
                    system_with_prompt = system_with_prompt.format(system=system)
                    input_with_prompt = input_with_prompt.format(human=human)
                    # input_with_prompt = system_with_prompt + input_with_prompt
                    tokenizer_system = self.tokenizer(system_with_prompt, add_special_tokens=False)['input_ids']
                    tokenizer_source = self.tokenizer(input_with_prompt, add_special_tokens=False)['input_ids']
                    ids.extend(tokenizer_system)
                    ids.extend(tokenizer_source)

                else:
                    input_with_prompt = (bos_token + "[INST] {human} [/INST] ")
                    input_with_prompt = input_with_prompt.format(human=start)
                    tokenizer_system = []
                    tokenizer_source = self.tokenizer(input_with_prompt, add_special_tokens=False)['input_ids']
                    ids.extend(tokenizer_source)

                if not self.train_on_source:
                    labels.extend([IGNORE_INDEX] * len(tokenizer_system))
                    labels.extend([IGNORE_INDEX] * len(tokenizer_source))
                else:
                    if self.train_without_system:
                        labels.extend([IGNORE_INDEX] * len(tokenizer_system))
                        labels_source = copy.deepcopy(tokenizer_source)
                        labels.extend(labels_source)
                    else:
                        labels.extend(copy.deepcopy(tokenizer_system))
                        labels.extend(copy.deepcopy(tokenizer_source))

                    # Remove the [INST] and [/INST] tokens in labels
                    # str1 = '[INST]'
                    # str2 = '[/INST]'
                    # flag1 = False
                    # flag2 = False
                    #
                    # str1 = self.tokenizer(str1, add_special_tokens=False)['input_ids']
                    # str2 = self.tokenizer(str2, add_special_tokens=False)['input_ids']
                    #
                    # for i in range(len(labels)):
                    #     if labels[i:i+len(str1)] == str1:
                    #         labels[i:i+len(str1)] = [IGNORE_INDEX] * len(str1)
                    #         flag1 = True
                    #     else:
                    #         if i == 0:
                    #             flag1 = True
                    #
                    #     if labels[i:i+len(str2)] == str2:
                    #         labels[i:i+len(str2)] = [IGNORE_INDEX] * len(str2)
                    #         flag2 = True
                    #
                    # if not flag1 or not flag2:
                    #     raise ValueError(f"Error in removing [INST] and [/INST] tokens in labels")

                tokenizer_target = self.tokenizer(end + eos_token, add_special_tokens=False)['input_ids']
                ids.extend(tokenizer_target)
                labels.extend(copy.deepcopy(tokenizer_target))

            # Truncated to the max length
            if len(ids) > self.target_max_len:
                ids = ids[:self.target_max_len]
                labels = labels[:self.target_max_len]

            # decode_info = self.tokenizer.decode(ids, skip_special_tokens=False, spaces_between_special_tokens=False)
            # lable_info = self.tokenizer.decode(labels, skip_special_tokens=False, spaces_between_special_tokens=False)
            # # pass

            input_ids.append(torch.tensor(ids))
            labels_ids.append(torch.tensor(labels))

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels_ids, batch_first=True,
                              padding_value=IGNORE_INDEX) if not self.predict_with_generate else None

        data_dict = {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
        }
        if labels is not None:
            data_dict['labels'] = labels
        return data_dict

@dataclass
class DataCollatorForGemmaLIMA(object):
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int
    target_max_len: int
    train_on_source: bool
    train_without_system: bool
    predict_with_generate: bool

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        conversations = [example['output'] for example in instances]

        input_ids = []
        labels_ids = []

        for conversation in conversations:
            p = int(len(conversation) / 2)

            ids = []
            labels = []
            # if p > 1:
            #     # For Debug
            #     print(p)
            # else:
            #     return 'None'
            for i in range(p):
                start = conversation[2 * i]
                end = conversation[2 * i + 1]

                # system_with_prompt = ("<s>{system}\n\n")
                Role1 = "<bos><start_of_turn>user\n"
                Role2 = "<end_of_turn>\n<start_of_turn>model\n"

                input_with_prompt = Role1 + "{human} " + Role2
                human = start

                input_with_prompt = input_with_prompt.format(human=human)
                # input_with_prompt = system_with_prompt + input_with_prompt
                tokenizer_source = self.tokenizer(input_with_prompt, add_special_tokens=False)['input_ids']
                ids.extend(tokenizer_source)

                #### FOR DEBUGGING
                # IGNORE_INDEX = 0
                ###END OF CODE

                if not self.train_on_source:
                    labels.extend([IGNORE_INDEX] * len(tokenizer_source))
                else:
                    labels.extend(copy.deepcopy(tokenizer_source))

                # Target
                if 2 * i + 1 == len(conversation) - 1:
                    _token_end = ' <end_of_turn><eos>'
                else:
                    _token_end = ' <end_of_turn>\n'

                tokenizer_target = self.tokenizer(end + _token_end, add_special_tokens=False)['input_ids']
                ids.extend(tokenizer_target)
                labels.extend(copy.deepcopy(tokenizer_target))
                # Target
                tokenizer_target = self.tokenizer(end + '<end_of_turn>', add_special_tokens=False)['input_ids']
                ids.extend(tokenizer_target)
                labels.extend(copy.deepcopy(tokenizer_target))

            # Truncated to the max length
            if len(ids) > self.target_max_len:
                ids = ids[:self.target_max_len]
                labels = labels[:self.target_max_len]

            #### FOR DEBUGGING
            # decode_info = self.tokenizer.decode(ids, skip_special_tokens=False, spaces_between_special_tokens=False)
            # lable_info = self.tokenizer.decode(labels, skip_special_tokens=False, spaces_between_special_tokens=False)
            ###END OF CODE

            input_ids.append(torch.tensor(ids))
            labels_ids.append(torch.tensor(labels))

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels_ids, batch_first=True,
                              padding_value=IGNORE_INDEX) if not self.predict_with_generate else None

        data_dict = {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
        }
        if labels is not None:
            data_dict['labels'] = labels
        return data_dict

UNNATURAL_INSTRUCTIONS_PROMPT_DICT = {
    "prompt_input": (
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
    )
}

def extract_unnatural_instructions_data(examples):
    out = {
        'input': None,
        'output': None,
    }
    for example_instances in examples['instances']:
        prompt_format = UNNATURAL_INSTRUCTIONS_PROMPT_DICT["prompt_input"]
        input = {
            "instruction": examples['instruction'],
            "input": example_instances['input'],
        }
        out['input'] = '<s>' + prompt_format.format(**input)
        out['output'] = example_instances['output'] + '</s>'
    # if extract_reformulations:
    #     for example_reformulations in examples['reformulations']:
    #         if example_reformulations is not None:
    #             for instance in example_reformulations:
    #                 out['input'].append(instance['instruction_with_input'])
    #                 out['output'].append(instance['output'])
    return out


def extract_unnatural_instructions_data_v2(examples):
    out = {
        'input': None,
        'output': None,
    }
    for example_instances in examples['instances']:
        prompt_format = UNNATURAL_INSTRUCTIONS_PROMPT_DICT["prompt_input"]
        input = {
            "instruction": examples['instruction'],
            "input": example_instances['input'],
        }
        out['input'] = ''
        out['output'] = '<s>' + prompt_format.format(**input) + example_instances['output'] + '</s>'
    # if extract_reformulations:
    #     for example_reformulations in examples['reformulations']:
    #         if example_reformulations is not None:
    #             for instance in example_reformulations:
    #                 out['input'].append(instance['instruction_with_input'])
    #                 out['output'].append(instance['output'])
    return out

ALPACA_PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response: "
    ),
}

ALPACA_PROMPT_DICT_2 = {
    "prompt_input": (
        "[INST] <<SYS>> Below is an instruction that describes a task, paired with an input that provides further context."
        "Write a response that appropriately completes the request.\n\n<</SYS>>"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n[/INST] ### Response: "
    ),
    "prompt_no_input": (
        "[INST] <<SYS>> Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n<</SYS>>"
        "### Instruction:\n{instruction}\n\n[/INST] ### Response: "
    ),
}

ALPACA_PROMPT_DICT_3 = {
    "prompt_input": (
        "[INST] <<SYS>>\nYou are a helpful, unbiased, uncensored assistant.\n<</SYS>>\n\n"
        "{instruction} The input is {input} [/INST] "
    ),
    "prompt_no_input": (
        "[INST] <<SYS>>\nYou are a helpful, unbiased, uncensored assistant.\n<</SYS>>\n\n"
        "{instruction} [/INST] "
    )
}

# def extract_alpaca_dataset(example):
#     if example.get("input", "") != "":
#         prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
#     else:
#         prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
#     return {'input': prompt_format.format(**example)}
#
#
#
# def extract_alpaca_dataset_llama2(example):
#     out = {
#         'input': None,
#         'output': None,
#     }
#
#     if example.get("input", "") != "":
#         prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
#     else:
#         prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
#
#     out['input'] = '<s>' + prompt_format.format(**example)
#     out['output'] = example['output'] + '</s>'
#     return out
#
# def extract_alpaca_dataset_llama2_v2(example):
#     out = {
#         'input': "",
#         'output': None,
#     }
#
#     if example.get("input", "") != "":
#         prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
#     else:
#         prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
#
#     input = '<s>' + prompt_format.format(**example)
#     output = example['output'] + '</s>'
#
#     out['output'] = input + output
#
#     return out

def turn_alpaca_to_conversations(examples):
    out = {
        'input': None,
        'output': []
    }
    if examples.get("input", "") != "":
        prompt = "{instruction}\nInput: {input}"
    else:
        prompt = "{instruction}"

    input = prompt.format(**examples)
    out['output'].append(input)
    out['output'].append(examples['output'])

    return out

LIMA_PROMPT_DICT = (
    "{instruction}\n\n"
)

# NORobots_PROMPT_DICT = {
#     "prompt_system": (
#         "<s>{system}\n\n"
#         "User: {user} Assistant: {assistant} </s>"
#     ),
#     "prompt_no_system": (
#         "<s> User: {user} Assistant: {assistant} </s>"
#     )
# }

def extract_norobots_data(examples, bos_token, eos_token):
    NORobots_PROMPT_DICT = {
        "prompt_system": (
            bos_token + "[INST] <<SYS>>\n{system}\n<</SYS>>\n\n"
            "{user} [/INST] {assistant}" + eos_token
        ),
        "prompt_no_system": (
            bos_token + "[INST] {user} [/INST] {assistant}" + eos_token
        )
    }

    messages = examples['messages']
    out = {
        'input': '',
        'output': '',
    }

    if messages[0]['role'] == 'system':
        data = {
            'system': messages[0]['content'],
            'user': messages[1]['content'],
            'assistant': messages[2]['content'],
        }
        assert messages[0]['role'] == 'system'
        assert messages[1]['role'] == 'user'
        assert messages[2]['role'] == 'assistant'

        out['output'] = NORobots_PROMPT_DICT["prompt_system"].format(**data)

        cur_index = 3
    elif messages[0]['role'] == 'user':
        data = {
            'user': messages[0]['content'],
            'assistant': messages[1]['content']
        }

        assert messages[0]['role'] == 'user'
        assert messages[1]['role'] == 'assistant'

        out['output'] = NORobots_PROMPT_DICT["prompt_no_system"].format(**data)
        cur_index = 2
    else:
        raise ValueError(f"Invalid role: {messages[0]['role']}")

    for i in range(cur_index, len(messages), 2):
        if i + 1 == len(messages):
            if messages[i]['role'] == 'user':
                user = messages[i]['content']
                assistant = ''
            elif messages[i]['role'] == 'assistant':
                user = ''
                assistant = messages[i]['content']
            else:
                raise ValueError(f"Invalid role: {messages[i]['role']}")
        else:
            user = messages[i]['content']
            assistant = messages[i + 1]['content']
            assert messages[i]['role'] == 'user'
            assert messages[i + 1]['role'] == 'assistant'

            out['output'] += f"{bos_token}[INST] {user} [/INST] {assistant}{eos_token}"

    return out

def turn_norobots_to_conversations (examples):
    messages = examples['messages']
    out = {
        'input': None,
        'output': []
    }

    if messages[0]['role'] == 'system':
        assert messages[1]['role'] == 'user'
        assert messages[2]['role'] == 'assistant'
        # combine system and user
        out['output'].append(messages[0]['content'] + messages[1]['content'])
        out['output'].append(messages[2]['content'])
        cur_index = 3
    elif messages[0]['role'] == 'user':
        assert messages[0]['role'] == 'user'
        assert messages[1]['role'] == 'assistant'
        out['output'].append(messages[0]['content'])
        out['output'].append(messages[1]['content'])
        cur_index = 2
    else:
        raise ValueError(f"Invalid role: {messages[0]['role']}")

    for i in range(cur_index, len(messages), 2):
        if i + 1 == len(messages):
            if messages[i]['role'] == 'user':
                user = messages[i]['content']
                assistant = ''
            elif messages[i]['role'] == 'assistant':
                user = ''
                assistant = messages[i]['content']
            else:
                raise ValueError(f"Invalid role: {messages[i]['role']}")
        else:
            user = messages[i]['content']
            assistant = messages[i + 1]['content']
            assert messages[i]['role'] == 'user'
            assert messages[i + 1]['role'] == 'assistant'

        out['output'].append(user)
        out['output'].append(assistant)

    return out
# def extract_lima_data(examples):
#     if len(examples['conversations']) == 1:
#         Input = {'instruction': examples['conversations'][0]}
#         out = {
#             'input': None,
#             'output': '',
#         }
#         out['input'] = '<s>' + LIMA_PROMPT_DICT.format(**Input)
#     else:
#         output = ''
#         p = int(len(examples['conversations']) / 2)
#         for i in range(p):
#             start = examples['conversations'][2 * i]
#             end = examples['conversations'][2 * i + 1]
#             output += f'<s>User: {start} \n\n Assistant: {end} </s>'
#
#         out = {
#             'input': '',
#             'output': output,
#         }
#
#     return out
#
# def extract_lima_data_v2(examples):
#     if len(examples['conversations']) == 1:
#         Input = {'instruction': examples['conversations'][0]}
#         out = {
#             'input': None,
#             'output': '',
#         }
#         out['input'] = '<s>' + LIMA_PROMPT_DICT.format(**Input)
#     else:
#         output = ''
#         p = int(len(examples['conversations']) / 2)
#         for i in range(p):
#             start = examples['conversations'][2 * i]
#             end = examples['conversations'][2 * i + 1]
#             if i == 0:
#                 system = ("A chat between a curious human and an artificial intelligence assistant. "
#                               "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n")
#
#                 output += f'<s>{system}HUMAN: {start} ASSISTANT: {end} </s>'
#             else:
#                 output += f'<s>HUMAN: {start} ASSISTANT: {end} </s>'
#
#         out = {
#             'input': '',
#             'output': output,
#         }
#
#     return out
#
# def extract_lima_data_v3(examples):
#     out = {
#         'ouput': examples['conversations']
#     }
#     return out

def local_dataset(dataset_name):
    if dataset_name.endswith('.json') or dataset_name.endswith('.jsonl'):
        full_dataset = Dataset.from_json(path_or_paths=dataset_name)
    elif dataset_name.endswith('.csv'):
        full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name))
    elif dataset_name.endswith('.tsv'):
        full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter='\t'))
    else:
        raise ValueError(f"Unsupported dataset format: {dataset_name}")

    split_dataset = full_dataset.train_test_split(test_size=0.1)
    return split_dataset

def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict:
    """
    Make dataset and collator for supervised fine-tuning.
    Datasets are expected to have the following columns: { `input`, `output` }

    Available datasets to be selected with `dataset` argument:
        - alpaca, 52002 examples
        - alpaca cleaned, 51942 examples
        - chip2 (OIG), 210289 examples
        - self-instruct, 82612 examples
        - hh-rlhf (Anthropic), 160800 examples
        - longform, 23.7k examples
        - oasst1 (OpenAssistant) primary message tree only, 9,846 examples

    Coming soon:
        - unnatural instructions core, 66010 examples
        - unnatural instructions full, 240670 examples
        - alpaca-gpt4, 52002 examples
        - unnatural-instructions-gpt4, 9000 examples
        - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
        - flan (FLAN v2), up to 20M examples available
        - vicuna
    """

    def load_data(dataset_name):
        if dataset_name == 'alpaca-gpt4':
            return load_from_disk('./dataset/alpaca-gpt4')
        elif dataset_name == 'LIMA':
            return load_from_disk("./dataset/lima")
        elif dataset_name == 'no_robots':
            return load_from_disk("./dataset/no_robots")
            # return load_dataset("GAIR/lima")
        elif dataset_name == 'no_robots_Coding':
            return load_from_disk("./dataset/no_robots_Coding")
        elif dataset_name == 'no_robots_Rewrite':
            return load_from_disk("./dataset/no_robots_Rewrite")
        elif dataset_name == 'no_robots_Summarize':
            return load_from_disk("./dataset/no_robots_Summarize")
        elif dataset_name == 'no_robots_Extract':
            return load_from_disk("./dataset/no_robots_Extract")
        elif dataset_name == 'no_robots_Generation':
            return load_from_disk("./dataset/no_robots_Generation")
        elif dataset_name == 'no_robots_Classify':
            return load_from_disk("./dataset/no_robots_Classify")
        elif dataset_name == 'no_robots_Brainstorm':
            return load_from_disk("./dataset/no_robots_Brainstorm")
        elif dataset_name == 'no_robots_Open_QA':
            return load_from_disk("./dataset/no_robots_Open QA")
        elif dataset_name == 'no_robots_Chat':
            return load_from_disk("./dataset/no_robots_Chat")
        else:
            if os.path.exists(dataset_name):
                try:
                    args.dataset_format = args.dataset_format if args.dataset_format else "input-output"
                    full_dataset = local_dataset(dataset_name)
                    return full_dataset
                except:
                    raise ValueError(f"Error loading dataset from {dataset_name}")
            else:
                raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.")

    def format_dataset(dataset, dataset_format):
        if (
                dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or dataset_format == 'alpaca-gpt4' or
                (dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean', 'alpaca-gpt4'])
        ):
            # dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction'])
            # if args.train_on_source:
            #     dataset = dataset.map(extract_alpaca_dataset_llama2_v2, remove_columns=['instruction'])
            # else:
            #     dataset = dataset.map(extract_alpaca_dataset_llama2, remove_columns=['instruction'])
            dataset = dataset.map(turn_alpaca_to_conversations, remove_columns=['instruction'])

        elif dataset_format == 'chip2' or (dataset_format is None and args.dataset == 'chip2'):
            dataset = dataset.map(lambda x: {
                'input': x['text'].split('\n<bot>: ')[0].replace('<human>: ', ''),
                'output': x['text'].split('\n<bot>: ')[1],
            })
        elif dataset_format == 'self-instruct' or (dataset_format is None and args.dataset == 'self-instruct'):
            for old, new in [["prompt", "input"], ["completion", "output"]]:
                dataset = dataset.rename_column(old, new)
        elif dataset_format == 'hh-rlhf' or (dataset_format is None and args.dataset == 'hh-rlhf'):
            dataset = dataset.map(lambda x: {
                'input': '',
                'output': x['chosen']
            })
        elif dataset_format == 'LIMA' or (dataset_format is None and args.dataset == 'LIMA'):
            dataset = dataset.map(lambda x: {
                'input': '',
                'output': x['conversations']})
            # pass
        elif dataset_format == 'no_robots' or (dataset_format is None and args.dataset == 'no_robots'):
            # dataset = dataset.map(extract_norobots_data, fn_kwargs={'bos_token': tokenizer.bos_token, 'eos_token': tokenizer.eos_token})
            dataset = dataset.map(turn_norobots_to_conversations)
            dataset['train'] = dataset['train_sft']
            dataset['test'] = dataset['test_sft']
        elif 'no_robots_' in args.dataset:
            dataset = dataset.map(extract_norobots_data,
                                  fn_kwargs={'bos_token': tokenizer.bos_token, 'eos_token': tokenizer.eos_token})

        elif dataset_format == 'oasst1' or (dataset_format is None and args.dataset == 'oasst1'):
            dataset = dataset.map(lambda x: {
                'input': '',
                'output': '<s>' + x['text'] + '</s>',
            })
        elif dataset_format == 'unnatural-instructions-core' or (
                dataset_format is None and args.dataset == 'unnatural-instructions-core'):
            dataset = dataset.map(extract_unnatural_instructions_data)
        elif dataset_format == 'input-output':
            # leave as is
            pass
        ### Remove unused columns.
        if 'train' in dataset.column_names:
            dataset = dataset.remove_columns(
                [col for col in dataset.column_names['train'] if col not in ['input', 'output']]
            )
        else:
            dataset = dataset.remove_columns(
                [col for col in dataset.column_names if col not in ['input', 'output']]
            )
        return dataset

    # Load dataset.
    dataset = load_data(args.dataset)
    dataset = format_dataset(dataset, args.dataset_format)

    # for example in dataset:
    #     print(example)

    # for i in range(4):
    #     print(dataset['train'][i])
    #     print(dataset['train'][-i])
    # print(len(dataset['train']))
    # print(dataset[0])

    # Split train/eval, reduce size
    if args.do_eval or args.do_predict:
        if 'eval' in dataset:
            eval_dataset = dataset['eval']
        else:
            print('Splitting train dataset in train and validation according to `eval_dataset_size`')
            dataset = dataset["train"].train_test_split(
                test_size=args.eval_dataset_size, shuffle=True, seed=42
            )
            eval_dataset = dataset['test']
        if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples:
            eval_dataset = eval_dataset.select(range(args.max_eval_samples))
        if args.group_by_length:
            eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
    if args.do_train:
        if 'train' in dataset:
            train_dataset = dataset['train']
        else:
            train_dataset = dataset
        # train_dataset = train_dataset[:args.max_train_samples]
        if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples:
            train_dataset = train_dataset.select(range(args.max_train_samples))
        if args.group_by_length:
            train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})

    # if args.dataset == 'no_robots':
    #     data_collator = DataCollatorForCausalLM(
    #         tokenizer=tokenizer,
    #         source_max_len=args.source_max_len,
    #         target_max_len=args.target_max_len,
    #         train_on_source=args.train_on_source,
    #         predict_with_generate=args.predict_with_generate,
    #     )
    # else:

    if 'gemma' in args.model_name_or_path:
        data_collator = DataCollatorForGemmaLIMA(
            tokenizer=tokenizer,
            source_max_len=args.source_max_len,
            target_max_len=args.target_max_len,
            train_on_source=args.train_on_source,
            train_without_system=args.train_without_system,
            predict_with_generate=args.predict_with_generate,
        )
    elif 'Llama' in args.model_name_or_path or 'Mistral' in args.model_name_or_path:
        data_collator = DataCollatorForLIMA(
            tokenizer=tokenizer,
            source_max_len=args.source_max_len,
            target_max_len=args.target_max_len,
            train_on_source=args.train_on_source,
            train_without_system=args.train_without_system,
            predict_with_generate=args.predict_with_generate,
        )
    else:
        raise ValueError(f"Unsupported model: {args.model_name_or_path}")

    # else:
    #     data_collator = DataCollatorForCausalLM(
    #         tokenizer=tokenizer,
    #         source_max_len=args.source_max_len,
    #         target_max_len=args.target_max_len,
    #         train_on_source=args.train_on_source,
    #         predict_with_generate=args.predict_with_generate,
    #     )
    return dict(
        train_dataset=train_dataset if args.do_train else None,
        eval_dataset=eval_dataset if args.do_eval else None,
        predict_dataset=eval_dataset if args.do_predict else None,
        data_collator=data_collator
    )


def get_last_checkpoint(checkpoint_dir):
    if isdir(checkpoint_dir):
        is_completed = exists(join(checkpoint_dir, 'completed'))
        # if is_completed: return None, True # already finished
        max_step = 0
        for filename in os.listdir(checkpoint_dir):
            if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'):
                max_step = max(max_step, int(filename.replace('checkpoint-', '')))
        if max_step == 0: return None, is_completed  # training started, but no checkpoint
        checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
        print(f"Found a previous checkpoint at: {checkpoint_dir}")
        return checkpoint_dir, is_completed  # checkpoint found!
    return None, False  # first training


class MyTrainer(transformers.Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_optimizer_and_scheduler(self, num_training_steps: int):

        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
        """
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

        if self.optimizer is None:
            decay_parameters = self.get_decay_parameter_names(opt_model)
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
            ]

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
            if optimizer_cls.__name__ == "Adam8bit":
                import bitsandbytes

                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

                skipped = 0
                for module in opt_model.modules():
                    if isinstance(module, nn.Embedding):
                        skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                        logger.info(f"skipped {module}: {skipped / 2 ** 20}M params")
                        manager.register_module_override(module, "weight", {"optim_bits": 32})
                        logger.debug(f"bitsandbytes: will optimize {module} in fp32")
                logger.info(f"skipped: {skipped / 2 ** 20}M params")

        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

class AdaLoraTrainer(Seq2SeqTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)

        orth_reg_weight = model.peft_config[model.trainable_adapter_name].orth_reg_weight
        assert orth_reg_weight > 0

        if hasattr(outputs, "loss"):
            regu_loss = 0
            num_param = 0
            for n, p in model.named_parameters():
                if ("lora_A" in n or "lora_B" in n) and model.trainable_adapter_name in n:
                    para_cov = p @ p.T if "lora_A" in n else p.T @ p
                    I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov))
                    I.requires_grad = False
                    num_param += 1
                    regu_loss += torch.norm(para_cov - I, p="fro").to(outputs.loss.device)
            if num_param > 0:
                regu_loss = regu_loss / num_param
            else:
                regu_loss = 0
            outputs.loss += orth_reg_weight * regu_loss

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = unwrap_model(model)
            if is_peft_available() and isinstance(unwrapped_model, PeftModel):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

def train():
    hfparser = transformers.HfArgumentParser((
        ModelArguments, DataArguments, TrainingArguments, GenerationArguments, ExtraArguments
    ))
    model_args, data_args, training_args, generation_args, project_args, extra_args = \
        hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
    training_args.generation_config = transformers.GenerationConfig(**vars(generation_args))
    args = argparse.Namespace(
        **vars(model_args), **vars(data_args), **vars(training_args), **vars(project_args)
    )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    if training_args.should_log:
        transformers.utils.logging.set_verbosity_info()
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, bfloat16 training: {training_args.bf16}"
    )

    if 'wandb' in args.report_to:
        init_wandb(args.project_name)

    # mkdir results direction
    if args.local_rank == 0:
        print(f"results_dir: {args.results_dir}")
        args.results_dir = mkdir_if_not_exists(args.results_dir)
        print(f"results_dir 2: {args.results_dir}")

    if 'checkpoint-' in args.output_dir:
        checkpoint_dir = args.output_dir
    else:
        checkpoint_dir = None
    # else:
    #     checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
    #     if completed_training and args.do_train:
    #         raise ValueError(f'Training already completed. Please delete the output folder to start a new training.\ndir:{args.output_dir}')

    model, tokenizer = get_accelerate_model(args, checkpoint_dir)

    model.config.use_cache = False
    print('loaded model')
    set_seed(args.seed)

    data_module = make_data_module(tokenizer=tokenizer, args=args)
    # for data in data_module['train_dataset']:
    #     d = data_module['data_collator']([data])

    # if args.dynamic_lr:
    #     optimzer = create_optimizer_based_on_lora_score(args, model)
    # else:
    #     optimzer = create_optimizer(args, model)
    if args.peft_type == 'adalora':
        trainer = AdaLoraTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            **{k: v for k, v in data_module.items() if k != 'predict_dataset'},
        )
    else:
        trainer = Seq2SeqTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            **{k: v for k, v in data_module.items() if k != 'predict_dataset'},
        )

    for name, moudle in model.named_modules():
        if hasattr(moudle, 'trainer'):
            moudle.trainer = trainer

    # Callbacks
    if args.peft_type == 'adalora':
        trainer.add_callback(AdaLoRACallback)

    # Add Generation callback
    if args.do_generate:
        gen_callback = GenerationCallback(train_dataset=data_module['train_dataset'],
                                          model_id=args.model_name_or_path, dataset_name=args.dataset, max_new_tokens=128)
        trainer.add_callback(gen_callback)

    if not args.full_finetune:
        callback = SavePeftModelCallback(train_dataset=data_module['train_dataset'])
        trainer.add_callback(callback)

    if args.distribution_json_file is not None:
        prob_dict = get_layers_prob_distribution(args)
        freezer = FreezerCallback(prob_dict, args.full_finetune, args.prob_interval)
        trainer.add_callback(freezer)

    # else:
    #     callback = SavePeftModelCallback(train_dataset=data_module['train_dataset'])
    #     trainer.add_callback(callback)
    if args.combine_training:
        callback = SKillCallback(divide_ratio=args.divide_ratio, interval=args.interval)
        trainer.add_callback(callback)
    if args.use_gradient_score:
        callback = GradientCallback(args.save_interval)
        trainer.add_callback(callback)

    # Verifying the datatypes and parameter counts before training.
    print_trainable_parameters(args, model)
    dtypes = {}
    for _, p in model.named_parameters():
        dtype = p.dtype
        if dtype not in dtypes: dtypes[dtype] = 0
        dtypes[dtype] += p.numel()
    total = 0
    for k, v in dtypes.items(): total += v
    for k, v in dtypes.items():
        print(k, v, v / total)

    all_metrics = {"name": args.run_name}
    # Training
    if args.do_train:
        logger.info("*** Train ***")
        # Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF.
        # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not.
        train_result = trainer.train()
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        if local_rank == 0:
            trainer.save_model(os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-final"))
            trainer.save_state()
            metrics = train_result.metrics
            trainer.log_metrics("train", metrics)
            trainer.save_metrics("train", metrics)
            all_metrics.update(metrics)

        if args.do_mmlu_eval:
            eval_results = mmluevalcallback.get_results()
            with open(os.path.join(args.output_dir, 'mmlu_eval.json'), 'w') as fout:
                fout.write(json.dumps(eval_results))

    # Evaluation
    if args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval")
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
        all_metrics.update(metrics)
    # Prediction
    if args.do_predict:
        logger.info("*** Predict ***")
        prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'], metric_key_prefix="predict")
        prediction_metrics = prediction_output.metrics
        predictions = prediction_output.predictions
        predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
        predictions = tokenizer.batch_decode(
            predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout:
            for i, example in enumerate(data_module['predict_dataset']):
                example['prediction_with_input'] = predictions[i].strip()
                example['prediction'] = predictions[i].replace(example['input'], '').strip()
                fout.write(json.dumps(example) + '\n')
        print(prediction_metrics)
        trainer.log_metrics("predict", prediction_metrics)
        trainer.save_metrics("predict", prediction_metrics)
        all_metrics.update(prediction_metrics)

    if (args.do_train or args.do_eval or args.do_predict):
        with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
            fout.write(json.dumps(all_metrics))


if __name__ == "__main__":
    train()
