########## The following part is copied from Transformers' trainer (3.4.0) and later ported to be compatible with v4.4.2 and to support initialization from linear head probing. ##########

# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import collections
import inspect
import math
import os
import re
import shutil
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
from sympy import im
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data import Subset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import LambdaLR
import math
import time
from torch._utils import _get_all_device_indices
from collections.abc import Mapping

import transformers
from transformers.file_utils import is_datasets_available, is_in_notebook, is_torch_tpu_available
from transformers.integrations import (
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
)
from transformers.optimization import AdamW, get_linear_schedule_with_warmup, get_scheduler

from transformers.trainer_callback import (
    DefaultFlowCallback,
    ProgressCallback,
)
from transformers.trainer_utils import (
    default_compute_objective,
)
from transformers.training_args import TrainingArguments
from transformers.utils import logging
from transformers.trainer_utils import TrainOutput

from tqdm import tqdm, trange
from torch.optim import SGD
import torch.nn.functional as F

from src.linearhead_trainer import LinearHeadTrainer
from transformers.trainer_callback import TrainerState

import copy
import logging as py_logging

_use_native_amp = False
_use_apex = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from transformers.utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from transformers.file_utils import is_apex_available

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast

if version.parse(torch.__version__) < version.parse("1.2"):
    _use_ddp_no_sync = False
else:
    _use_ddp_no_sync = True

if is_datasets_available():
    import datasets

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

if is_tensorboard_available():
    from transformers.integrations import TensorBoardCallback

    DEFAULT_CALLBACKS.append(TensorBoardCallback)


if is_wandb_available():
    from transformers.integrations import WandbCallback

    DEFAULT_CALLBACKS.append(WandbCallback)

if is_comet_available():
    from transformers.integrations import CometCallback

    DEFAULT_CALLBACKS.append(CometCallback)

if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

logger = logging.get_logger(__name__)
logger.setLevel(logging.INFO)

########## The above part is copied from Transformers' trainer (3.4.0) ##########

def default_dev_objective(metrics):
    """
    Objective used for picking the best model on development sets
    """
    if "eval_mnli/acc" in metrics:
        return metrics["eval_mnli/acc"]
    elif "eval_mnli-mm/acc" in metrics:
        return metrics["eval_mnli-mm/acc"]
    elif "eval_f1" in metrics:
        return metrics["eval_f1"]
    elif "eval_mcc" in metrics:
        return metrics["eval_mcc"]
    elif "eval_pearson" in metrics:
        return metrics["eval_pearson"]
    elif "eval_acc" in metrics:
        return metrics["eval_acc"]

    raise Exception("No metric founded for {}".format(metrics))

def get_model_device(model):
    return next(model.parameters()).device

def get_resorted_device_ids(device_ids, main_device_index):
    # 创建一个新的device_ids列表，以避免修改原始输入
    new_device_ids = list(device_ids)
    # 将主设备移动到序列开头
    if main_device_index in new_device_ids:
        new_device_ids.remove(main_device_index)
        new_device_ids.insert(0, main_device_index)
    return new_device_ids

