# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import os, sys
from transformers import TrainerState
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import time
import math
######## packaging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import shutil
import importlib.metadata

# isort: on

import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler

#import config
from prox_op import replace_weight, compute_mask_loss


from transformers import __version__
from transformers.configuration_utils import PretrainedConfig
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from transformers.integrations.tpu import tpu_spmd_dataloader
from transformers.modelcard import TrainingSummary
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
from transformers.optimization import Adafactor, get_scheduler
from transformers.pytorch_utils import (
    ALL_LAYERNORM_LAYERS,
    is_torch_greater_or_equal_than_1_13,
    is_torch_greater_or_equal_than_2_3,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    ExportableState,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_pt_utils import (
    DistributedTensorGatherer,
    EvalLoopContainer,
    IterableDatasetShard,
    LabelSmoother,
    LayerWiseDummyOptimizer,
    LengthGroupedSampler,
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    find_batch_size,
    get_model_param_count,
    get_module_class_from_name,
    get_parameter_names,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
    remove_dummy_checkpoint,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalLoopOutput,
    EvalPrediction,
    HPSearchBackend,
    HubStrategy,
    IntervalStrategy,
    PredictionOutput,
    RemoveColumnsCollator,
    TrainerMemoryTracker,
    TrainOutput,
    check_target_module_exists,
    default_compute_objective,
    denumpify_detensorize,
    enable_full_determinism,
    find_executable_batch_size,
    get_last_checkpoint,
    has_length,
    neftune_post_forward_hook,
    number_of_arguments,
    seed_worker,
    set_seed,
    speed_metrics,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
    ADAPTER_CONFIG_NAME,
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
    CONFIG_NAME,
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    XLA_FSDPV2_MIN_VERSION,
    PushInProgress,
    PushToHubMixin,
    can_return_loss,
    find_labels,
    is_accelerate_available,
    is_apex_available,
    is_bitsandbytes_available,
    is_datasets_available,
    is_galore_torch_available,
    is_in_notebook,
    is_ipex_available,
    is_lomo_available,
    is_peft_available,
    is_safetensors_available,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_compile_available,
    is_torch_mlu_available,
    is_torch_neuroncore_available,
    is_torch_npu_available,
    is_torch_xla_available,
    logging,
    strtobool,
)
from transformers.utils.quantization_config import QuantizationMethod
logger = logging.get_logger(__name__)


DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from transformers.utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback




if is_safetensors_available():
    import safetensors.torch

if is_peft_available():
    from peft import PeftModel


if is_accelerate_available():
    from accelerate import Accelerator, skip_first_batches
    from accelerate import __version__ as accelerate_version
    from accelerate.utils import (
        DistributedDataParallelKwargs,
        DistributedType,
        GradientAccumulationPlugin,
        is_mlu_available,
        is_mps_available,
        is_npu_available,
        is_torch_version,
        is_xpu_available,
        load_fsdp_model,
        load_fsdp_optimizer,
        save_fsdp_model,
        save_fsdp_optimizer,
    )

    DATA_SAMPLERS = [RandomSampler]
    if version.parse(accelerate_version) > version.parse("0.23.0"):
        from accelerate.data_loader import SeedableRandomSampler

        DATA_SAMPLERS += [SeedableRandomSampler]

    if is_deepspeed_available():
        from accelerate.utils import DeepSpeedSchedulerWrapper

if is_accelerate_available("0.28.0"):
    from accelerate.utils import DataLoaderConfiguration

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    from torch_xla import __version__ as XLA_VERSION

if is_apex_available():
    from apex import amp

def _is_peft_model(model):
    if is_peft_available():
        classes_to_check = (PeftModel,) if is_peft_available() else ()
        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
            from peft import PeftMixedModel

            classes_to_check = (*classes_to_check, PeftMixedModel)
        return isinstance(model, classes_to_check)
    return False

from transformers.integrations import (
    get_reporting_integration_callbacks,
    hp_params,
)


if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
######## packaging
TRAINER_STATE_NAME = "trainer_state.json"
from contextlib import contextmanager



@contextmanager
def _tmp_disable_cache_if_any(model):
    use_cache = getattr(getattr(model, "config", object()), "use_cache", None)
    has_attr = hasattr(getattr(model, "config", object()), "use_cache")
    try:
        if has_attr:
            model.config.use_cache = False
        yield
    finally:
        if has_attr:
            model.config.use_cache = use_cache

@contextmanager
def _tmp_eval_mode(model: nn.Module):
    was_training = model.training
    try:
        model.eval()
        yield
    finally:
        if was_training:
            model.train()
from collections.abc import Mapping


def recompute_wanda_metrics_for_step_(model: nn.Module,
                                     optimizer,
                                     batch_inputs: dict,
                                     normalize: bool = True,
                                     lamda_2: float=0.1):
    opt_param_ids = getattr(optimizer, "_opt_param_id_set", None)
    if opt_param_ids is None:
        opt_param_ids = set(id(p) for g in optimizer.param_groups for p in g["params"] if isinstance(p, torch.Tensor))
        optimizer._opt_param_id_set = opt_param_ids

    hooks = []
    stats = {}  # p -> {"sum": Tensor[in_features], "cnt": int}

    def make_pre_hook(p):
        def _pre_hook(module, inputs):
            x = inputs[0]
            if not torch.is_tensor(x):
                return
            with torch.no_grad():
                in_features = x.shape[-1]
                x2_mean = (x.detach().reshape(-1, in_features).pow(2).mean(dim=0)).to(p.device)
                rec = stats.get(p)
                if rec is None or rec["sum"].numel() != x2_mean.numel():
                    stats[p] = {"sum": x2_mean.clone(), "cnt": 1}
                else:
                    rec["sum"].add_(x2_mean)
                    rec["cnt"] += 1
        return _pre_hook

    for mod in model.modules():
        if isinstance(mod, nn.Linear) and hasattr(mod, "weight"):
            p = mod.weight
            if id(p) not in opt_param_ids:
                continue
            h = mod.register_forward_pre_hook(make_pre_hook(p))
            hooks.append(h)

    with torch.no_grad(), _tmp_eval_mode(model), _tmp_disable_cache_if_any(model):
        model(**batch_inputs)

    for h in hooks:
        h.remove()

    for p, rec in stats.items():
        if rec["cnt"] <= 0:
            continue
        scaler_row = rec["sum"] / rec["cnt"]         # [in_features]
        W = p.data                                   # [out_features, in_features]
        if scaler_row.numel() != W.shape[-1]:
            continue
        abs_W = W.abs()

        col_sum = abs_W.sum(dim=0, keepdim=True)
        row_sum = abs_W.sum(dim=1, keepdim=True)
        with torch.no_grad():
           ri_ratio = abs_W / col_sum + abs_W/row_sum
        # scaler = scaler_row.norm(p=2)  # shape: scalar
        # a=0.5
        # power_scaler = (scaler ** a).item()  

        #W_metric = ri_ratio * scaler_row.sqrt().reshape(1, -1) + lamda*W.norm(dim=1,p=2,keepdim=True)
        W_metric = ri_ratio * (scaler_row.reshape(1, -1))**0.5
        #alpha=os.environ.get("TEMP_ALPHA", "")
        #W_metric = W.abs() * scaler_row.sqrt().reshape(1, -1)
        #W_metric = (torch.abs(W)/torch.sum(torch.abs(W), dim=0) + torch.abs(W)/torch.sum(torch.abs(W), dim=1).reshape(-1, 1)) * (scaler_row.reshape((1,-1)))**0.3
        #W_metric = W.abs() * scaler_row.sqrt().reshape(1, -1) + lamda*W.norm(dim=1,p=2,keepdim=True)
        if normalize:
            denom = W_metric.abs().mean()
            if torch.isfinite(denom) and denom > 0:
                W_metric = W_metric / denom
        st = optimizer.state.setdefault(p, {})
        st["w_metric"] = W_metric    
        st["scaler_row"] = scaler_row.reshape(1, -1)

def recompute_stochria_metrics_for_step(model: nn.Module,
                                     optimizer,
                                     batch_inputs: dict,
                                     normalize: bool = True,
                                     ):
    opt_param_ids = getattr(optimizer, "_opt_param_id_set", None)
    if opt_param_ids is None:
        opt_param_ids = set(id(p) for g in optimizer.param_groups for p in g["params"] if isinstance(p, torch.Tensor))
        optimizer._opt_param_id_set = opt_param_ids

    hooks = []
    # p -> {"sum_mean": Tensor[in_features], "sum_sq": Tensor[in_features], "cnt": int}
    stats = {}

    def make_pre_hook(p):
        def _pre_hook(module, inputs):
            x = inputs[0]
            if not torch.is_tensor(x):
                return
            with torch.no_grad():
                in_features = x.shape[-1]
                x_flat = x.detach().reshape(-1, in_features)
                x_mean = x_flat.mean(dim=0).to(p.device)         # E[x]
                x2_mean = x_flat.pow(2).mean(dim=0).to(p.device) # E[x^2]

                rec = stats.get(p)
                if rec is None or rec["sum_mean"].numel() != x_mean.numel():
                    stats[p] = {"sum_mean": x_mean.clone(),
                                "sum_sq":   x2_mean.clone(),
                                "cnt": 1}
                else:
                    rec["sum_mean"].add_(x_mean)
                    rec["sum_sq"].add_(x2_mean)
                    rec["cnt"] += 1
        return _pre_hook

    for mod in model.modules():
        if isinstance(mod, nn.Linear) and hasattr(mod, "weight"):
            p = mod.weight
            if id(p) not in opt_param_ids:
                continue
            h = mod.register_forward_pre_hook(make_pre_hook(p))
            hooks.append(h)

    with torch.no_grad(), _tmp_eval_mode(model), _tmp_disable_cache_if_any(model):
        model(**batch_inputs)

    # delete hook
    for h in hooks:
        h.remove()

    for p, rec in stats.items():
        if rec["cnt"] <= 0:
            continue

        scaler_row_mean = rec["sum_mean"] / rec["cnt"]  # [in_features], E[x]
        scaler_row_x2   = rec["sum_sq"]   / rec["cnt"]  # [in_features], E[x^2]

        W = p.data  # [out_features, in_features]
        if scaler_row_x2.numel() != W.shape[-1]:
            continue

        # === stochRIA===
        abs_W = W.detach().abs()
        out_dim, in_dim = abs_W.shape

        beta = 0.2                        
        tau = max(1, int(beta * min(out_dim, in_dim)))
        eps = 1e-12

        with torch.no_grad():

            rand_row = torch.rand(out_dim, in_dim, device=W.device)
            row_topk = rand_row.topk(tau, dim=1).indices             # [out_dim, tau]
            row_mask = torch.zeros_like(abs_W, dtype=torch.bool)
            row_mask.scatter_(1, row_topk, True)                    
            row_subsum = (abs_W * row_mask).sum(dim=1, keepdim=True).clamp_min(eps)  # [out,1]

            rand_col = torch.rand(out_dim, in_dim, device=W.device)
            col_topk = rand_col.topk(tau, dim=0).indices             # [tau, in_dim]
            col_mask = torch.zeros_like(abs_W, dtype=torch.bool)
            col_mask.scatter_(0, col_topk, True)                     
            col_subsum = (abs_W * col_mask).sum(dim=0, keepdim=True).clamp_min(eps)  # [1,in]

            ri_ratio = abs_W / col_subsum + abs_W / row_subsum

            W_metric = ri_ratio * (scaler_row_x2.reshape(1, -1)) ** 0.35

            if normalize:
                denom = W_metric.abs().mean()
                if torch.isfinite(denom) and denom > 0:
                    W_metric = W_metric / denom

            st = optimizer.state.setdefault(p, {})
            st["w_metric"] = W_metric
            st["scaler_row"] = scaler_row_mean.reshape(1, -1)
            st["scaler_row_x2"] = scaler_row_x2.reshape(1, -1)
            st["stochria_tau"] = tau
            st["stochria_beta"] = beta


def spp_inner_training_loop(
    self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
    self.accelerator.free_memory()
    self._train_batch_size = batch_size
    if self.args.auto_find_batch_size:
        if self.state.train_batch_size != self._train_batch_size:
            from accelerate.utils import release_memory

            (self.model_wrapped,) = release_memory(self.model_wrapped)
            self.model_wrapped = self.model

            # Check for DeepSpeed *after* the intial pass and modify the config
            if self.is_deepspeed_enabled:
                # Temporarily unset `self.args.train_batch_size`
                original_bs = self.args.per_device_train_batch_size
                self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
                self.propagate_args_to_deepspeed(True)
                self.args.per_device_train_batch_size = original_bs
        self.state.train_batch_size = self._train_batch_size
    logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
    # Data loader and number of training steps
    train_dataloader = self.get_train_dataloader()
    if self.is_fsdp_xla_v2_enabled:
        train_dataloader = tpu_spmd_dataloader(train_dataloader)

    # Setting up training control variables:
    # number of training epochs: num_train_epochs
    # number of training steps per epoch: num_update_steps_per_epoch
    # total number of training steps to execute: max_steps
    total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size

    len_dataloader = None
    num_train_tokens = None
    if has_length(train_dataloader):
        len_dataloader = len(train_dataloader)
        num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        num_examples = self.num_examples(train_dataloader)
        if args.max_steps > 0:
            max_steps = args.max_steps
            num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                args.max_steps % num_update_steps_per_epoch > 0
            )
            # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
            # the best we can do.
            num_train_samples = args.max_steps * total_train_batch_size
            if args.include_tokens_per_second:
                num_train_tokens = (
                    self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
                )
        else:
            max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
            num_train_epochs = math.ceil(args.num_train_epochs)
            num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
            if args.include_tokens_per_second:
                num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
    elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
        max_steps = args.max_steps
        # Setting a very large number of epochs so we go as many times as necessary over the iterator.
        num_train_epochs = sys.maxsize
        num_update_steps_per_epoch = max_steps
        num_examples = total_train_batch_size * args.max_steps
        num_train_samples = args.max_steps * total_train_batch_size
        if args.include_tokens_per_second:
            num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
    else:
        raise ValueError(
            "args.max_steps must be set to a positive value if dataloader does not have a length, was"
            f" {args.max_steps}"
        )

    if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
        if self.args.n_gpu > 1:
            # nn.DataParallel(model) replicates the model, creating new variables and module
            # references registered here no longer work on other gpus, breaking the module
            raise ValueError(
                "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                " (torchrun or torch.distributed.launch (deprecated))."
            )
        else:
            debug_overflow = DebugUnderflowOverflow(self.model)  # noqa

    delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled

    # We need to reset the scheduler, as its parameters may be different on subsequent calls
    if self._created_lr_scheduler:
        self.lr_scheduler = None
        self._created_lr_scheduler = False

    if self.is_deepspeed_enabled:
        self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

    if not delay_optimizer_creation:
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)

    self.state = TrainerState(
        stateful_callbacks=[
            cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
        ]
    )
    self.state.is_hyper_param_search = trial is not None
    self.state.train_batch_size = self._train_batch_size

    # Compute absolute values for logging, eval, and save if given as ratio
    if args.logging_steps is not None:
        if args.logging_steps < 1:
            self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
        else:
            self.state.logging_steps = args.logging_steps
    if args.eval_steps is not None:
        if args.eval_steps < 1:
            self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
        else:
            self.state.eval_steps = args.eval_steps
    if args.save_steps is not None:
        if args.save_steps < 1:
            self.state.save_steps = math.ceil(max_steps * args.save_steps)
        else:
            self.state.save_steps = args.save_steps

    # Activate gradient checkpointing if needed
    if args.gradient_checkpointing:
        if args.gradient_checkpointing_kwargs is None:
            gradient_checkpointing_kwargs = {}
        else:
            gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs

        self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    model = self._wrap_model(self.model_wrapped)

    # as the model is wrapped, don't use `accelerator.prepare`
    # this is for unhandled cases such as
    # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
    use_accelerator_prepare = True if model is self.model else False

    if delay_optimizer_creation:
        if use_accelerator_prepare:
            self._fsdp_qlora_plugin_updates()
            self.model = self.accelerator.prepare(self.model)
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)

    # prepare using `accelerator` prepare
    if use_accelerator_prepare:
        self.model.train()
        if hasattr(self.lr_scheduler, "step"):
            if self.use_apex:
                model = self.accelerator.prepare(self.model)
            else:
                model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
        else:
            # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
            model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                self.model, self.optimizer, self.lr_scheduler
            )
    elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
        # In this case we are in DDP + LOMO, which should be supported
        self.optimizer = self.accelerator.prepare(self.optimizer)

    if self.is_fsdp_enabled:
        self.model = self.model_wrapped = model

    # for the rest of this function `model` is the outside model, whether it was wrapped or not
    if model is not self.model:
        self.model_wrapped = model

    # backward compatibility
    if self.is_deepspeed_enabled:
        self.deepspeed = self.model_wrapped

    # ckpt loading
    if resume_from_checkpoint is not None:
        if self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(
                self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
            )
        elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
            self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

    # Check if saved optimizer or scheduler states exist
    self._load_optimizer_and_scheduler(resume_from_checkpoint)

    # important: at this point:
    # self.model         is the Transformers Model
    # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
    # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

    # Train!
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {num_examples:,}")
    logger.info(f"  Num Epochs = {num_train_epochs:,}")
    logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    if self.args.per_device_train_batch_size != self._train_batch_size:
        logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_steps:,}")
    logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

    self.state.epoch = 0
    start_time = time.time()
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    steps_trained_progress_bar = None

    # Check if continuing training from a checkpoint
    if resume_from_checkpoint is not None and os.path.isfile(
        os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
    ):
        self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
        self.compare_trainer_and_checkpoint_args(self.args, self.state)
        self._load_callback_state()
        epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
        if not args.ignore_data_skip:
            steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
            steps_trained_in_current_epoch *= args.gradient_accumulation_steps
        else:
            steps_trained_in_current_epoch = 0

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info(f"  Continuing training from epoch {epochs_trained}")
        logger.info(f"  Continuing training from global step {self.state.global_step}")
        if not args.ignore_data_skip:
            logger.info(
                f"  Will skip the first {epochs_trained} epochs then the first"
                f" {steps_trained_in_current_epoch} batches in the first epoch."
            )

    # Update the references
    self.callback_handler.model = self.model
    self.callback_handler.optimizer = self.optimizer
    self.callback_handler.lr_scheduler = self.lr_scheduler
    self.callback_handler.train_dataloader = train_dataloader
    if self.hp_name is not None and self._trial is not None:
        # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
        # parameter to Train when using DDP.
        self.state.trial_name = self.hp_name(self._trial)
    if trial is not None:
        assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
        self.state.trial_params = hp_params(assignments)
    else:
        self.state.trial_params = None
    # This should be the same if the state has been saved but in case the training arguments changed, it's safer
    # to set this after the load.
    self.state.max_steps = max_steps
    self.state.num_train_epochs = num_train_epochs
    self.state.is_local_process_zero = self.is_local_process_zero()
    self.state.is_world_process_zero = self.is_world_process_zero()

    # tr_loss is a tensor to avoid synchronization of TPUs through .item()
    tr_loss = torch.tensor(0.0).to(args.device)
    # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
    self._total_loss_scalar = 0.0
    self._globalstep_last_logged = self.state.global_step

    model.zero_grad()

    grad_norm: Optional[float] = None
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

    if args.eval_on_start:
        self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

    total_batched_samples = 0
    for epoch in range(epochs_trained, num_train_epochs):
        epoch_iterator = train_dataloader
        if hasattr(epoch_iterator, "set_epoch"):
            epoch_iterator.set_epoch(epoch)

        # Reset the past mems state at the beginning of each epoch if necessary.
        if args.past_index >= 0:
            self._past = None

        steps_in_epoch = (
            len(epoch_iterator)
            if len_dataloader is not None
            else args.max_steps * args.gradient_accumulation_steps
        )
        self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

        if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
            self._load_rng_state(resume_from_checkpoint)

        rng_to_sync = False
        steps_skipped = 0
        if steps_trained_in_current_epoch > 0:
            epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
            steps_skipped = steps_trained_in_current_epoch
            steps_trained_in_current_epoch = 0
            rng_to_sync = True

        step = -1

        for step, inputs in enumerate(epoch_iterator):
            total_batched_samples += 1

            if self.args.include_num_input_tokens_seen:
                main_input_name = getattr(self.model, "main_input_name", "input_ids")
                if main_input_name not in inputs:
                    logger.warning(
                        "Tried to track the number of tokens seen, however the current model is "
                        "not configured properly to know what item is the input. To fix this, add "
                        "a `main_input_name` attribute to the model class you are using."
                    )
                else:
                    self.state.num_input_tokens_seen += (
                        torch.sum(
                            self.accelerator.gather(
                                torch.tensor(
                                    inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
                                )
                            )
                        )
                        .cpu()
                        .item()
                    )
            if rng_to_sync:
                self._load_rng_state(resume_from_checkpoint)
                rng_to_sync = False

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                if steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.update(1)
                if steps_trained_in_current_epoch == 0:
                    self._load_rng_state(resume_from_checkpoint)
                continue
            elif steps_trained_progress_bar is not None:
                steps_trained_progress_bar.close()
                steps_trained_progress_bar = None

            if step % args.gradient_accumulation_steps == 0:
                self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

            with self.accelerator.accumulate(model):
                tr_loss_step = self.training_step(model, inputs)

            if (
                args.logging_nan_inf_filter
                and not is_torch_xla_available()
                and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
            ):
                # if loss is nan or inf simply add the average of previous logged losses
                tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
            else:
                if tr_loss.device != tr_loss_step.device:
                    raise ValueError(
                        f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
                    )
                tr_loss += tr_loss_step

            self.current_flos += float(self.floating_point_ops(inputs))

            is_last_step_and_steps_less_than_grad_acc = (
                steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
            )

            if (
                total_batched_samples % args.gradient_accumulation_steps == 0
                or
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                is_last_step_and_steps_less_than_grad_acc
            ):
                # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
                # in accelerate. So, explicitly enable sync gradients to True in that case.
                if is_last_step_and_steps_less_than_grad_acc:
                    self.accelerator.gradient_state._set_sync_gradients(True)

                # Gradient clipping
                if args.max_grad_norm is not None and args.max_grad_norm > 0:
                    # deepspeed does its own clipping

                    if is_sagemaker_mp_enabled() and args.fp16:
                        _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
                    elif self.use_apex:
                        # Revert to normal clipping otherwise, handling Apex or full precision
                        _grad_norm = nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            args.max_grad_norm,
                        )
                    else:
                        _grad_norm = self.accelerator.clip_grad_norm_(
                            model.parameters(),
                            args.max_grad_norm,
                        )

                    if (
                        is_accelerate_available()
                        and self.accelerator.distributed_type == DistributedType.DEEPSPEED
                    ):
                        grad_norm = model.get_global_grad_norm()
                        # In some cases the grad norm may not return a float
                        if hasattr(grad_norm, "item"):
                            grad_norm = grad_norm.item()
                    else:
                        grad_norm = _grad_norm
                #inputs
                #from utils_wanda import recompute_wanda_metrics_for_step
                inputs_for_metric = self._prepare_inputs(inputs)
                real_model = getattr(model, "module", model)    
                recompute_stochria_metrics_for_step(real_model, self.optimizer, inputs_for_metric, normalize=True)

                self.optimizer.step()
                #use_24_panelty=True
                if self.lambda2_param>0.0:
                    
                    replace_weight(model, self.lambda2_param) 
                    #param.data=prox_op(param.data,lamda_2)
                self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)

                optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
                if optimizer_was_run:
                    # Delay optimizer scheduling until metrics are generated
                    if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                        self.lr_scheduler.step()
                # model.zero_grad()
                self.state.global_step += 1
                self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
                self.control = self.callback_handler.on_step_end(args, self.state, self.control)

                self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
            else:
                self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

            if self.control.should_epoch_stop or self.control.should_training_stop:
                # PyTorch/XLA relies on the data loader to insert the mark_step for
                # each step. Since we are breaking the loop early, we need to manually
                # insert the mark_step here.
                if is_torch_xla_available():
                    xm.mark_step()
                break
        if step < 0:
            logger.warning(
                "There seems to be not a single sample in your epoch_iterator, stopping training at step"
                f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                f" num_steps ({max_steps}) higher than the number of available samples."
            )
            self.control.should_training_stop = True

        self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
        self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
            if is_torch_xla_available():
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())
            else:
                logger.warning(
                    "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                    "configured. Check your training configuration if this is unexpected."
                )
        if self.control.should_training_stop:
            break

    if args.past_index and hasattr(self, "_past"):
        # Clean the state at the end of training
        delattr(self, "_past")

    logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
    if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
        # Wait for everyone to get here so we are sure the model has been saved by process 0.
        if is_torch_xla_available():
            xm.rendezvous("load_best_model_at_end")
        elif args.parallel_mode == ParallelMode.DISTRIBUTED:
            dist.barrier()
        elif is_sagemaker_mp_enabled():
            smp.barrier()

        self._load_best_model()

    # add remaining tr_loss
    self._total_loss_scalar += tr_loss.item()
    effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError
    train_loss = self._total_loss_scalar / effective_global_step

    metrics = speed_metrics(
        "train",
        start_time,
        num_samples=num_train_samples,
        num_steps=self.state.max_steps,
        num_tokens=num_train_tokens,
    )
    self.store_flos()
    metrics["total_flos"] = self.state.total_flos
    metrics["train_loss"] = train_loss

    self.is_in_train = False

    self._memory_tracker.stop_and_update_metrics(metrics)

    self.log(metrics)

    run_dir = self._get_output_dir(trial)
    checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

    # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
    if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
        for checkpoint in checkpoints_sorted:
            if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
                logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                shutil.rmtree(checkpoint, ignore_errors=True)

    self.control = self.callback_handler.on_train_end(args, self.state, self.control)

    # Wait for the checkpoint to be uploaded.
    self._finish_current_push()

    # After training we make sure to retrieve back the original forward pass method
    # for the embedding layer by removing the forward post hook.
    if self.neftune_noise_alpha is not None:
        self._deactivate_neftune(self.model)

    return TrainOutput(self.state.global_step, train_loss, metrics)

######## functioning