class Trainer(LinearHeadTrainer):
    """
    Adding some functions based on Transformers' Trainer class.
    """

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Based on Transformers' default one, we add fixing layer option where the bottom n layers' parameters
        are fixed and only the top layers are further fine-tuned.
        """
        if self.args.hf_inference_model:
            return

        if self.optimizer is None:
            params = {}
            for n, p in self.model.named_parameters():
                if self.args.fix_layers > 0:
                    if 'encoder.layer' in n:
                        try:
                            layer_num = int(n[n.find('encoder.layer') + 14:].split('.')[0])
                        except:
                            print(n)
                            raise Exception("")
                        if layer_num >= self.args.fix_layers:
                            print('yes', n)
                            params[n] = p
                        else:
                            print('no ', n)
                    elif 'embeddings' in n:
                        print('no ', n)
                    else:
                        print('yes', n)
                        params[n] = p
                else:
                    params[n] = p
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in params.items() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            if self.args.optimizer == 'adam':
                self.optimizer = AdamW(
                    optimizer_grouped_parameters,
                    lr=self.args.learning_rate,
                    betas=(self.args.adam_beta1, self.args.adam_beta2),
                    eps=self.args.adam_epsilon,
                )
            elif self.args.optimizer == 'sgd':
                self.optimizer = SGD(
                    optimizer_grouped_parameters,
                    lr=self.args.learning_rate
                )
            else:
                raise NotImplementedError
        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
                optimizer=self.optimizer,
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
                num_training_steps=num_training_steps,
            )

    def should_optim(self, name, param):
        return (not self.args.layer_wise_optim or f".{self.state.global_step % self.model.config.num_hidden_layers}." in name) and param.requires_grad

    def zo_forward(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        model.eval()
        
        inputs = self._prepare_inputs(inputs)
        # device = get_model_device(model)
        # inputs = self._prepare_inputs(inputs, device)
        # inputs.to(device)
        
        if self.args.optimize_acc:
            loss, logits = model(**inputs)
            preds = F.softmax(logits, dim=-1)
            acc = torch.sum(torch.argmax(preds, 1) == inputs['labels']) / len(preds)
            loss = -acc
        else:
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
        self.state.zo_forward_step += 1
        return loss.detach()

    # 该函数是原版的efficient_perturb_parameters函数，必须保证model是self.model
    # 它的model必须是self.model，因为被更新的参数是self.named_parameters_to_optim
    # 其返回的model并无意义 —— 或者说，可以在调用后直接访问 self.model 取得model，不应写在函数中
    # 在一般的联邦学习代码中会出问题 —— 因为你无法保证model是self.model（但CeZO没问题）
    def efficient_perturb_parameters(self, model: nn.Module, random_seed: int, scaling_factor=1):
        torch.manual_seed(random_seed)
        for name, param in self.named_parameters_to_optim:
            z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
            param.data = param.data + scaling_factor * z * self.args.zero_order_eps
        return model
    
    # For FedAvg and FeedAvg
    def efficient_perturb_parameters_federated(self, model: nn.Module, random_seed: int, scaling_factor=1):
        torch.manual_seed(random_seed)
        for name, param in model.named_parameters():
            if self.should_optim(name, param):
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                param.data = param.data + scaling_factor * z * self.args.zero_order_eps
        return model

    # 同理，该函数是原版函数，必须保证model是self.model
    def norm_perturb_parameters(self, model: nn.Module, random_vector=None, scaling_factor=1):
        if random_vector is None:
            random_vector = {}

        for name, param in self.named_parameters_to_optim:
            if name in random_vector:
                z = random_vector[name]
            else:
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                random_vector[name] = z

            cname = self.retrieve_c(name)
            if cname in self.cs:
                z = z / self.cs[cname]

            param.data = param.data + scaling_factor * z * self.args.zero_order_eps

        return model, random_vector

    def norm_perturb_parameters_federated(self, model: nn.Module, random_vector=None, scaling_factor=1):
        if random_vector is None:
            random_vector = {}

        for name, param in model.named_parameters():
            if self.should_optim(name, param):
                if name in random_vector:
                    z = random_vector[name]
                else:
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    random_vector[name] = z

                cname = self.retrieve_c(name)
                if cname in self.cs:
                    z = z / self.cs[cname]

                param.data = param.data + scaling_factor * z * self.args.zero_order_eps

        return model, random_vector
    
    # 同理，该函数是原版函数，必须保证model是self.model
    def perturb_parameters(self, model: nn.Module, random_vector=None, scaling_factor=1):
        if random_vector is None:
            random_vector = {}

        for name, param in self.named_parameters_to_optim:
            if name in random_vector:
                z = random_vector[name]
            else:
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                random_vector[name] = z
            param.data = param.data + scaling_factor * z * self.args.zero_order_eps

        return model, random_vector
    
    def perturb_parameters_federated(self, model: nn.Module, random_vector=None, scaling_factor=1):
        if random_vector is None:
            random_vector = {}

        for name, param in model.named_parameters():
            if self.should_optim(name, param):
                if name in random_vector:
                    z = random_vector[name]
                else:
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    random_vector[name] = z
                param.data = param.data + scaling_factor * z * self.args.zero_order_eps

        return model, random_vector

    # 同理，该函数是原版函数，必须保证model是self.model
    def perturb_single_layer(self, model, layer_name, random_vector=None, scaling_factor=1):
        if random_vector is None:
            random_vector = {}

        for name, param in self.named_parameters_to_optim:
            cname = self.retrieve_c(name)
            if cname == layer_name:
                if name in random_vector:
                    z = random_vector[name]
                else:
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    random_vector[name] = z
                param.data = param.data + scaling_factor * z * self.args.zero_order_eps

        return model, random_vector
    
    def perturb_single_layer_federated(self, model, layer_name, random_vector=None, scaling_factor=1):
        if random_vector is None:
            random_vector = {}

        for name, param in model.named_parameters():
            if self.should_optim(name, param):
                cname = self.retrieve_c(name)
                if cname == layer_name:
                    if name in random_vector:
                        z = random_vector[name]
                    else:
                        z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                        random_vector[name] = z
                    param.data = param.data + scaling_factor * z * self.args.zero_order_eps

        return model, random_vector

    # 同理，该函数是原版函数，必须保证model是self.model
    def initialize_c(self, model, inputs):
        self.named_parameters_to_optim = []
        for name, param in model.named_parameters():
            if self.should_optim(name, param):
                self.named_parameters_to_optim.append((name, param))

        self.cs = {'embed': 0.0, 'lm_head': 0.0} 
        # OPT: embed_tokens; embed_positions
        # RoBERTa: embeddings
        self.num_params = copy.deepcopy(self.cs)
        self.num_model_layers = model.config.num_hidden_layers
        layer_name = "layers" if model.config.model_type == "opt" else "layer"
        for i in range(self.num_model_layers): 
            self.cs[f'{layer_name}.{i}.'] = 0.0
            self.num_params[f'{layer_name}.{i}.'] = 0
        
        # ZO estimation of c's
        if self.args.zo_variant != 'param_norm' and self.args.use_zo_grad_est:
            for layer in self.cs.keys():
                with torch.no_grad():
                    model, z = self.perturb_single_layer(model, layer_name=layer)
                    loss1 = self.zo_forward(model, inputs)
                    model, z = self.perturb_single_layer(model, layer_name=layer, random_vector=z, scaling_factor=-2)
                    loss2 = self.zo_forward(model, inputs)

                projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)
                self.cs[layer] = torch.abs(projected_grad)

                model, z = self.perturb_single_layer(model, layer_name=layer, random_vector=z)
        
        # no need to run backprop if we are using parameter norm variant, can just measure them
        elif self.args.zo_variant == 'param_norm':
            for name, param in self.named_parameters_to_optim:
                print(name)
                ckey = self.retrieve_c(name)
                if ckey in self.cs:
                    self.cs[ckey] += torch.sum(param.data ** 2)
                    self.num_params[ckey] += param.data.numel()

            # take sqrt to get norm
            for ckey in self.cs:
                self.cs[ckey] = torch.sqrt(self.cs[ckey])
                if self.args.scale_norm_by_num_params:
                    self.cs[ckey] /= torch.sqrt(self.cs[ckey])
            
            for ckey in self.cs:
                if self.cs[ckey] != 0:
                    self.cs[ckey] = self.cs[ckey].detach().item()
        
        # backpropagation estimation fo ZO c's
        #   this is mostly for debugging purposes to disentangle the variance from using ZO to estimate c
        #   from the effectiveness of the preconditioners
        else: 
            model.eval()
            inputs = self._prepare_inputs(inputs)
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            loss.backward()
            for name, param in self.named_parameters_to_optim:
                if param.grad is None:
                    print(name)
                else:
                    ckey = self.retrieve_c(name)
                    if ckey in self.cs:
                        self.cs[ckey] += torch.sum(param.grad ** 2)
                        self.num_params[ckey] += param.grad.numel()

            # take sqrt to get norm
            for ckey in self.cs:
                self.cs[ckey] = torch.sqrt(self.cs[ckey])
                if self.args.scale_norm_by_num_params:
                    self.cs[ckey] /= torch.sqrt(self.num_params[ckey])

            for ckey in self.cs:
                if self.cs[ckey] != 0:
                    self.cs[ckey] = self.cs[ckey].detach().item()

        self.layer_names = list(self.cs.keys())
        model.zero_grad()

    def initialize_c_federated(self, model, inputs):
        named_parameters_to_optim = []
        for name, param in model.named_parameters():
            if self.should_optim(name, param):
                named_parameters_to_optim.append((name, param))

        self.cs = {'embed': 0.0, 'lm_head': 0.0} 
        # OPT: embed_tokens; embed_positions
        # RoBERTa: embeddings
        self.num_params = copy.deepcopy(self.cs)
        self.num_model_layers = model.config.num_hidden_layers
        layer_name = "layers" if model.config.model_type == "opt" else "layer"
        for i in range(self.num_model_layers): 
            self.cs[f'{layer_name}.{i}.'] = 0.0
            self.num_params[f'{layer_name}.{i}.'] = 0
        
        # ZO estimation of c's
        if self.args.zo_variant != 'param_norm' and self.args.use_zo_grad_est:
            for layer in self.cs.keys():
                with torch.no_grad():
                    model, z = self.perturb_single_layer_federated(model, layer_name=layer)
                    loss1 = self.zo_forward(model, inputs)
                    model, z = self.perturb_single_layer_federated(model, layer_name=layer, random_vector=z, scaling_factor=-2)
                    loss2 = self.zo_forward(model, inputs)

                projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)
                self.cs[layer] = torch.abs(projected_grad)

                model, z = self.perturb_single_layer_federated(model, layer_name=layer, random_vector=z)
        
        # no need to run backprop if we are using parameter norm variant, can just measure them
        elif self.args.zo_variant == 'param_norm':
            for name, param in named_parameters_to_optim:
                print(name)
                ckey = self.retrieve_c(name)
                if ckey in self.cs:
                    self.cs[ckey] += torch.sum(param.data ** 2)
                    self.num_params[ckey] += param.data.numel()

            # take sqrt to get norm
            for ckey in self.cs:
                self.cs[ckey] = torch.sqrt(self.cs[ckey])
                if self.args.scale_norm_by_num_params:
                    self.cs[ckey] /= torch.sqrt(self.cs[ckey])
            
            for ckey in self.cs:
                if self.cs[ckey] != 0:
                    self.cs[ckey] = self.cs[ckey].detach().item()
        
        # backpropagation estimation fo ZO c's
        #   this is mostly for debugging purposes to disentangle the variance from using ZO to estimate c
        #   from the effectiveness of the preconditioners
        else: 
            model.eval()
            inputs = self._prepare_inputs(inputs)
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            loss.backward()
            for name, param in named_parameters_to_optim:
                if param.grad is None:
                    print(name)
                else:
                    ckey = self.retrieve_c(name)
                    if ckey in self.cs:
                        self.cs[ckey] += torch.sum(param.grad ** 2)
                        self.num_params[ckey] += param.grad.numel()

            # take sqrt to get norm
            for ckey in self.cs:
                self.cs[ckey] = torch.sqrt(self.cs[ckey])
                if self.args.scale_norm_by_num_params:
                    self.cs[ckey] /= torch.sqrt(self.num_params[ckey])

            for ckey in self.cs:
                if self.cs[ckey] != 0:
                    self.cs[ckey] = self.cs[ckey].detach().item()

        self.layer_names = list(self.cs.keys())
        model.zero_grad()

    def retrieve_c(self, param_name):
        for c_name in self.cs.keys():
            if c_name in param_name:
                return c_name

        return '' # these parameters are likely not being used in the forward pass

    def get_num_samples(self):
        if self.args.zero_order_sample_scheduler is None:
            noise_sample_time = 1 
        elif self.args.zero_order_sample_scheduler == "linear":
            noise_sample_time = max(1, int(self.state.global_step / self.args.max_steps * self.args.zero_order_sample))
        elif self.args.zero_order_sample_scheduler == "constant":
            noise_sample_time = int(self.args.zero_order_sample)
        else:
            raise NotImplementedError
        # print("Sample %d zs" % (noise_sample_time))

        return noise_sample_time

    def is_batch_nums_equal(self, data_loaders):
        batch_num = len(data_loaders[0])
        for data_loader in data_loaders:
            if len(data_loader) != batch_num:
                return False
        return True

    def train(self, model_path=None, dev_objective=None):
        """
        Main training entry point.

        The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
        Add early stopping.
        """
        
        # 创建一个文件处理器，设置日志文件路径
        file_handler = py_logging.FileHandler(os.path.join(self.args.output_dir, self.args.log_file))
        file_handler.setLevel(py_logging.INFO)
        # 将处理器添加到 logger
        logger.addHandler(file_handler)
        
        if self.args.from_linearhead and model_path is None:
            super().train(model_path, dev_objective) # Train output layer using LinearHeadTrainer

        self.best_dir = None
        self.objective = -float("inf")
        self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective

        # Data loading.
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        if num_update_steps_per_epoch == 0:
            num_update_steps_per_epoch = 1
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        self.create_optimizer_and_scheduler(num_training_steps=t_total)
        optimizer = self.optimizer
        scheduler = self.lr_scheduler

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model

        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        # Train
        if transformers.is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.state = TrainerState()
        self.state.global_step = 0
        start_time = time.time()
        self.state.zo_forward_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        if self.args.gradient_checkpointing:
            model.gradient_checkpointing_enable()

        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.state.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.state.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.state.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.state.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.state.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = torch.tensor(0.0).to(self.args.device)
        logging_loss_scalar = 0.0
        model.zero_grad()
        metrics = None
        for epoch in range(epochs_trained, int(num_train_epochs)):
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if transformers.is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):
                if self.args.sync_embedding_layers:
                    assert model.module.model_type == 'opt', 'did not implement embedding layer synchronization for non-OPT models'
                    model.module.model.decoder.embed_tokens.weight = model.module.lm_head.weight

                # estimate c's (param or grad norm) on epoch 0
                if epoch == 0 and step == 0 and self.args.zo_variant is not None:
                    self.initialize_c(model, inputs)
                elif step == 0 and self.args.zo_variant is not None and self.args.recompute_norms:
                    self.initialize_c(model, inputs)
                
                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue
                    
                if self.args.zero_order_optim:
                    # Get parameters that should be optimized (for layer-wise optimization and prefix-tuning)
                    self.named_parameters_to_optim = []
                    for name, param in model.named_parameters():
                        if self.should_optim(name, param):
                            self.named_parameters_to_optim.append((name, param))

                    if self.args.zo_by_layer:
                        assert not self.args.efficient_zero_order, 'did not implement preconditioned ZO for efficient ZO yet'
                        assert self.args.zero_order_use_trainer_optim, 'preconditioned ZO requires using the trainer optimizer'
                        num_zs = self.get_num_samples()
                        layers = [np.random.choice(self.layer_names)] if self.args.pc_rnd_layer else self.layer_names

                        # for each layer: perturb only that layer and store the gradient estimates in the grad buffer
                        for layer in self.layer_names:
                            for _ in range(num_zs):
                                c_i = self.cs[layer]
                                with torch.no_grad():
                                    c_i = 1.0 if c_i == 0 else c_i # if the scaling is 0, just reset it to 1 so that there can eventually be some gradient to those layers 
                                    model, random_vector = self.perturb_single_layer(model, layer, scaling_factor=1.0/c_i)
                                    loss1 = self.zo_forward(model, inputs)
                                    model, random_vector = self.perturb_single_layer(model, layer, random_vector=random_vector, scaling_factor=-2.0/c_i)
                                    loss2 = self.zo_forward(model, inputs)
                                    model, random_vector = self.perturb_single_layer(model, layer, random_vector=random_vector, scaling_factor=1.0/c_i)

                                projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)
                                # scale grad according to number of zs sampled
                                if not self.args.scale_lr_with_samples:
                                    projected_grad = projected_grad / float(num_zs)
                                
                                for name, param in self.named_parameters_to_optim:
                                    if self.retrieve_c(name) == layer:
                                        z_tilde = random_vector[name] * c_i

                                        if param.grad is None:
                                            param.grad = projected_grad * z_tilde
                                        else:
                                            param.grad += projected_grad * z_tilde

                                # note that  | E_z [ <z, grad of one layer > ] | is equal to norm of grad for that layer for gaussian z
                                # leverages this fact to update the grad norms
                                if self.args.zo_variant == 'grad_norm' and self.args.norm_running_update:
                                    self.cs[layer] = torch.abs(projected_grad)
                    else:
                        # get number of zs to sample
                        num_zs = self.get_num_samples()
                        if num_zs > 1:
                            assert self.args.zero_order_use_trainer_optim, 'cannot sample multiple zs without storing intermediate gradient. use trainer.'

                        for _ in range(num_zs):
                            # prepare for sampling new zs
                            random_vector = None
                            if self.args.efficient_zero_order:
                                random_seed = np.random.randint(1000000000)

                            with torch.no_grad():
                                # first function evaluation
                                if self.args.efficient_zero_order:
                                    model = self.efficient_perturb_parameters(model, random_seed)
                                elif self.args.zo_variant is not None:
                                    model, random_vector = self.norm_perturb_parameters(model)
                                else:
                                    model, random_vector = self.perturb_parameters(model)
                                loss1 = self.zo_forward(model, inputs)

                                # second function evaluation
                                if self.args.efficient_zero_order:
                                    model = self.efficient_perturb_parameters(model, random_seed, scaling_factor=-2)
                                elif self.args.zo_variant is not None:
                                    model, random_vector = self.norm_perturb_parameters(model, random_vector, scaling_factor=-2)
                                else:
                                    model, random_vector = self.perturb_parameters(model, random_vector, scaling_factor=-2)                 
                                loss2 = self.zo_forward(model, inputs)

                            projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)

                            # scale grad according to accumulation
                            if self.args.gradient_accumulation_steps > 1:
                                assert self.args.zero_order_use_trainer_optim, 'grad accumulation not implemented for non-trainer ZO yet'
                                projected_grad = projected_grad / self.args.gradient_accumulation_steps
                            
                            # scale grad according to number of zs sampled
                            if not self.args.scale_lr_with_samples:
                                projected_grad = projected_grad / float(num_zs)

                            # store gradient in parameter buffer if using trainer
                            # o/w, the loop will exit after one round and the update will be applied directly (see below)
                            if self.args.zero_order_use_trainer_optim:
                                if self.args.efficient_zero_order:
                                    # print(random_seed)
                                    torch.manual_seed(random_seed)
                                
                                for name, param in self.named_parameters_to_optim:
                                    # recover noise used in perturbations
                                    if self.args.efficient_zero_order:
                                        z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                                    else:
                                        z = random_vector[name]

                                    if self.args.zo_variant is not None and not self.args.change_grad_estimate:
                                        cname = self.retrieve_c(name)
                                        if cname in self.cs:
                                            z = z * self.cs[cname]

                                    if param.grad is None:
                                        param.grad = projected_grad * z
                                    else:
                                        param.grad += projected_grad * z

                            # reset model back to its parameters at start of step
                            if self.args.efficient_zero_order:
                                model = self.efficient_perturb_parameters(model, random_seed)
                            elif self.args.zo_variant is not None:
                                model, random_vector = self.norm_perturb_parameters(model, random_vector)   
                            else:
                                model, random_vector = self.perturb_parameters(model, random_vector)

                    # apply gradient updates
                    # if using trainer, follow trainer logic to clip grad and check if parameters should be updated
                    if self.args.zero_order_use_trainer_optim:
                        if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                            # last step in epoch but step is always smaller than gradient_accumulation_steps
                            len(epoch_iterator) <= self.args.gradient_accumulation_steps
                            and (step + 1) == len(epoch_iterator)
                        ):
                            # Gradient norm clipping
                            if self.args.zero_order_clip_grad:
                                norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                            # Update the parameters and step scheduler
                            optimizer.step()
                            scheduler.step()
                        
                            # logging
                            if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                                self.state.global_step == 1 and self.args.logging_first_step
                            ):
                                logs = {}
                                logs["loss"] = loss1.item()
                                if not self.args.zero_order_clip_grad:
                                    norm = 0.0
                                    for _, p in model.named_parameters():
                                        if p.grad is not None:
                                            norm += torch.sum(p.grad ** 2)
                                    norm = torch.sqrt(norm)
                                logs["grad_norm"] = norm.item()
                                logs["learning_rate"] = (
                                    scheduler.get_last_lr()[0]
                                    if version.parse(torch.__version__) >= version.parse("1.4")
                                    else scheduler.get_lr()[0]
                                )
                                logs["num_zs"] = num_zs
                                logs["global_step"] = self.state.global_step
                                logs["zo_forward_step"] = self.state.zo_forward_step
                                logs["max_steps"] = self.args.max_steps
                                logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                                logs["time"] = int(time.time() - start_time)
                                self.log(logs)
                                logger.info(str(logs))
                            
                            model.zero_grad()
                            self.state.global_step += 1
                            self.epoch = epoch + (step + 1) / len(epoch_iterator)
                    # if not using the trainer, the updates are resampled and directly applied to the parameters
                    else:
                        # Efficient mode 
                        # WARNING: no gradient accumulation when not storing the grad
                        assert self.args.gradient_accumulation_steps == 1, 'gradient accumulation is not supported for zero-order optimization'
                        assert self.args.zero_order_sample_scheduler is None
                        assert not self.args.zero_order_clip_grad, 'gradient clipping not implemented yet for non-trainer ZO'

                        if self.args.efficient_zero_order:
                            torch.manual_seed(random_seed)     
                        for name, param in self.named_parameters_to_optim:
                            if self.args.efficient_zero_order:
                                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                            else:
                                z = random_vector[name]
                            param.data = param.data - self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data)

                        if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                                self.state.global_step == 1 and self.args.logging_first_step
                            ):
                                logs = {}
                                logs["loss"] = loss1.item()
                                logs["learning_rate"] = self.args.learning_rate
                                logs["global_step"] = self.state.global_step
                                logs["zo_forward_step"] = self.state.zo_forward_step
                                logs["max_steps"] = self.args.max_steps
                                logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                                logs["time"] = int(time.time() - start_time)
                                self.log(logs)
                                logger.info(str(logs))


                        self.state.global_step += 1
                        self.epoch = epoch + (step + 1) / len(epoch_iterator)
                    
                    # Debug information
                    # print("%.5f, %.5f" % (loss1.item(), loss2.item()))
                    # print("Loss: %.10f, projected_grad: %.5f" % (loss1, projected_grad))

                # standard, non-ZO optimization
                else:
                    tr_loss += self.training_step(model, inputs)

                    if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <= self.args.gradient_accumulation_steps
                        and (step + 1) == len(epoch_iterator)
                    ):
                        if self.args.fp16 and _use_native_amp:
                            self.scaler.unscale_(optimizer)
                            norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                        elif self.args.fp16:
                            norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                        else:
                            norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                        if self.args.optimizer_variant == 'signgd':
                            for n,p in model.named_parameters():
                                if p.grad is not None:
                                    p.grad = torch.sign(p.grad)

                        if transformers.is_torch_tpu_available():
                            xm.optimizer_step(optimizer)
                        elif self.args.fp16 and _use_native_amp:
                            self.scaler.step(optimizer)
                            self.scaler.update()
                        else:
                            optimizer.step()

                        scheduler.step()
                        model.zero_grad()
                        self.state.global_step += 1
                        self.epoch = epoch + (step + 1) / len(epoch_iterator)

                        if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                            self.state.global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            tr_loss_scalar = tr_loss.item()
                            logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
                            logs["norm"] = norm.item()
                            # backward compatibility for pytorch schedulers
                            logs["learning_rate"] = (
                                scheduler.get_last_lr()[0]
                                if version.parse(torch.__version__) >= version.parse("1.4")
                                else scheduler.get_lr()[0]
                            )
                            logging_loss_scalar = tr_loss_scalar

                            self.log(logs)
                            logger.info(str(logs))

                if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                    epoch_iterator.close()
                    break
                
                # if self.args.evaluate_during_training and self.state.global_step % self.args.eval_steps == 0:
                if self.args.evaluate_during_training and epoch % self.args.eval_epoch == 0:
                    output = self.evaluate()
                    metrics = output.metrics
                    objective = self.dev_objective(metrics)
                    if objective > self.objective:
                        logger.info("Best dev result: {}".format(objective))
                        self.objective = objective
                        # self.save_model(self.args.output_dir)

                        # Now we save this to (CPU) memory instead of disk <-- much faster
                        self.best_model_ckpt = {k: v.detach().cpu() for k, v in model.state_dict().items()}

            if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                # train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.state.global_step, tr_loss / self.state.global_step, metrics), self.objective


    # --------------------------------------------------- Federated ---------------------------------------------------
    def train_for_single_model_batch(self, model, optimizer, scheduler, epoch, 
                                     step, inputs, tr_loss, start_time, steps_trained_in_current_epoch):
        if self.args.sync_embedding_layers:
            assert model.module.model_type == 'opt', 'did not implement embedding layer synchronization for non-OPT models'
            model.module.model.decoder.embed_tokens.weight = model.module.lm_head.weight

        # estimate c's (param or grad norm) on epoch 0
        if epoch == 0 and step == 0 and self.args.zo_variant is not None:
            self.initialize_c_federated(model, inputs)
        elif step == 0 and self.args.zo_variant is not None and self.args.recompute_norms:
            self.initialize_c_federated(model, inputs)
            
        if self.args.zero_order_optim:
            # Get parameters that should be optimized (for layer-wise optimization and prefix-tuning)
            self.named_parameters_to_optim = []
            for name, param in model.named_parameters():
                if self.should_optim(name, param):
                    self.named_parameters_to_optim.append((name, param))

            if self.args.zo_by_layer:
                assert not self.args.efficient_zero_order, 'did not implement preconditioned ZO for efficient ZO yet'
                assert self.args.zero_order_use_trainer_optim, 'preconditioned ZO requires using the trainer optimizer'
                num_zs = self.get_num_samples()
                layers = [np.random.choice(self.layer_names)] if self.args.pc_rnd_layer else self.layer_names

                # for each layer: perturb only that layer and store the gradient estimates in the grad buffer
                for layer in self.layer_names:
                    for _ in range(num_zs):
                        c_i = self.cs[layer]
                        with torch.no_grad():
                            c_i = 1.0 if c_i == 0 else c_i # if the scaling is 0, just reset it to 1 so that there can eventually be some gradient to those layers 
                            model, random_vector = self.perturb_single_layer_federated(model, layer, scaling_factor=1.0/c_i)
                            loss1 = self.zo_forward(model, inputs)
                            model, random_vector = self.perturb_single_layer_federated(model, layer, random_vector=random_vector, scaling_factor=-2.0/c_i)
                            loss2 = self.zo_forward(model, inputs)
                            model, random_vector = self.perturb_single_layer_federated(model, layer, random_vector=random_vector, scaling_factor=1.0/c_i)

                        projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)
                        # scale grad according to number of zs sampled
                        if not self.args.scale_lr_with_samples:
                            projected_grad = projected_grad / float(num_zs)
                        
                        for name, param in self.named_parameters_to_optim:
                            if self.retrieve_c(name) == layer:
                                z_tilde = random_vector[name] * c_i

                                if param.grad is None:
                                    param.grad = projected_grad * z_tilde
                                else:
                                    param.grad += projected_grad * z_tilde

                        # note that  | E_z [ <z, grad of one layer > ] | is equal to norm of grad for that layer for gaussian z
                        # leverages this fact to update the grad norms
                        if self.args.zo_variant == 'grad_norm' and self.args.norm_running_update:
                            self.cs[layer] = torch.abs(projected_grad)
            else:
                # get number of zs to sample
                num_zs = self.get_num_samples()
                if num_zs > 1:
                    assert self.args.zero_order_use_trainer_optim, 'cannot sample multiple zs without storing intermediate gradient. use trainer.'

                for _ in range(num_zs):
                    # prepare for sampling new zs
                    random_vector = None
                    if self.args.efficient_zero_order:
                        random_seed = np.random.randint(1000000000)

                    with torch.no_grad():
                        # first function evaluation
                        if self.args.efficient_zero_order:
                            model = self.efficient_perturb_parameters_federated(model, random_seed)
                        elif self.args.zo_variant is not None:
                            model, random_vector = self.norm_perturb_parameters_federated(model)
                        else:
                            model, random_vector = self.perturb_parameters_federated(model)
                        loss1 = self.zo_forward(model, inputs)

                        # second function evaluation
                        if self.args.efficient_zero_order:
                            model = self.efficient_perturb_parameters_federated(model, random_seed, scaling_factor=-2)
                        elif self.args.zo_variant is not None:
                            model, random_vector = self.norm_perturb_parameters_federated(model, random_vector, scaling_factor=-2)
                        else:
                            model, random_vector = self.perturb_parameters_federated(model, random_vector, scaling_factor=-2)                 
                        loss2 = self.zo_forward(model, inputs)

                    projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)

                    # scale grad according to accumulation
                    if self.args.gradient_accumulation_steps > 1:
                        assert self.args.zero_order_use_trainer_optim, 'grad accumulation not implemented for non-trainer ZO yet'
                        projected_grad = projected_grad / self.args.gradient_accumulation_steps
                    
                    # scale grad according to number of zs sampled
                    if not self.args.scale_lr_with_samples:
                        projected_grad = projected_grad / float(num_zs)

                    # store gradient in parameter buffer if using trainer
                    # o/w, the loop will exit after one round and the update will be applied directly (see below)
                    if self.args.zero_order_use_trainer_optim:
                        if self.args.efficient_zero_order:
                            # print(random_seed)
                            torch.manual_seed(random_seed)
                        
                        for name, param in self.named_parameters_to_optim:
                            # recover noise used in perturbations
                            if self.args.efficient_zero_order:
                                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                            else:
                                z = random_vector[name]

                            if self.args.zo_variant is not None and not self.args.change_grad_estimate:
                                cname = self.retrieve_c(name)
                                if cname in self.cs:
                                    z = z * self.cs[cname]

                            if param.grad is None:
                                param.grad = projected_grad * z
                            else:
                                param.grad += projected_grad * z

                    # reset model back to its parameters at start of step
                    if self.args.efficient_zero_order:
                        model = self.efficient_perturb_parameters_federated(model, random_seed)
                    elif self.args.zo_variant is not None:
                        model, random_vector = self.norm_perturb_parameters_federated(model, random_vector)   
                    else:
                        model, random_vector = self.perturb_parameters_federated(model, random_vector)

            # apply gradient updates
            # if using trainer, follow trainer logic to clip grad and check if parameters should be updated
            if self.args.zero_order_use_trainer_optim:
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    # Gradient norm clipping
                    if self.args.zero_order_clip_grad:
                        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    # Update the parameters and step scheduler
                    optimizer.step()
                    scheduler.step()
                
                    # logging
                    if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                        self.state.global_step == 1 and self.args.logging_first_step
                    ):
                        logs = {}
                        logs["loss"] = loss1.item()
                        if not self.args.zero_order_clip_grad:
                            norm = 0.0
                            for _, p in model.named_parameters():
                                if p.grad is not None:
                                    norm += torch.sum(p.grad ** 2)
                            norm = torch.sqrt(norm)
                        logs["grad_norm"] = norm.item()
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
                        logs["num_zs"] = num_zs
                        logs["global_step"] = self.state.global_step
                        logs["zo_forward_step"] = self.state.zo_forward_step
                        logs["max_steps"] = self.args.max_steps
                        logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                        logs["time"] = int(time.time() - start_time)
                        self.log(logs)
                        logger.info(str(logs))
                    
                    model.zero_grad()
                    self.state.global_step += 1
                    # self.epoch = epoch + (step + 1) / len(epoch_iterator)
            # if not using the trainer, the updates are resampled and directly applied to the parameters
            else:
                # Efficient mode 
                # WARNING: no gradient accumulation when not storing the grad
                assert self.args.gradient_accumulation_steps == 1, 'gradient accumulation is not supported for zero-order optimization'
                assert self.args.zero_order_sample_scheduler is None
                assert not self.args.zero_order_clip_grad, 'gradient clipping not implemented yet for non-trainer ZO'

                if self.args.efficient_zero_order:
                    torch.manual_seed(random_seed)     
                for name, param in self.named_parameters_to_optim:
                    if self.args.efficient_zero_order:
                        z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    else:
                        z = random_vector[name]
                    param.data = param.data - self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data)

                if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                        self.state.global_step == 1 and self.args.logging_first_step
                    ):
                        logs = {}
                        logs["loss"] = loss1.item()
                        logs["learning_rate"] = self.args.learning_rate
                        logs["global_step"] = self.state.global_step
                        logs["zo_forward_step"] = self.state.zo_forward_step
                        logs["max_steps"] = self.args.max_steps
                        logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                        logs["time"] = int(time.time() - start_time)
                        self.log(logs)
                        logger.info(str(logs))


                self.state.global_step += 1
                # self.epoch = epoch + (step + 1) / len(epoch_iterator)
            
            # Debug information
            # print("%.5f, %.5f" % (loss1.item(), loss2.item()))
            # print("Loss: %.10f, projected_grad: %.5f" % (loss1, projected_grad))

        # standard, non-ZO optimization
        else:
            loss_step = self.training_step(model, inputs)
            tr_loss += loss_step

            if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                len(epoch_iterator) <= self.args.gradient_accumulation_steps
                and (step + 1) == len(epoch_iterator)
            ):
                if self.args.fp16 and _use_native_amp:
                    self.scaler.unscale_(optimizer)
                    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                elif self.args.fp16:
                    norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                else:
                    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                if self.args.optimizer_variant == 'signgd':
                    for n,p in model.named_parameters():
                        if p.grad is not None:
                            p.grad = torch.sign(p.grad)

                if transformers.is_torch_tpu_available():
                    xm.optimizer_step(optimizer)
                elif self.args.fp16 and _use_native_amp:
                    self.scaler.step(optimizer)
                    self.scaler.update()
                else:
                    optimizer.step()

                scheduler.step()
                model.zero_grad()
                self.state.global_step += 1
                # self.epoch = epoch + (step + 1) / len(epoch_iterator)

                if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                    self.state.global_step == 1 and self.args.logging_first_step
                ):
                    logs = {}
                    logs["loss"] = loss_step.item()
                    logs["norm"] = norm.item()
                    # backward compatibility for pytorch schedulers
                    logs["learning_rate"] = (
                        scheduler.get_last_lr()[0]
                        if version.parse(torch.__version__) >= version.parse("1.4")
                        else scheduler.get_lr()[0]
                    )

                    self.log(logs)
                    logger.info(str(logs))
    
        return model, tr_loss

    def train_for_single_model_epoch(self, model, train_dataloader, optimizer, scheduler, epoch, 
                               tr_loss, start_time, steps_trained_in_current_epoch):
        if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
            train_dataloader.sampler.set_epoch(epoch)

        device = get_model_device(model)
        
        # if transformers.is_torch_tpu_available():
        #     parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
        #         self.args.device
        #     )
        #     epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
        # else:
        #     epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
        if transformers.is_torch_tpu_available():
            parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                device
            )
            epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
        else:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)

        # Reset the past mems state at the beginning of each epoch if necessary.
        if self.args.past_index >= 0:
            self._past = None

        for step, inputs in enumerate(epoch_iterator):
            if self.args.sync_embedding_layers:
                assert model.module.model_type == 'opt', 'did not implement embedding layer synchronization for non-OPT models'
                model.module.model.decoder.embed_tokens.weight = model.module.lm_head.weight

            # estimate c's (param or grad norm) on epoch 0
            if epoch == 0 and step == 0 and self.args.zo_variant is not None:
                self.initialize_c_federated(model, inputs)
            elif step == 0 and self.args.zo_variant is not None and self.args.recompute_norms:
                self.initialize_c_federated(model, inputs)
            
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
                
            if self.args.zero_order_optim:
                # Get parameters that should be optimized (for layer-wise optimization and prefix-tuning)
                self.named_parameters_to_optim = []
                for name, param in model.named_parameters():
                    if self.should_optim(name, param):
                        self.named_parameters_to_optim.append((name, param))

                if self.args.zo_by_layer:
                    assert not self.args.efficient_zero_order, 'did not implement preconditioned ZO for efficient ZO yet'
                    assert self.args.zero_order_use_trainer_optim, 'preconditioned ZO requires using the trainer optimizer'
                    num_zs = self.get_num_samples()
                    layers = [np.random.choice(self.layer_names)] if self.args.pc_rnd_layer else self.layer_names

                    # for each layer: perturb only that layer and store the gradient estimates in the grad buffer
                    for layer in self.layer_names:
                        for _ in range(num_zs):
                            c_i = self.cs[layer]
                            with torch.no_grad():
                                c_i = 1.0 if c_i == 0 else c_i # if the scaling is 0, just reset it to 1 so that there can eventually be some gradient to those layers 
                                model, random_vector = self.perturb_single_layer_federated(model, layer, scaling_factor=1.0/c_i)
                                loss1 = self.zo_forward(model, inputs)
                                model, random_vector = self.perturb_single_layer_federated(model, layer, random_vector=random_vector, scaling_factor=-2.0/c_i)
                                loss2 = self.zo_forward(model, inputs)
                                model, random_vector = self.perturb_single_layer_federated(model, layer, random_vector=random_vector, scaling_factor=1.0/c_i)

                            projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)
                            # scale grad according to number of zs sampled
                            if not self.args.scale_lr_with_samples:
                                projected_grad = projected_grad / float(num_zs)
                            
                            for name, param in self.named_parameters_to_optim:
                                if self.retrieve_c(name) == layer:
                                    z_tilde = random_vector[name] * c_i

                                    if param.grad is None:
                                        param.grad = projected_grad * z_tilde
                                    else:
                                        param.grad += projected_grad * z_tilde

                            # note that  | E_z [ <z, grad of one layer > ] | is equal to norm of grad for that layer for gaussian z
                            # leverages this fact to update the grad norms
                            if self.args.zo_variant == 'grad_norm' and self.args.norm_running_update:
                                self.cs[layer] = torch.abs(projected_grad)
                else:
                    # get number of zs to sample
                    num_zs = self.get_num_samples()
                    if num_zs > 1:
                        assert self.args.zero_order_use_trainer_optim, 'cannot sample multiple zs without storing intermediate gradient. use trainer.'

                    for _ in range(num_zs):
                        # prepare for sampling new zs
                        random_vector = None
                        if self.args.efficient_zero_order:
                            random_seed = np.random.randint(1000000000)

                        with torch.no_grad():
                            # first function evaluation
                            if self.args.efficient_zero_order:
                                model = self.efficient_perturb_parameters_federated(model, random_seed)
                            elif self.args.zo_variant is not None:
                                model, random_vector = self.norm_perturb_parameters_federated(model)
                            else:
                                model, random_vector = self.perturb_parameters_federated(model)
                            loss1 = self.zo_forward(model, inputs)

                            # second function evaluation
                            if self.args.efficient_zero_order:
                                model = self.efficient_perturb_parameters_federated(model, random_seed, scaling_factor=-2)
                            elif self.args.zo_variant is not None:
                                model, random_vector = self.norm_perturb_parameters_federated(model, random_vector, scaling_factor=-2)
                            else:
                                model, random_vector = self.perturb_parameters_federated(model, random_vector, scaling_factor=-2)                 
                            loss2 = self.zo_forward(model, inputs)

                        projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)

                        # scale grad according to accumulation
                        if self.args.gradient_accumulation_steps > 1:
                            assert self.args.zero_order_use_trainer_optim, 'grad accumulation not implemented for non-trainer ZO yet'
                            projected_grad = projected_grad / self.args.gradient_accumulation_steps
                        
                        # scale grad according to number of zs sampled
                        if not self.args.scale_lr_with_samples:
                            projected_grad = projected_grad / float(num_zs)

                        # store gradient in parameter buffer if using trainer
                        # o/w, the loop will exit after one round and the update will be applied directly (see below)
                        if self.args.zero_order_use_trainer_optim:
                            if self.args.efficient_zero_order:
                                # print(random_seed)
                                torch.manual_seed(random_seed)
                            
                            for name, param in self.named_parameters_to_optim:
                                # recover noise used in perturbations
                                if self.args.efficient_zero_order:
                                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                                else:
                                    z = random_vector[name]

                                if self.args.zo_variant is not None and not self.args.change_grad_estimate:
                                    cname = self.retrieve_c(name)
                                    if cname in self.cs:
                                        z = z * self.cs[cname]

                                if param.grad is None:
                                    param.grad = projected_grad * z
                                else:
                                    param.grad += projected_grad * z

                        # reset model back to its parameters at start of step
                        if self.args.efficient_zero_order:
                            model = self.efficient_perturb_parameters_federated(model, random_seed)
                        elif self.args.zo_variant is not None:
                            model, random_vector = self.norm_perturb_parameters_federated(model, random_vector)   
                        else:
                            model, random_vector = self.perturb_parameters_federated(model, random_vector)

                # apply gradient updates
                # if using trainer, follow trainer logic to clip grad and check if parameters should be updated
                if self.args.zero_order_use_trainer_optim:
                    if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <= self.args.gradient_accumulation_steps
                        and (step + 1) == len(epoch_iterator)
                    ):
                        # Gradient norm clipping
                        if self.args.zero_order_clip_grad:
                            norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                        # Update the parameters and step scheduler
                        optimizer.step()
                        scheduler.step()
                    
                        # logging
                        if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                            self.state.global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            logs["loss"] = loss1.item()
                            if not self.args.zero_order_clip_grad:
                                norm = 0.0
                                for _, p in model.named_parameters():
                                    if p.grad is not None:
                                        norm += torch.sum(p.grad ** 2)
                                norm = torch.sqrt(norm)
                            logs["grad_norm"] = norm.item()
                            logs["learning_rate"] = (
                                scheduler.get_last_lr()[0]
                                if version.parse(torch.__version__) >= version.parse("1.4")
                                else scheduler.get_lr()[0]
                            )
                            logs["num_zs"] = num_zs
                            logs["global_step"] = self.state.global_step
                            logs["zo_forward_step"] = self.state.zo_forward_step
                            logs["max_steps"] = self.args.max_steps
                            logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                            logs["time"] = int(time.time() - start_time)
                            self.log(logs)
                            logger.info(str(logs))
                        
                        model.zero_grad()
                        self.state.global_step += 1
                        # self.epoch = epoch + (step + 1) / len(epoch_iterator)
                # if not using the trainer, the updates are resampled and directly applied to the parameters
                else:
                    # Efficient mode 
                    # WARNING: no gradient accumulation when not storing the grad
                    assert self.args.gradient_accumulation_steps == 1, 'gradient accumulation is not supported for zero-order optimization'
                    assert self.args.zero_order_sample_scheduler is None
                    assert not self.args.zero_order_clip_grad, 'gradient clipping not implemented yet for non-trainer ZO'

                    if self.args.efficient_zero_order:
                        torch.manual_seed(random_seed)     
                    for name, param in self.named_parameters_to_optim:
                        if self.args.efficient_zero_order:
                            z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                        else:
                            z = random_vector[name]
                        param.data = param.data - self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data)

                    if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                            self.state.global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            logs["loss"] = loss1.item()
                            logs["learning_rate"] = self.args.learning_rate
                            logs["global_step"] = self.state.global_step
                            logs["zo_forward_step"] = self.state.zo_forward_step
                            logs["max_steps"] = self.args.max_steps
                            logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                            logs["time"] = int(time.time() - start_time)
                            self.log(logs)
                            logger.info(str(logs))


                    self.state.global_step += 1
                    # self.epoch = epoch + (step + 1) / len(epoch_iterator)
                
                # Debug information
                # print("%.5f, %.5f" % (loss1.item(), loss2.item()))
                # print("Loss: %.10f, projected_grad: %.5f" % (loss1, projected_grad))

            # standard, non-ZO optimization
            else:
                loss_step = self.training_step(model, inputs)
                tr_loss += loss_step

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(optimizer)
                        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16:
                        norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                    else:
                        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if self.args.optimizer_variant == 'signgd':
                        for n,p in model.named_parameters():
                            if p.grad is not None:
                                p.grad = torch.sign(p.grad)

                    if transformers.is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(optimizer)
                        self.scaler.update()
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.state.global_step += 1
                    # self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                        self.state.global_step == 1 and self.args.logging_first_step
                    ):
                        logs = {}
                        logs["loss"] = loss_step.item()
                        logs["norm"] = norm.item()
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )

                        self.log(logs)
                        logger.info(str(logs))

            # 本model的本epoch是否训练结束
            if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                break
        
        return model, tr_loss

    def federated_adding(self, running_model, model_state_dicts, fed_num):
        # 将模型中相对original更新的部分加到updated上
        original = model_state_dicts["original"]
        updated = model_state_dicts["updated"]
        
        if self.is_parallel:
            with torch.no_grad():
                for name, param in running_model.module.named_parameters():
                    if self.should_optim(name, param):
                        updated[name] = updated[name] + (param.data - original[name]) / fed_num
        else:
            with torch.no_grad():
                for name, param in running_model.named_parameters():
                    if self.should_optim(name, param):
                        updated[name] = updated[name] + (param.data - original[name]) / fed_num

    # 3倍实现，不可有optimizer和lr_scheduler
    def train_federated_3times(self, fed_num=1, model_path=None, dev_objective=None):
        """
        Main training entry point.

        The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
        Add early stopping.
        """
        
        # 创建一个文件处理器，设置日志文件路径
        file_handler = py_logging.FileHandler(os.path.join(self.args.output_dir, self.args.log_file))
        file_handler.setLevel(py_logging.INFO)
        # 将处理器添加到 logger
        logger.addHandler(file_handler)
        
        if self.args.from_linearhead and model_path is None:
            super().train(model_path, dev_objective) # Train output layer using LinearHeadTrainer

        self.best_dir = None
        self.objective = -float("inf")
        self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective

        # Data loading.
        train_dataloader = self.get_train_dataloader()
        data_number = len(train_dataloader)
        #     preparing for federated learning
        def split_dataloader(dataloader, fed_num):
            # 获取原始DataLoader中的所有数据和标签
            dataset = dataloader.dataset
            indices = list(range(len(dataset)))
            # 将数据索引均分为fed_num份
            splitted_indices = [indices[i::fed_num] for i in range(fed_num)]
            # 为每一份数据创建新的DataLoader
            dataloaders = []
            for indices in splitted_indices:
                subset = Subset(dataset, indices)
                new_dataloader = DataLoader(subset, batch_size=dataloader.batch_size, shuffle=True, collate_fn=dataloader.collate_fn)
                dataloaders.append(new_dataloader)
            return dataloaders
        train_dataloaders = split_dataloader(train_dataloader, fed_num) # 将原始DataLoader均分为fed_num份
        # data_number = len(train_dataloaders[0]) # 获取DataLoader中的数据数量。由于每个联邦设备都会执行epoch++，所以应当按照每个设备的data_number来算k倍的epoch
        
        num_update_steps_per_epoch = data_number // self.args.gradient_accumulation_steps
        if num_update_steps_per_epoch == 0:
            num_update_steps_per_epoch = 1
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(data_number // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        # Create model, optimizer and scheduler
        #     获取main_model
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
        optimizer = self.optimizer
        scheduler = self.lr_scheduler

        #     Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        main_model = self.model # 获取主模型

        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            main_model, optimizer = amp.initialize(main_model, optimizer, opt_level=self.args.fp16_opt_level)

        #     Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            main_model = torch.nn.DataParallel(main_model)
        # DDP不好处理数据进程，禁用
        # #     Distributed training (should be after apex fp16 initialization)
        # if self.args.local_rank != -1:
        #     main_model = torch.nn.parallel.DistributedDataParallel(
        #         main_model,
        #         device_ids=[self.args.local_rank],
        #         output_device=self.args.local_rank,
        #         find_unused_parameters=True,
        #     )
        
        # ------------------------- Federated added -------------------------
        # 为 Federated 创建对应的 dataloaders
        self.is_parallel = isinstance(main_model, torch.nn.DataParallel) or isinstance(main_model, torch.nn.parallel.DistributedDataParallel) # 检查模型是否被 DataParallel 包装
        # 各个client的参数、优化器和学习率调度器
        model_state_dicts = {"original": copy.deepcopy(main_model.module.state_dict() if self.is_parallel else main_model.state_dict()),
                             "updated": copy.deepcopy(main_model.module.state_dict() if self.is_parallel else main_model.state_dict())}
        optimizer_state_dicts = []
        lr_scheduler_state_dicts = []
        for i in range(fed_num):
            optimizer_state_dicts.append(copy.deepcopy(self.optimizer.state_dict())) # 获取优化器参数
            lr_scheduler_state_dicts.append(copy.deepcopy(self.lr_scheduler.state_dict())) # 获取学习率参数

        # Train
        if transformers.is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.state = TrainerState()
        self.state.global_step = 0
        start_time = time.time()
        self.state.zo_forward_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.state.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.state.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.state.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.state.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.state.global_step = 0
                logger.info("  Starting fine-tuning.")

        # ------------------------- 开始训练 -------------------------
        # training
        tr_loss = torch.tensor(0.0).to(self.args.device)
        metrics = None
        
        for epoch in range(epochs_trained, int(num_train_epochs)):
            # training for federated learning
            if self.args.aggregation_freq == "batch": # 逐batch聚合
                if not self.is_batch_nums_equal(train_dataloaders):
                    raise ValueError("All train_dataloaders must have the same number of batches.")
                # 训练前的配置
                epoch_iterators = []
                # device
                device = get_model_device(main_model)
                for i in range(fed_num):
                    train_dataloader = train_dataloaders[i]
                    
                    if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                        train_dataloader.sampler.set_epoch(epoch)
                    # dataloader
                    if transformers.is_torch_tpu_available():
                        parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                            device
                        )
                        epoch_iterator = parallel_loader
                    else:
                        epoch_iterator = train_dataloader
                    epoch_iterators.append(epoch_iterator)

                # 遍历每个batch
                for step, epoch_inputs in enumerate(zip(*epoch_iterators)):
                    # Reset the past mems state at the beginning of each epoch if necessary.
                    if self.args.past_index >= 0:
                        self._past = None
                    
                    # 逐batches训练
                    for i in range(fed_num):
                        # -------------- 初始化模型参数 --------------
                        if self.is_parallel: # 获取模型参数
                            main_model.module.load_state_dict(model_state_dicts["original"])
                        else:
                            main_model.load_state_dict(model_state_dicts["original"])
                        self.optimizer.load_state_dict(optimizer_state_dicts[i]) # 获取优化器参数
                        self.lr_scheduler.load_state_dict(lr_scheduler_state_dicts[i]) # 获取学习率参数
                        torch.cuda.empty_cache()
                        # -------------- 训练模型 --------------
                        inputs = epoch_inputs[i]
                        epoch_iterator = epoch_iterators[i]
                        main_model, tr_loss = self.train_for_single_model_batch(
                            main_model, self.optimizer, self.lr_scheduler, epoch, 
                            step, inputs, tr_loss, start_time, steps_trained_in_current_epoch)
                        # -------------- 保存模型参数 --------------
                        self.federated_adding(main_model, model_state_dicts, fed_num)
                        optimizer_state_dicts[i] = copy.deepcopy(self.optimizer.state_dict()) # 保存优化器参数
                        lr_scheduler_state_dicts[i] = copy.deepcopy(self.lr_scheduler.state_dict()) # 保存学习率参数
                        torch.cuda.empty_cache()
                    # federated aggregation
                    model_state_dicts["original"] = copy.deepcopy(model_state_dicts["updated"])
                    torch.cuda.empty_cache()
                    # 本model的本epoch是否训练结束
                    if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                        break
            elif re.search(r'(\d+)batch', self.args.aggregation_freq): # 形如 "[L]batch" 的聚合频率
                L = int(re.search(r'(\d+)batch', self.args.aggregation_freq).group(1))
                if not self.is_batch_nums_equal(train_dataloaders):
                    raise ValueError("All train_dataloaders must have the same number of batches.")
                
                # 训练前的配置
                epoch_iterators = []
                # device
                device = get_model_device(main_model)
                for i in range(fed_num):
                    train_dataloader = train_dataloaders[i]
                    
                    if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                        train_dataloader.sampler.set_epoch(epoch)
                    # dataloader
                    if transformers.is_torch_tpu_available():
                        parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                            device
                        )
                        epoch_iterator = parallel_loader
                    else:
                        epoch_iterator = train_dataloader
                    epoch_iterators.append(epoch_iterator)

                # Reset the past mems state at the beginning of each epoch if necessary.
                if self.args.past_index >= 0:
                    self._past = None
                    
                # 初始化
                iterators = [iter(dl) for dl in epoch_iterators] # 每个 DataLoader 的迭代器
                step_counters = [0] * fed_num  # step 计数器列表
                exhausted_iterators = [False] * fed_num # 跟踪每个 DataLoader 是否已经遍历完
                # 遍历所有 dataLoader
                while not all(exhausted_iterators):
                    for i, it in enumerate(iterators):
                        if not exhausted_iterators[i]:
                            # 训练该 dataLoader 中的 L 个 batch
                            # -------------- 初始化模型参数 --------------
                            if self.is_parallel: # 获取模型参数
                                main_model.module.load_state_dict(model_state_dicts["original"])
                            else:
                                main_model.load_state_dict(model_state_dicts["original"])
                            self.optimizer.load_state_dict(optimizer_state_dicts[i]) # 获取优化器参数
                            self.lr_scheduler.load_state_dict(lr_scheduler_state_dicts[i]) # 获取学习率参数
                            torch.cuda.empty_cache()
                            # -------------- 训练 L 个 batch --------------
                            try:
                                # 尝试获取该 DataLoader 的 L 个 batch
                                for _ in range(L):
                                    # 获取 batch inputs
                                    step = step_counters[i]
                                    step_counters[i] += 1  # 增加 step 计数器
                                    inputs = next(it)
                                    # -------------- 训练模型 --------------
                                    main_model, tr_loss = self.train_for_single_model_batch(
                                        main_model, self.optimizer, self.lr_scheduler, epoch, 
                                        step, inputs, tr_loss, start_time, steps_trained_in_current_epoch)
                            except StopIteration:
                                # 如果 DataLoader 中的数据已经被完全遍历，则标记为已耗尽
                                exhausted_iterators[i] = True
                            except Exception as e:
                                # 其他异常就是真的异常了
                                raise e
                            # -------------- 保存模型参数 --------------
                            self.federated_adding(main_model, model_state_dicts, fed_num)
                            optimizer_state_dicts[i] = copy.deepcopy(self.optimizer.state_dict()) # 保存优化器参数
                            lr_scheduler_state_dicts[i] = copy.deepcopy(self.lr_scheduler.state_dict()) # 保存学习率参数
                            torch.cuda.empty_cache()
                    # 每个 dataloader 都跑完 L 个 batch 了，聚合分发新模型
                    model_state_dicts["original"] = copy.deepcopy(model_state_dicts["updated"])
                    torch.cuda.empty_cache()

                    # 本model的本epoch是否训练结束
                    if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                        break
            elif self.args.aggregation_freq == "epoch": # 逐epoch聚合
                for i in range(fed_num):
                    # -------------- 初始化模型参数 --------------
                    if self.is_parallel: # 获取模型参数
                        main_model.module.load_state_dict(model_state_dicts["original"])
                    else:
                        main_model.load_state_dict(model_state_dicts["original"])
                    self.optimizer.load_state_dict(optimizer_state_dicts[i]) # 获取优化器参数
                    self.lr_scheduler.load_state_dict(lr_scheduler_state_dicts[i]) # 获取学习率参数
                    torch.cuda.empty_cache()
                    # -------------- 训练模型 --------------
                    main_model, tr_loss = self.train_for_single_model_epoch(
                        main_model, train_dataloader, self.optimizer, self.lr_scheduler, epoch, 
                        tr_loss, start_time, steps_trained_in_current_epoch)
                    # -------------- 保存模型参数 --------------
                    self.federated_adding(main_model, model_state_dicts, fed_num)
                    optimizer_state_dicts[i] = copy.deepcopy(self.optimizer.state_dict()) # 保存优化器参数
                    lr_scheduler_state_dicts[i] = copy.deepcopy(self.lr_scheduler.state_dict()) # 保存学习率参数
                    torch.cuda.empty_cache()
                # federated aggregation
                model_state_dicts["original"] = copy.deepcopy(model_state_dicts["updated"])
                torch.cuda.empty_cache()
            else:
                raise ValueError("aggregation_freq must be 'batch' or 'epoch'.")
            
            self.epoch = epoch + 1
            
            # evaluate and save
            # if self.args.evaluate_during_training and self.state.global_step % self.args.eval_steps == 0:
            if self.args.evaluate_during_training and epoch % self.args.eval_epoch == 0:
                logger.info(f"Evaluating at epoch {epoch}...")
                output = self.evaluate()
                metrics = output.metrics
                objective = self.dev_objective(metrics)
                if objective > self.objective:
                    logger.info("Best dev result: {}".format(objective))
                    self.objective = objective
                    # self.save_model(self.args.output_dir)

                    # Now we save this to (CPU) memory instead of disk <-- much faster
                    # self.best_model_ckpt = {k: v.detach().cpu() for k, v in main_model.state_dict().items()}
                    self.best_model_ckpt = {k: v.detach().cpu() for k, v in (main_model.module.state_dict().items() if self.is_parallel else main_model.state_dict().items())}
                    
            if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                # train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        # ------------------------- 训练结束 -------------------------
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.state.global_step, tr_loss / self.state.global_step, metrics), self.objective


    # --------------------------------------------------- FedAvg ---------------------------------------------------
    def create_optimizer_and_scheduler_single(self, model, num_training_steps: int):
        if self.args.hf_inference_model:
            return None, None

        # optimizer
        params = {}
        for n, p in model.named_parameters():
            if self.args.fix_layers > 0:
                if 'encoder.layer' in n:
                    try:
                        layer_num = int(n[n.find('encoder.layer') + 14:].split('.')[0])
                    except:
                        print(n)
                        raise Exception("")
                    if layer_num >= self.args.fix_layers:
                        print('yes', n)
                        params[n] = p
                    else:
                        print('no ', n)
                elif 'embeddings' in n:
                    print('no ', n)
                else:
                    print('yes', n)
                    params[n] = p
            else:
                params[n] = p
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in params.items() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        if self.args.optimizer == 'adam':
            optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        elif self.args.optimizer == 'sgd':
            optimizer = SGD(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate
            )
        else:
            raise NotImplementedError
        
        # scheduler
        lr_scheduler = get_scheduler(
            self.args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
            num_training_steps=num_training_steps,
        )
    
        return optimizer, lr_scheduler

    def prepare_model_optimizer_scheduler_single(self, main_model, t_total, model_path, device=None):
        # 如果不指定，那就放到主设备上
        if device == None:
            device = self.args.device
        model = copy.deepcopy(main_model.module if self.is_parallel else main_model)
        model.to(device)
        optimizer, scheduler = self.create_optimizer_and_scheduler_single(model=model, num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        # 16-bit (mixed) precision training
        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        #     Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            # model = torch.nn.DataParallel(model)
            # 保证正确放到指定的主device上
            model = torch.nn.DataParallel(model, device_ids=get_resorted_device_ids(_get_all_device_indices(), device))
        # DDP不好处理数据进程，禁用
        # #     Distributed training (should be after apex fp16 initialization)
        # if self.args.local_rank != -1:
        #     model = torch.nn.parallel.DistributedDataParallel(
        #         model,
        #         device_ids=[self.args.local_rank],
        #         output_device=self.args.local_rank,
        #         find_unused_parameters=True,
        #     )

        return model, optimizer, scheduler

    def federated_averaging(self, models, main_model):
        device = self.args.device
        # 提取每个模型的 state_dict
        state_dicts = [model.module.state_dict() if self.is_parallel else model.state_dict() for model in models]
        # 遍历模型的参数
        with torch.no_grad():
            for param_name, param in main_model.module.named_parameters() if self.is_parallel else main_model.named_parameters():
                # 初始化参数和为0的tensor，与param具有相同的形状和数据类型
                param_sum = torch.zeros_like(param.data, device=device)
                # 对所有模型的对应参数进行累加
                for state_dict in state_dicts:
                    param_sum += state_dict[param_name].to(device)
                # 计算平均值并赋值给averaged_model的对应参数
                param.data.copy_(param_sum / len(models))
        return main_model

    # 存在瑕疵：尚未考虑per batch和per epoch聚合
    def train_FedAvg(self, fed_num=1, model_path=None, dev_objective=None):
        """
        Main training entry point.

        The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
        Add early stopping.
        """
        
        # 创建一个文件处理器，设置日志文件路径
        file_handler = py_logging.FileHandler(os.path.join(self.args.output_dir, self.args.log_file))
        file_handler.setLevel(py_logging.INFO)
        # 将处理器添加到 logger
        logger.addHandler(file_handler)
        
        if self.args.from_linearhead and model_path is None:
            super().train(model_path, dev_objective) # Train output layer using LinearHeadTrainer

        self.best_dir = None
        self.objective = -float("inf")
        self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective

        # Data loading.
        train_dataloader = self.get_train_dataloader()
        data_number = len(train_dataloader)
        #     preparing for federated learning
        def split_dataloader(dataloader, fed_num):
            # 获取原始DataLoader中的所有数据和标签
            dataset = dataloader.dataset
            indices = list(range(len(dataset)))
            # 将数据索引均分为fed_num份
            splitted_indices = [indices[i::fed_num] for i in range(fed_num)]
            # 为每一份数据创建新的DataLoader
            dataloaders = []
            for indices in splitted_indices:
                subset = Subset(dataset, indices)
                new_dataloader = DataLoader(subset, batch_size=dataloader.batch_size, shuffle=True, collate_fn=dataloader.collate_fn)
                dataloaders.append(new_dataloader)
            return dataloaders
        train_dataloaders = split_dataloader(train_dataloader, fed_num) # 将原始DataLoader均分为fed_num份
        # data_number = len(train_dataloaders[0]) # 获取DataLoader中的数据数量。由于每个联邦设备都会执行epoch++，所以应当按照每个设备的data_number来算k倍的epoch
        
        num_update_steps_per_epoch = data_number // self.args.gradient_accumulation_steps
        if num_update_steps_per_epoch == 0:
            num_update_steps_per_epoch = 1
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(data_number // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        # Create model, optimizer and scheduler
        #     获取main_model
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
        optimizer = self.optimizer
        scheduler = self.lr_scheduler

        #     Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        main_model = self.model # 获取主模型

        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            main_model, optimizer = amp.initialize(main_model, optimizer, opt_level=self.args.fp16_opt_level)

        #     Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            main_model = torch.nn.DataParallel(main_model)
        # DDP不好处理数据进程，禁用
        # #     Distributed training (should be after apex fp16 initialization)
        # if self.args.local_rank != -1:
        #     main_model = torch.nn.parallel.DistributedDataParallel(
        #         main_model,
        #         device_ids=[self.args.local_rank],
        #         output_device=self.args.local_rank,
        #         find_unused_parameters=True,
        #     )
        
        del optimizer, scheduler # 在之后的训练过程中，不会调用主模型中的optimizer和scheduler，因此删除
        
        # 判断是否为多GPU，注册为对象属性
        self.is_parallel = isinstance(main_model, torch.nn.DataParallel) or isinstance(main_model, torch.nn.parallel.DistributedDataParallel) # 检查模型是否被 DataParallel 包装
        device_ids = _get_all_device_indices()
        device_num = len(device_ids)

        # Create model, optimizer and scheduler for federated learning
        models = []
        optimizers = []
        schedulers = []
        for i in range(fed_num):
            # 将各个子模型放到不同的卡上，从而负载均衡
            model_single, optimizer_single, scheduler_single = self.prepare_model_optimizer_scheduler_single(
                main_model, t_total, model_path, device=device_ids[(i+1) % device_num])
            if self.args.gradient_checkpointing:
                model_single.gradient_checkpointing_enable()
            model_single.zero_grad()
            models.append(model_single)
            optimizers.append(optimizer_single)
            schedulers.append(scheduler_single)

        # Train
        if transformers.is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.state = TrainerState()
        self.state.global_step = 0
        start_time = time.time()
        self.state.zo_forward_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.state.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.state.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.state.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.state.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.state.global_step = 0
                logger.info("  Starting fine-tuning.")

        # training
        tr_loss = torch.tensor(0.0).to(self.args.device)
        logging_loss_scalar = 0.0
        metrics = None
        
        for epoch in range(epochs_trained, int(num_train_epochs)):
            # tr_losses = []
            # logging_loss_scalars = []
            # # training for federated learning
            # for i in range(fed_num):
            #     models[i], tr_loss_single, logging_loss_scalar_single = self.train_for_single_model(
            #         models[i], train_dataloaders[i], optimizers[i], schedulers[i], 
            #         epoch, tr_loss, logging_loss_scalar, start_time, steps_trained_in_current_epoch)
            #     tr_losses.append(tr_loss_single)
            #     logging_loss_scalars.append(logging_loss_scalar_single)
                
            # # aggregation
            # main_model = self.federated_averaging(models, main_model)
            # #     用 averaged_model 的参数来更新每个客户端模型的参数
            # for model_single in models:
            #     model_single.load_state_dict(main_model.state_dict())
            # self.state.global_step += 1 # 更新 global_step
            
            # tr_loss = torch.mean(torch.stack(tr_losses))
            # logging_loss_scalar = np.mean(logging_loss_scalars)

            
            
            # training for federated learning
            if self.args.aggregation_freq == "batch": # 逐batch聚合
                raise NotImplementedError()
                if not self.is_batch_nums_equal(train_dataloaders):
                    raise ValueError("All train_dataloaders must have the same number of batches.")
                # 训练前的配置
                epoch_iterators = []
                # device
                device = get_model_device(main_model)
                for i in range(fed_num):
                    train_dataloader = train_dataloaders[i]
                    
                    if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                        train_dataloader.sampler.set_epoch(epoch)
                    # dataloader
                    if transformers.is_torch_tpu_available():
                        parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                            device
                        )
                        epoch_iterator = parallel_loader
                    else:
                        epoch_iterator = train_dataloader
                    epoch_iterators.append(epoch_iterator)

                # 遍历每个batch
                for step, epoch_inputs in enumerate(zip(*epoch_iterators)):
                    # Reset the past mems state at the beginning of each epoch if necessary.
                    if self.args.past_index >= 0:
                        self._past = None
                    
                    # 逐batches训练
                    for i in range(fed_num):
                        # -------------- 初始化模型参数 --------------
                        if self.is_parallel: # 获取模型参数
                            main_model.module.load_state_dict(model_state_dicts["original"])
                        else:
                            main_model.load_state_dict(model_state_dicts["original"])
                        self.optimizer.load_state_dict(optimizer_state_dicts[i]) # 获取优化器参数
                        self.lr_scheduler.load_state_dict(lr_scheduler_state_dicts[i]) # 获取学习率参数
                        torch.cuda.empty_cache()
                        # -------------- 训练模型 --------------
                        inputs = epoch_inputs[i]
                        epoch_iterator = epoch_iterators[i]
                        main_model, tr_loss = self.train_for_single_model_batch(
                            main_model, self.optimizer, self.lr_scheduler, epoch, 
                            step, inputs, tr_loss, start_time, steps_trained_in_current_epoch)
                        # -------------- 保存模型参数 --------------
                        self.federated_adding(main_model, model_state_dicts, fed_num)
                        optimizer_state_dicts[i] = copy.deepcopy(self.optimizer.state_dict()) # 保存优化器参数
                        lr_scheduler_state_dicts[i] = copy.deepcopy(self.lr_scheduler.state_dict()) # 保存学习率参数
                        torch.cuda.empty_cache()
                    # federated aggregation
                    model_state_dicts["original"] = copy.deepcopy(model_state_dicts["updated"])
                    torch.cuda.empty_cache()
                    # 本model的本epoch是否训练结束
                    if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                        break
            elif re.search(r'(\d+)batch', self.args.aggregation_freq): # 形如 "[L]batch" 的聚合频率
                L = int(re.search(r'(\d+)batch', self.args.aggregation_freq).group(1))
                if not self.is_batch_nums_equal(train_dataloaders):
                    raise ValueError("All train_dataloaders must have the same number of batches.")
                
                # 训练前的配置
                epoch_iterators = []
                # device
                device = get_model_device(main_model)
                for i in range(fed_num):
                    train_dataloader = train_dataloaders[i]
                    
                    if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                        train_dataloader.sampler.set_epoch(epoch)
                    # dataloader
                    if transformers.is_torch_tpu_available():
                        parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                            device
                        )
                        epoch_iterator = parallel_loader
                    else:
                        epoch_iterator = train_dataloader
                    epoch_iterators.append(epoch_iterator)

                # Reset the past mems state at the beginning of each epoch if necessary.
                if self.args.past_index >= 0:
                    self._past = None
                    
                # 初始化
                iterators = [iter(dl) for dl in epoch_iterators] # 每个 DataLoader 的迭代器
                step_counters = [0] * fed_num  # step 计数器列表
                exhausted_iterators = [False] * fed_num # 跟踪每个 DataLoader 是否已经遍历完
                # 遍历所有 dataLoader
                while not all(exhausted_iterators):
                    for i, it in enumerate(iterators):
                        if not exhausted_iterators[i]:
                            # 训练该 dataLoader 中的 L 个 batch
                            # -------------- 初始化模型参数 --------------
                            if self.is_parallel: # 获取模型参数
                                main_model.module.load_state_dict(model_state_dicts["original"])
                            else:
                                main_model.load_state_dict(model_state_dicts["original"])
                            self.optimizer.load_state_dict(optimizer_state_dicts[i]) # 获取优化器参数
                            self.lr_scheduler.load_state_dict(lr_scheduler_state_dicts[i]) # 获取学习率参数
                            torch.cuda.empty_cache()
                            # -------------- 训练 L 个 batch --------------
                            try:
                                # 尝试获取该 DataLoader 的 L 个 batch
                                for _ in range(L):
                                    # 获取 batch inputs
                                    step = step_counters[i]
                                    step_counters[i] += 1  # 增加 step 计数器
                                    inputs = next(it)
                                    # -------------- 训练模型 --------------
                                    main_model, tr_loss = self.train_for_single_model_batch(
                                        main_model, self.optimizer, self.lr_scheduler, epoch, 
                                        step, inputs, tr_loss, start_time, steps_trained_in_current_epoch)
                            except StopIteration:
                                # 如果 DataLoader 中的数据已经被完全遍历，则标记为已耗尽
                                exhausted_iterators[i] = True
                            except Exception as e:
                                # 其他异常就是真的异常了
                                raise e
                            # -------------- 保存模型参数 --------------
                            self.federated_adding(main_model, model_state_dicts, fed_num)
                            optimizer_state_dicts[i] = copy.deepcopy(self.optimizer.state_dict()) # 保存优化器参数
                            lr_scheduler_state_dicts[i] = copy.deepcopy(self.lr_scheduler.state_dict()) # 保存学习率参数
                            torch.cuda.empty_cache()
                    # 每个 dataloader 都跑完 L 个 batch 了，聚合分发新模型
                    model_state_dicts["original"] = copy.deepcopy(model_state_dicts["updated"])
                    torch.cuda.empty_cache()

                    # 本model的本epoch是否训练结束
                    if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                        break
            elif self.args.aggregation_freq == "epoch": # 逐epoch聚合
                for i in range(fed_num):
                    # -------------- 训练模型 --------------
                    _, tr_loss = self.train_for_single_model_epoch(
                        models[i], train_dataloaders[i], optimizers[i], schedulers[i], epoch, 
                        tr_loss, start_time, steps_trained_in_current_epoch)
                # federated aggregation
                main_model = self.federated_averaging(models, main_model)
                #     用 averaged_model 的参数来更新每个客户端模型的参数
                for model_single in models:
                    model_single.load_state_dict(main_model.state_dict())
                self.state.global_step += 1 # 更新 global_step
                
            else:
                raise ValueError("aggregation_freq must be 'batch' or 'epoch'.")
            
            self.epoch = epoch + 1
            
            
            
            # evaluate and save
            # if self.args.evaluate_during_training and self.state.global_step % self.args.eval_steps == 0:
            if self.args.evaluate_during_training and epoch % self.args.eval_epoch == 0:
                logger.info(f"Evaluating at epoch {epoch}...")
                output = self.evaluate()
                metrics = output.metrics
                objective = self.dev_objective(metrics)
                if objective > self.objective:
                    logger.info("Best dev result: {}".format(objective))
                    self.objective = objective
                    # self.save_model(self.args.output_dir)

                    # Now we save this to (CPU) memory instead of disk <-- much faster
                    # self.best_model_ckpt = {k: v.detach().cpu() for k, v in main_model.state_dict().items()}
                    self.best_model_ckpt = {k: v.detach().cpu() for k, v in (main_model.module.state_dict().items() if self.is_parallel else main_model.state_dict().items())}
                    
            if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                # train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.state.global_step, tr_loss / self.state.global_step, metrics), self.objective


    # --------------------------------------------------- CeZO ---------------------------------------------------
    def train_for_single_model_CeZO_batch(self, model, inputs, start_time, steps_trained_in_current_epoch):
        if self.args.sync_embedding_layers:
            assert model.module.model_type == 'opt', 'did not implement embedding layer synchronization for non-OPT models'
            model.module.model.decoder.embed_tokens.weight = model.module.lm_head.weight

        # 初始值
        projected_grad = 0
        random_seed = 0
        
        # Skip past any already trained steps if resuming training
        if steps_trained_in_current_epoch > 0:
            steps_trained_in_current_epoch -= 1
            # continue
            return model, projected_grad, random_seed
            
        if self.args.zero_order_optim:
            # 已经创建过 self.named_parameters_to_optim 了，无需再创建
            # # Get parameters that should be optimized (for layer-wise optimization and prefix-tuning)
            # self.named_parameters_to_optim = []
            # for name, param in model.named_parameters():
            #     if self.should_optim(name, param):
            #         self.named_parameters_to_optim.append((name, param))

            if self.args.zo_by_layer:
                raise NotImplementedError()
            else:
                # get number of zs to sample
                num_zs = self.get_num_samples()
                if num_zs > 1:
                    assert self.args.zero_order_use_trainer_optim, 'cannot sample multiple zs without storing intermediate gradient. use trainer.'

                for _ in range(num_zs):
                    # prepare for sampling new zs
                    if self.args.efficient_zero_order:
                        random_seed = self.seed_

                    with torch.no_grad():
                        # first function evaluation
                        model = self.efficient_perturb_parameters_federated(model, random_seed)
                        loss1 = self.zo_forward(model, inputs)

                        # second function evaluation
                        model = self.efficient_perturb_parameters_federated(model, random_seed, scaling_factor=-2)
                        loss2 = self.zo_forward(model, inputs)

                    projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)
                    projected_grad = 1 if projected_grad > 0 else -1

                    # # scale grad according to accumulation
                    # if self.args.gradient_accumulation_steps > 1:
                    #     assert self.args.zero_order_use_trainer_optim, 'grad accumulation not implemented for non-trainer ZO yet'
                    #     projected_grad = projected_grad / self.args.gradient_accumulation_steps
                    
                    # scale grad according to number of zs sampled
                    if not self.args.scale_lr_with_samples:
                        projected_grad = projected_grad / float(num_zs)

                    # reset model back to its parameters at start of step
                    model = self.efficient_perturb_parameters_federated(model, random_seed)

            # # apply gradient updates
            # # if using trainer, follow trainer logic to clip grad and check if parameters should be updated
            # if self.args.zero_order_use_trainer_optim:
            #     raise Exception('zero_order_use_trainer_optim is not supported')
            # else:
            #     # Efficient mode 
            #     # WARNING: no gradient accumulation when not storing the grad
            #     # assert self.args.gradient_accumulation_steps == 1, 'gradient accumulation is not supported for zero-order optimization'
            #     # assert self.args.zero_order_sample_scheduler is None
            #     # assert not self.args.zero_order_clip_grad, 'gradient clipping not implemented yet for non-trainer ZO'
            #     print('efficient_zero_order')
            #     if self.args.efficient_zero_order:
            #         torch.manual_seed(random_seed)
            #     for name, param in model.named_parameters():
            #         if self.should_optim(name, param):
            #             if self.args.efficient_zero_order:
            #                 z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
            #             else:
            #                 z = random_vector[name]
            #             param.data = param.data - self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data) # key function

                if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                        self.state.global_step == 1 and self.args.logging_first_step
                    ):
                        logs = {}
                        logs["loss"] = loss1.item()
                        logs["learning_rate"] = self.args.learning_rate
                        logs["global_step"] = self.state.global_step
                        logs["zo_forward_step"] = self.state.zo_forward_step
                        logs["max_steps"] = self.args.max_steps
                        logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                        logs["time"] = int(time.time() - start_time)
                        self.log(logs)
                        logger.info(str(logs))


                self.state.global_step += 1
                # self.epoch = epoch + (step + 1) / len(epoch_iterator)
            
            # Debug information
            # print("%.5f, %.5f" % (loss1.item(), loss2.item()))
            # print("Loss: %.10f, projected_grad: %.5f" % (loss1, projected_grad))
            
            
            if self.args.efficient_zero_order:
                random_seed = random_seed
                projected_grad = projected_grad
            
        return model, projected_grad, random_seed
    
    def train_for_single_model_CeZO_epoch(self, model, train_dataloader, epoch, start_time, steps_trained_in_current_epoch):
        if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
            train_dataloader.sampler.set_epoch(epoch)

        device = get_model_device(model)
        if transformers.is_torch_tpu_available():
            parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                device
            )
            epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
        else:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)

        # Reset the past mems state at the beginning of each epoch if necessary.
        if self.args.past_index >= 0:
            self._past = None

        random_seed_epoch = [] # 存储每个batch的随机种子
        projected_grad_epoch = [] # 存储每个batch的 projected_grad
        
        for step, inputs in enumerate(epoch_iterator):
            if self.args.sync_embedding_layers:
                assert model.module.model_type == 'opt', 'did not implement embedding layer synchronization for non-OPT models'
                model.module.model.decoder.embed_tokens.weight = model.module.lm_head.weight

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
                
            if self.args.zero_order_optim:
                # 已经创建过 self.named_parameters_to_optim 了，无需再创建
                # # Get parameters that should be optimized (for layer-wise optimization and prefix-tuning)
                # self.named_parameters_to_optim = []
                # for name, param in model.named_parameters():
                #     if self.should_optim(name, param):
                #         self.named_parameters_to_optim.append((name, param))

                if self.args.zo_by_layer:
                    raise NotImplementedError()
                else:
                    # get number of zs to sample
                    num_zs = self.get_num_samples()
                    if num_zs > 1:
                        assert self.args.zero_order_use_trainer_optim, 'cannot sample multiple zs without storing intermediate gradient. use trainer.'

                    for _ in range(num_zs):
                        # prepare for sampling new zs
                        random_seed = np.random.randint(1000000000)

                        with torch.no_grad():
                            # first function evaluation
                            model = self.efficient_perturb_parameters_federated(model, random_seed)
                            loss1 = self.zo_forward(model, inputs)

                            # second function evaluation
                            model = self.efficient_perturb_parameters_federated(model, random_seed, scaling_factor=-2)
                            loss2 = self.zo_forward(model, inputs)

                        projected_grad = (loss1 - loss2) / (2 * self.args.zero_order_eps)

                        # scale grad according to accumulation
                        if self.args.gradient_accumulation_steps > 1:
                            assert self.args.zero_order_use_trainer_optim, 'grad accumulation not implemented for non-trainer ZO yet'
                            projected_grad = projected_grad / self.args.gradient_accumulation_steps
                        
                        # scale grad according to number of zs sampled
                        if not self.args.scale_lr_with_samples:
                            projected_grad = projected_grad / float(num_zs)

                        # reset model back to its parameters at start of step
                        model = self.efficient_perturb_parameters_federated(model, random_seed)

                # apply gradient updates
                # if using trainer, follow trainer logic to clip grad and check if parameters should be updated
                if self.args.zero_order_use_trainer_optim:
                    raise NotImplementedError()
                else:
                    # Efficient mode 
                    # WARNING: no gradient accumulation when not storing the grad
                    assert self.args.gradient_accumulation_steps == 1, 'gradient accumulation is not supported for zero-order optimization'
                    assert self.args.zero_order_sample_scheduler is None
                    assert not self.args.zero_order_clip_grad, 'gradient clipping not implemented yet for non-trainer ZO'

                    if self.args.efficient_zero_order:
                        torch.manual_seed(random_seed)
                    for name, param in model.named_parameters():
                        if self.should_optim(name, param):
                            if self.args.efficient_zero_order:
                                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                            else:
                                z = random_vector[name]
                            param.data = param.data - self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data) # key function

                    if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                            self.state.global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            logs["loss"] = loss1.item()
                            logs["learning_rate"] = self.args.learning_rate
                            logs["global_step"] = self.state.global_step
                            logs["zo_forward_step"] = self.state.zo_forward_step
                            logs["max_steps"] = self.args.max_steps
                            logs["max_zo_forward_steps"] = self.args.max_zo_forward_steps
                            logs["time"] = int(time.time() - start_time)
                            self.log(logs)
                            logger.info(str(logs))


                    self.state.global_step += 1
                    # self.epoch = epoch + (step + 1) / len(epoch_iterator)
                
                # Debug information
                # print("%.5f, %.5f" % (loss1.item(), loss2.item()))
                # print("Loss: %.10f, projected_grad: %.5f" % (loss1, projected_grad))
                
                
                if self.args.efficient_zero_order:
                    random_seed_epoch.append(random_seed)
                    projected_grad_epoch.append(projected_grad.detach().cpu().item())

            # 本model的本epoch是否训练结束
            if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                epoch_iterator.close()
                break
        
        return model, projected_grad_epoch, random_seed_epoch
    
    # 根据 projected_grads, random_seeds 实现 Communication-Efficient Federated Learning
    def federated_averaging_CeZO_batch(self, model, fed_num, projected_grad_batches, random_seed_batches):
        with torch.no_grad():
            # for i in range(len(projected_grad_batches)):
            #     projected_grad = projected_grad_batches[i]
            #     random_seed = random_seed_batches[i]
                
            #     # 保存当前随机数生成器的状态
            #     initial_rng_state = torch.get_rng_state()
            #     # 设置新的随机种子，新模型参数
            #     torch.manual_seed(random_seed)
                
            #     # 更新参数
            #     for name, param in model.named_parameters():
            #         if self.should_optim(name, param):
            #             z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
            #             # 直接增量平均（遍历所有model，参数改变量加权累加即可）
            #             param.data = param.data - (self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data)) / fed_num # key function
                
            #     # 恢复到初始的随机数生成器状态
            #     torch.set_rng_state(initial_rng_state)
            
            print(projected_grad_batches)
            print(random_seed_batches)
            print([projected_grad_batches[q :: self.Q] for q in range(self.Q)])
            print([1 if sum(projected_grad_batches[q :: self.Q]) > 0 else -1 for q in range(self.Q)])
            
            for q in range(self.Q):
                projected_grad = 1 if sum(projected_grad_batches[q :: self.Q]) > 0 else -1
                random_seed = random_seed_batches[q]
                torch.manual_seed(random_seed)
                for name, param in model.named_parameters():
                    if self.should_optim(name, param):
                        z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                        param.data = param.data - (self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data)) / self.Q
                    
        return model

    def federated_averaging_CeZO_epoch(self, model, fed_num, projected_grad_epoches, random_seed_epoches):
        with torch.no_grad():
            for i in range(len(projected_grad_epoches)):
                for j in range(len(projected_grad_epoches[i])): # 遍历所有batch
                    # 确定更新参数
                    projected_grad = projected_grad_epoches[i][j]
                    random_seed = random_seed_epoches[i][j]
                    
                    # 保存当前随机数生成器的状态
                    initial_rng_state = torch.get_rng_state()
                    # 设置新的随机种子，新模型参数
                    torch.manual_seed(random_seed)
                    
                    # 更新参数
                    for name, param in model.named_parameters():
                        if self.should_optim(name, param):
                            z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                            # 直接增量平均（遍历所有model，参数改变量加权累加即可）
                            param.data = param.data - (self.args.learning_rate * (projected_grad * z + self.args.weight_decay * param.data)) / fed_num # key function
                    
                    # 恢复到初始的随机数生成器状态
                    torch.set_rng_state(initial_rng_state)

    def train_CeZO(self, fed_num=1, model_path=None, dev_objective=None):
        
        # self.Q = 5
        self.Q = self.args.onebit_Q
        self.seed_ = -1
        
        
        """
        Main training entry point.

        The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
        Add early stopping.
        """
        # 首先保证训练参数正确
        assert self.args.zero_order_optim, "CeZO requires zero-order optim!"
        assert self.args.efficient_zero_order, "CeZO requires efficient zero order!"
        assert not self.args.zero_order_use_trainer_optim, "CeZO cannot use trainer to optimize!"
        assert self.args.gradient_accumulation_steps == 1, 'gradient accumulation is not supported for zero-order optimization'
        assert self.args.zero_order_sample_scheduler is None
        assert not self.args.zero_order_clip_grad, 'gradient clipping not implemented yet for non-trainer ZO'

        
        # 创建一个文件处理器，设置日志文件路径
        file_handler = py_logging.FileHandler(os.path.join(self.args.output_dir, self.args.log_file))
        file_handler.setLevel(py_logging.INFO)
        # 将处理器添加到 logger
        logger.addHandler(file_handler)
        
        if self.args.from_linearhead and model_path is None:
            super().train(model_path, dev_objective) # Train output layer using LinearHeadTrainer

        self.best_dir = None
        self.objective = -float("inf")
        self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective

        # Data loading.
        train_dataloader = self.get_train_dataloader()
        data_number = len(train_dataloader) # 获取DataLoader中的数据数量
        #     preparing for federated learning
        def split_dataloader(dataloader, fed_num):
            # 获取原始DataLoader中的所有数据和标签
            dataset = dataloader.dataset
            batch_size = dataloader.batch_size
            indices = list(range(len(dataset)))
            # 将数据索引均分为fed_num份
            splitted_indices = [indices[i::fed_num] for i in range(fed_num)]
            # 确保每份中数据数量相同
            sample_num = max(len(indices) for indices in splitted_indices)
            # 为每一份数据创建新的DataLoader
            dataloaders = []
            for indices in splitted_indices:
                if len(indices) < sample_num:
                    temp_slice = indices[0: sample_num - len(indices)]
                    indices = indices + temp_slice
                subset = Subset(dataset, indices)
                new_dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True, collate_fn=dataloader.collate_fn)
                dataloaders.append(new_dataloader)
            return dataloaders
        train_dataloaders = split_dataloader(train_dataloader, fed_num) # 将原始DataLoader均分为fed_num份
        num_update_steps_per_epoch = data_number // self.args.gradient_accumulation_steps # 梯度聚合(未启用)
        if num_update_steps_per_epoch == 0:
            num_update_steps_per_epoch = 1
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(data_number // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        # Create model, optimizer and scheduler
        #     获取main_model
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
        optimizer = self.optimizer
        scheduler = self.lr_scheduler

        #     Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        main_model = self.model # 获取主模型

        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            main_model, optimizer = amp.initialize(main_model, optimizer, opt_level=self.args.fp16_opt_level)

        #     Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            main_model = torch.nn.DataParallel(main_model)
        # DDP不好处理数据进程，禁用
        # #     Distributed training (should be after apex fp16 initialization)
        # if self.args.local_rank != -1:
        #     main_model = torch.nn.parallel.DistributedDataParallel(
        #         main_model,
        #         device_ids=[self.args.local_rank],
        #         output_device=self.args.local_rank,
        #         find_unused_parameters=True,
        #     )
        
        del optimizer, scheduler # 在之后的训练过程中，不会调用主模型中的optimizer和scheduler，因此删除
        
        # 判断是否为多GPU，注册为对象属性
        self.is_parallel = isinstance(main_model, torch.nn.DataParallel) or isinstance(main_model, torch.nn.parallel.DistributedDataParallel) # 检查模型是否被 DataParallel 包装

        # Train
        if transformers.is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.state = TrainerState()
        self.state.global_step = 0
        start_time = time.time()
        self.state.zo_forward_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        # self.named_parameters_to_optim = []
        # for name, param in main_model.named_parameters():
        #     if self.should_optim(name, param):
        #         self.named_parameters_to_optim.append((name, param))
        
        # 各个client的参数
        model_state_dicts = {"original": copy.deepcopy(main_model.module.state_dict() if self.is_parallel else main_model.state_dict())}

        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.state.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.state.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.state.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.state.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.state.global_step = 0
                logger.info("  Starting fine-tuning.")

        # training
        tr_loss = torch.tensor(0.0).to(self.args.device)
        metrics = None
        
        for epoch in range(epochs_trained, int(num_train_epochs)):
            # training for federated learning
            if self.args.aggregation_freq == "batch": # 逐batch聚合
                if not self.is_batch_nums_equal(train_dataloaders):
                    raise ValueError("All train_dataloaders must have the same number of batches.")
                # 训练前的配置
                epoch_iterators = []
                # device
                device = get_model_device(main_model)
                for i in range(fed_num):
                    train_dataloader = train_dataloaders[i]
                    
                    if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                        train_dataloader.sampler.set_epoch(epoch)
                    # dataloader
                    if transformers.is_torch_tpu_available():
                        parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                            device
                        )
                        epoch_iterator = parallel_loader
                    else:
                        epoch_iterator = train_dataloader
                    epoch_iterators.append(epoch_iterator)

                # 遍历每个batch
                for step, epoch_inputs in enumerate(zip(*epoch_iterators)):
                    # Reset the past mems state at the beginning of each epoch if necessary.
                    if self.args.past_index >= 0:
                        self._past = None
                    
                    # 逐batches训练
                    projected_grad_batches = []
                    random_seed_batches = []
                    for i in range(fed_num):
                        # -------------- 初始化模型参数 --------------
                        if self.is_parallel: # 获取模型参数
                            main_model.module.load_state_dict(model_state_dicts["original"])
                        else:
                            main_model.load_state_dict(model_state_dicts["original"])
                        torch.cuda.empty_cache()
                        # -------------- 训练模型 --------------
                        inputs = epoch_inputs[i]
                        epoch_iterator = epoch_iterators[i]
                        
                        for q in range(self.Q):
                            self.seed_ += 1
                            main_model, projected_grad, random_seed = self.train_for_single_model_CeZO_batch(
                                main_model, inputs, start_time, steps_trained_in_current_epoch)
                            projected_grad_batches.append(projected_grad)
                            random_seed_batches.append(random_seed)
                        self.seed_ -= self.Q
                    self.seed_ += self.Q
                        
                    # aggregation
                    main_model = self.federated_averaging_CeZO_batch(main_model, fed_num, projected_grad_batches, random_seed_batches)
                    model_state_dicts["original"] = copy.deepcopy(main_model.module.state_dict() if self.is_parallel else main_model.state_dict()) # 下发给K个client
                    torch.cuda.empty_cache()
                    # 本model的本epoch是否训练结束
                    if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                        break
            elif re.search(r'(\d+)batch', self.args.aggregation_freq): # 形如 "[L]batch" 的聚合频率
                raise NotImplementedError("This CeZO aggregation frequency is not supported.")
                L = int(re.search(r'(\d+)batch', self.args.aggregation_freq).group(1))
                if not self.is_batch_nums_equal(train_dataloaders):
                    raise ValueError("All train_dataloaders must have the same number of batches.")
                
                # 训练前的配置
                epoch_iterators = []
                projected_grad_multi_batches, random_seed_multi_batches = [], []
                # device
                device = get_model_device(main_model)
                for i in range(fed_num):
                    train_dataloader = train_dataloaders[i]
                    
                    if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                        train_dataloader.sampler.set_epoch(epoch)
                    # dataloader
                    if transformers.is_torch_tpu_available():
                        parallel_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(
                            device
                        )
                        epoch_iterator = parallel_loader
                    else:
                        epoch_iterator = train_dataloader
                    epoch_iterators.append(epoch_iterator)

                # Reset the past mems state at the beginning of each epoch if necessary.
                if self.args.past_index >= 0:
                    self._past = None
                    
                # 初始化
                iterators = [iter(dl) for dl in epoch_iterators] # 每个 DataLoader 的迭代器
                step_counters = [0] * fed_num  # step 计数器列表
                exhausted_iterators = [False] * fed_num # 跟踪每个 DataLoader 是否已经遍历完
                # 遍历所有 dataLoader
                while not all(exhausted_iterators):
                    for i, it in enumerate(iterators):
                        if not exhausted_iterators[i]:
                            # -------------- 初始化模型参数 --------------
                            if self.is_parallel: # 获取模型参数
                                main_model.module.load_state_dict(model_state_dicts["original"])
                            else:
                                main_model.load_state_dict(model_state_dicts["original"])
                            torch.cuda.empty_cache()
                            # -------------- 初始化 list，用来存放训练后的参数变动 --------------
                            projected_grad_batches = []
                            random_seed_batches = []
                            # -------------- 训练 L 个 batch --------------
                            try:
                                # 尝试获取该 DataLoader 的 L 个 batch
                                for _ in range(L):
                                    # 获取 batch inputs
                                    step = step_counters[i]
                                    step_counters[i] += 1  # 增加 step 计数器
                                    inputs = next(it)
                                    # -------------- 训练模型 --------------
                                    main_model, projected_grad, random_seed = self.train_for_single_model_CeZO_batch(
                                        main_model, inputs, start_time, steps_trained_in_current_epoch)
                                    projected_grad_batches.append(projected_grad)
                                    random_seed_batches.append(random_seed)
                            except StopIteration:
                                # 如果 DataLoader 中的数据已经被完全遍历，则标记为已耗尽
                                exhausted_iterators[i] = True
                            except Exception as e:
                                # 其他异常就是真的异常了
                                raise e
                            # -------------- 保存训练后的参数变动（无论满不满 L 个 batch ） --------------
                            projected_grad_multi_batches.append(projected_grad_batches)
                            random_seed_multi_batches.append(random_seed_batches)
                            
                    # 每个 dataloader 都跑完 L 个 batch 了，聚合分发新模型
                    self.federated_averaging_CeZO_epoch(main_model, fed_num, projected_grad_multi_batches, random_seed_multi_batches)
                    projected_grad_multi_batches, random_seed_multi_batches = [], []
                    model_state_dicts["original"] = copy.deepcopy(main_model.module.state_dict() if self.is_parallel else main_model.state_dict()) # 下发给K个client
                    torch.cuda.empty_cache()
                
                    # 本model的本epoch是否训练结束
                    if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                        break
            elif self.args.aggregation_freq == "epoch": # 逐epoch聚合
                raise NotImplementedError("This CeZO aggregation frequency is not supported.")
                projected_grad_epoches = []
                random_seed_epoches = []
                for i in range(fed_num):
                    # -------------- 初始化模型参数 --------------
                    if self.is_parallel: # 获取模型参数
                        main_model.module.load_state_dict(model_state_dicts["original"])
                    else:
                        main_model.load_state_dict(model_state_dicts["original"])
                    torch.cuda.empty_cache()
                    # -------------- 训练模型 --------------
                    main_model, projected_grad, random_seed = self.train_for_single_model_CeZO_epoch(
                        main_model, train_dataloaders[i], epoch, start_time, steps_trained_in_current_epoch)
                    projected_grad_epoches.append(projected_grad)
                    random_seed_epoches.append(random_seed)
                
                # federated aggregation
                self.federated_averaging_CeZO_epoch(main_model, fed_num, projected_grad_epoches, random_seed_epoches)
                model_state_dicts["original"] = copy.deepcopy(main_model.module.state_dict() if self.is_parallel else main_model.state_dict()) # 下发给K个client
                torch.cuda.empty_cache()
            else:
                raise ValueError("aggregation_freq must be 'batch' or 'epoch'.")
            
            self.epoch = epoch + 1
            
            # evaluate and save
            # if self.args.evaluate_during_training and self.state.global_step % self.args.eval_steps == 0:
            if self.args.evaluate_during_training and epoch % self.args.eval_epoch == 0:
                logger.info(f"Evaluating at epoch {epoch}...")
                output = self.evaluate()
                metrics = output.metrics
                objective = self.dev_objective(metrics)
                if objective > self.objective:
                    logger.info("Best dev result: {}".format(objective))
                    self.objective = objective
                    # self.save_model(self.args.output_dir)

                    # Now we save this to (CPU) memory instead of disk <-- much faster
                    # self.best_model_ckpt = {k: v.detach().cpu() for k, v in main_model.state_dict().items()}
                    self.best_model_ckpt = {k: v.detach().cpu() for k, v in (main_model.module.state_dict().items() if self.is_parallel else main_model.state_dict().items())}
                    
            if self.args.max_steps > 0 and self.state.global_step > self.args.max_steps or (self.args.max_zo_forward_steps > 0 and self.state.zo_forward_step > self.args.max_zo_forward_steps):
                # train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.state.global_step, tr_loss / self.state.global_step, metrics), self.objective

    """
    Difference compared to original implementation: return output instead of output.metrics (so there is also the logits)
    """
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
                columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
                the :obj:`__len__` method.

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self.prediction_loop(eval_dataloader, description="Evaluation")

        self.log(output.metrics)
        logger.info(output.metrics)

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output
