# NEW
from transformers.modeling_utils import load_sharded_checkpoint, unwrap_model
######### Trl Trainer import ########
import contextlib
import dataclasses
import os
import warnings
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union, Tuple, Dict, Tuple as Tup
import math


import torch
import torch.nn as nn
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BaseImageProcessor,
    DataCollator,
    FeatureExtractionMixin,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    TrainingArguments,
    is_wandb_available,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ..data_utils import (
    apply_chat_template,
    is_conversational,
    # maybe_convert_to_chatml,
    # pack_dataset,
    # truncate_dataset,
    # truncate_input_ids_innermost,
    # convert_generations_to_chatml_with_sot_tokens,
    # is_list_conversational,
)
# from ..models import get_act_offloading_ctx_manager
from .sft_config import SFTConfig
from .utils import (
    ConstantLengthDataset,
    generate_model_card,
    # get_comet_experiment_url,
    pad,
    peft_module_casting_to_bf16,
)
import gc
# if is_peft_available():
#     import peft
#     from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_wandb_available():
    import wandb

######### Transformers Trainer import ########
import inspect
import time
from functools import partial
from accelerate import Accelerator, skip_first_batches


from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from transformers.integrations import (
    get_reporting_integration_callbacks,
)


from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available

from transformers.trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    ExportableState,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_pt_utils import (
    DistributedTensorGatherer,
    EvalLoopContainer,
    IterableDatasetShard,
    LabelSmoother,
    LayerWiseDummyOptimizer,
    LengthGroupedSampler,
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    find_batch_size,
    get_model_param_count,
    get_module_class_from_name,
    get_parameter_names,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
    remove_dummy_checkpoint,
    set_rng_state_for_device,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalLoopOutput,
    EvalPrediction,
    HPSearchBackend,
    HubStrategy,
    PredictionOutput,
    RemoveColumnsCollator,
    SaveStrategy,
    TrainerMemoryTracker,
    TrainOutput,
    check_target_module_exists,
    default_compute_objective,
    denumpify_detensorize,
    enable_full_determinism,
    find_executable_batch_size,
    get_last_checkpoint,
    has_length,
    neftune_post_forward_hook,
    number_of_arguments,
    seed_worker,
    set_seed,
    speed_metrics,
)
from transformers.utils import (
    ADAPTER_CONFIG_NAME,
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
    CONFIG_NAME,
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    XLA_FSDPV2_MIN_VERSION,
    PushInProgress,
    PushToHubMixin,
    can_return_loss,
    check_torch_load_is_safe,
    find_labels,
    is_accelerate_available,
    is_apex_available,
    is_apollo_torch_available,
    is_bitsandbytes_available,
    is_datasets_available,
    is_galore_torch_available,
    is_grokadamw_available,
    is_in_notebook,
    is_ipex_available,
    is_liger_kernel_available,
    is_lomo_available,
    is_peft_available,
    is_safetensors_available,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_schedulefree_available,
    is_torch_hpu_available,
    is_torch_mlu_available,
    is_torch_mps_available,
    is_torch_musa_available,
    is_torch_neuroncore_available,
    is_torch_npu_available,
    is_torch_xla_available,
    is_torch_xpu_available,
    is_torchao_available,
    logging,
    strtobool,
)

from accelerate.utils import (
    AutocastKwargs,
    DistributedDataParallelKwargs,
    DistributedType,
    load_fsdp_model,
    load_fsdp_optimizer,
    save_fsdp_model,
    save_fsdp_optimizer,
)

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_datasets_available():
    import datasets

logger = logging.get_logger(__name__)

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCALER_NAME = "scaler.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"

######## New import ##########
from collections import deque
from accelerate.utils import gather_object
from accelerate.utils import send_to_device
from scipy.optimize import linear_sum_assignment
# import rpdb
import itertools
import torch.distributed as dist


# Debug random matching performance
@torch.no_grad()
def random_configuration(cost: torch.Tensor, *, seed: int | None = None):
    """
    Uniform random one-to-one matching (no repeated columns).
    """
    N, M = cost.shape
    K = min(N, M)
    dev = cost.device

    g = torch.Generator(device=dev)
    if seed is not None:
        g.manual_seed(seed)

    if N <= M:
        # use all rows; pick N distinct columns
        rows = torch.arange(N, device=dev, dtype=torch.long)
        cols = torch.randperm(M, generator=g, device=dev)[:K]
    else:
        # use all columns; pick M distinct rows, keep rows sorted like your solver
        rows = torch.randperm(N, generator=g, device=dev)[:K].sort().values
        cols = torch.randperm(M, generator=g, device=dev)  # length K

    return rows, cols



####### Final version for bipartite matching and logging #######
# assignments_gpu_or_hybrid_atleast1d.py
# ---------- enumeration (GPU) ----------
@torch.no_grad()
def enumerate_all_assignments_gpu_vec(
    cost: torch.Tensor,
    sum_dtype: torch.dtype = torch.float64,
    normalize_by_cols: bool=True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    assert cost.is_cuda and cost.ndim == 2
    R, C = map(int, cost.shape); assert R >= C
    dev = cost.device

    row_sets = torch.tensor(list(itertools.combinations(range(R), C)), dtype=torch.long, device=dev)
    perms    = torch.tensor(list(itertools.permutations(range(C))),    dtype=torch.long, device=dev)

    S, P = row_sets.size(0), perms.size(0)
    row_idx = row_sets[:, None, :].expand(S, P, C)
    col_idx = perms[None, :, :].expand(S, P, C)

    vals  = cost.to(sum_dtype)[row_idx, col_idx].sum(-1, dtype=sum_dtype)   # (S,P)
    flat  = vals.reshape(-1)                                                # (N,)
    order = torch.argsort(flat)

    s = torch.div(order, P, rounding_mode="floor")
    p = order - s * P

    rows_sorted  = row_sets.index_select(0, s)
    cols_sorted  = perms.index_select(0, p)
    costs_sorted = flat.index_select(0, order)
    if normalize_by_cols:
        costs_sorted = costs_sorted/C
    return rows_sorted, cols_sorted, costs_sorted


# ---------- enumeration (CPU -> GPU hybrid) ----------
@torch.no_grad()
def enumerate_all_assignments_cpu_vec_to_device( # cost GPU in -> rows, costs GPU out
    cost: torch.Tensor,
    sum_dtype: torch.dtype = torch.float64,
    normalize_by_cols: bool=True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    assert cost.ndim == 2
    dev = cost.device
    R, C = map(int, cost.shape); assert R >= C

    x = cost.detach().to("cpu")
    row_sets = torch.tensor(list(itertools.combinations(range(R), C)), dtype=torch.long)
    perms    = torch.tensor(list(itertools.permutations(range(C))),    dtype=torch.long)

    S, P = row_sets.size(0), perms.size(0)
    row_idx = row_sets[:, None, :].expand(S, P, C)
    col_idx = perms[None, :, :].expand(S, P, C)

    # vals  = x.to(sum_dtype)[row_idx, col_idx].sum(-1, dtype=sum_dtype)
    vals  = x[row_idx, col_idx].sum(-1)
    flat  = vals.reshape(-1)
    order = torch.argsort(flat)

    s = torch.div(order, P, rounding_mode="floor")
    p = order - s * P

    rows_sorted  = row_sets.index_select(0, s).to(dev)
    cols_sorted  = perms.index_select(0, p).to(dev)
    costs_sorted = flat.index_select(0, order).to(dev)
    if normalize_by_cols:
        costs_sorted = costs_sorted/C
    return rows_sorted, cols_sorted, costs_sorted


# ---------- stats (GPU/CPU device-agnostic; returns 0-D tensors) ----------
@torch.no_grad()
def assignment_stats_minimal_tensor(
    costs_sorted: torch.Tensor,
    eps: float = 1e-12,
    exclude_all_min_ties: bool = False,
) -> Dict[str, torch.Tensor]:
    dev = costs_sorted.device
    c = costs_sorted.to(dtype=torch.float64, device=dev)
    N = c.numel()

    min_cost = c[0]
    max_cost = c[-1]
    denom_abs = min_cost.abs() + eps
    denom_min = min_cost + eps

    if exclude_all_min_ties:
        gt_mask = c > min_cost
        if gt_mask.any():
            second_cost = c[gt_mask][0]; rest = c[gt_mask]
        else:
            second_cost = torch.tensor(float("inf"), dtype=c.dtype, device=dev)
            rest = c.new_empty(0, dtype=c.dtype, device=dev)
    else:
        second_cost = c[1] if N >= 2 else torch.tensor(float("inf"), dtype=c.dtype, device=dev)
        rest = c[1:]

    gap_second = second_cost - min_cost
    gap_second_rel = gap_second / denom_abs

    if rest.numel() > 0:
        gap_rest = rest.mean() - min_cost
        gap_rest_rel = gap_rest / denom_abs
    else:
        nan = torch.tensor(float("nan"), dtype=c.dtype, device=dev)
        gap_rest = nan
        gap_rest_rel = nan

    mean_all = c.mean()
    gap_max = max_cost - min_cost
    gap_max_rel = gap_max / denom_min

    return {
        "min_cost":       min_cost,
        "gap_second":     gap_second,
        "gap_second_rel": gap_second_rel,
        "gap_rest":       gap_rest,
        "gap_rest_rel":   gap_rest_rel,
        "mean_all":       mean_all,
        "gap_max":        gap_max,
        "gap_max_rel":    gap_max_rel,
    }


# ---------- tracker (everything lives on CUDA; choose backend per call) ----------
class MatchingTrackerFixed:
    _STAT_KEYS = (
        "min_cost",
        "gap_second",
        "gap_second_rel",
        "gap_rest",
        "gap_rest_rel",
        "mean_all",
        "gap_max",
        "gap_max_rel",
    )

    def __init__(self, R_fixed: int, device: Optional[torch.device] = None):
        if device is None:
            assert torch.cuda.is_available(), "GPU required for this class"
            device = torch.device("cuda", torch.cuda.current_device())
        self.device = device
        self.R_fixed = int(R_fixed)
        self.counts = torch.zeros(R_fixed, dtype=torch.float64, device=device) 
        self.token_counts = torch.zeros(R_fixed, dtype=torch.float64, device=self.device)  # Σ tokens per row

        self._stat_sums: Dict[str, torch.Tensor] = {
            k: torch.zeros((), dtype=torch.float64, device=device) for k in self._STAT_KEYS
        }
        self.mats_processed = torch.zeros((), dtype=torch.int64, device=device)
        self.rowset_counts: dict[int, torch.Tensor] = {}
        self.config_counts: dict[int, torch.Tensor] = {} 

    def _ensure_rowset_counts(self, C: int) -> None:
        """Allocate per-C row-set counter if missing."""
        if C not in self.rowset_counts:
            S = math.comb(self.R_fixed, C)
            self.rowset_counts[C] = torch.zeros(S, dtype=torch.float64, device=self.device)

    def _ensure_config_counts(self, C: int) -> None:
        if C not in self.config_counts:
            R = int(self.counts.numel())
            N = math.comb(R, C) * math.factorial(C)
            self.config_counts[C] = torch.zeros(N, dtype=torch.float64, device=self.device)


    @torch.no_grad()
    def _best_config_and_rowset_ids_cpu(self, cost: torch.Tensor):
        """
        Returns:
          C : int                          # number of columns
          k : int                          # flat config id in [0, S*P)
          s : int                          # row-set id in [0, S)
          p : int                          # permutation id in [0, P)  (optional to use)
          S : int                          # comb(R, C)
          P : int                          # C!
        """
        R, C = map(int, cost.shape)
        x = cost.detach().to("cpu", dtype=torch.float64)
    
        row_sets = torch.tensor(list(itertools.combinations(range(R), C)), dtype=torch.long)  # (S, C)
        perms    = torch.tensor(list(itertools.permutations(range(C))),    dtype=torch.long)  # (P, C)
        S, P = row_sets.size(0), math.factorial(C)
    
        # broadcasted index grids → all (rowset, perm) totals at once
        row_idx = row_sets[:, None, :].expand(S, P, C)   # (S,P,C)
        col_idx = perms[None, :, :].expand(S, P, C)      # (S,P,C)
        vals = x[row_idx, col_idx].sum(-1)               # (S,P)
    
        k = int(vals.reshape(-1).argmin().item())        # flat argmin in [0, S*P)
        s, p = divmod(k, P)                               # row-set id, perm id
        return C, k, s, p, S, P



    @torch.no_grad()
    def process_optimal(
        self,
        cost: torch.Tensor,
        num_tokens_global_Nki: torch.Tensor, # NOTE CHANGE 2
        *,
        backend: str = "gpu",     # "gpu" or "cpu"
        return_stats: bool = True,
        exclude_all_min_ties: bool = False,
    ):
        if backend == "gpu":
            rows_sorted, cols_sorted, costs_sorted = enumerate_all_assignments_gpu_vec(cost.to(self.device))
        elif backend == "cpu":
            rows_sorted, cols_sorted, costs_sorted = enumerate_all_assignments_cpu_vec_to_device(cost.to(self.device))
        else:
            raise ValueError("backend must be 'gpu' or 'cpu'")

        best_rows = rows_sorted[0]  # (C,) int64 on CUDA, best_rows=[0,1,3,4]
        best_cols = cols_sorted[0]

        hist = torch.bincount(best_rows, minlength=self.counts.numel()).to(self.counts.dtype)
        self.counts.add_(hist)

        ############### NOTE CHANGE new logic for computing avg num_tokens_per_target_for_each_row ##############
        assert num_tokens_global_Nki is not None
        tok_sel = num_tokens_global_Nki[best_rows, best_cols].to(self.token_counts.dtype)
        # best_rows.shape = [M_targets]
        # tok_sel.shape = [M_targets]
        self.token_counts.index_add_(0, best_rows, tok_sel)

        ##########################################################################################################
        # import rpdb
        # import torch.distributed as dist
        # port = 4444+dist.get_rank()
        # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
        # rpdb.set_trace(port=port)

        # --- NEW: per-C row-set usage ---
        # after you update per-row usage counts, etc.
        C, k, s, p, S, P = self._best_config_and_rowset_ids_cpu(cost.to(self.device))
        
        # config usage (rowset × perm)
        self._ensure_config_counts(C)
        self.config_counts[C][k] += 1.0
        
        # row-set usage (perm-marginal)
        self._ensure_rowset_counts(C)
        self.rowset_counts[C][s] += 1.0
        
        # (optional) if you ever want permutation usage too:
        # self._ensure_perm_counts(C) ; self.perm_counts[C][p] += 1.0

        # --- END NEW ---

        stats_dev = assignment_stats_minimal_tensor(costs_sorted, exclude_all_min_ties=exclude_all_min_ties)
        for k in self._STAT_KEYS:
            self._stat_sums[k].add_(stats_dev[k])
        self.mats_processed.add_(1)

        return rows_sorted[0], cols_sorted[0]

    # getters
    def get_counts(self) -> torch.Tensor:
        # return self.counts.clone()
        return self.counts

    def get_distribution(self) -> torch.Tensor:
        s = self.counts.sum()
        return torch.where(s > 0, self.counts / s, torch.zeros_like(self.counts))

    def get_stat_sums(self) -> Dict[str, torch.Tensor]:
        return self._stat_sums  # by-ref

    def get_stat_means_tensors(self) -> Dict[str, torch.Tensor]:
        den = self.mats_processed.to(torch.float64)
        nan = torch.tensor(float("nan"), dtype=torch.float64, device=self.device)
        return {k: torch.where(den > 0, self._stat_sums[k] / den, nan) for k in self._STAT_KEYS}

    def reset(self) -> None:
        # self.counts.zero_()
        for k in self._STAT_KEYS: self._stat_sums[k].zero_()
        self.mats_processed.zero_()


# ---------- one-collective SUM using atleast_1d(...).contiguous() ----------
@torch.no_grad()
def allreduce_sums_keep_shapes(values: Dict[str, torch.Tensor], pack_dtype=torch.float64):
    """
    Elementwise SUM across ranks for many tensors in one collective.
    - Assumes all tensors are on the same CUDA device
    - Works with scalars (0-D) and 1-D tensors (your stated constraint)
    - Uses torch.atleast_1d(...).contiguous() for packing
    """
    if not dist.is_initialized() or dist.get_world_size() == 1:
        return {k: v.detach().clone() for k, v in values.items()}

    device = next(iter(values.values())).device
    keys   = list(values.keys())
    shapes = [values[k].shape for k in keys]   # keep original shapes to restore 0-D
    dtypes = [values[k].dtype for k in keys]
    sizes  = [values[k].numel() for k in keys]

    parts = [
        torch.atleast_1d(values[k]).to(device=device, dtype=pack_dtype).contiguous()
        for k in keys
    ]
    flat = torch.cat(parts, dim=0)  # 1-D

    dist.all_reduce(flat, op=dist.ReduceOp.SUM)

    out, off = {}, 0
    for k, shape, dtype, n in zip(keys, shapes, dtypes, sizes):
        seg = flat[off:off+n].to(dtype)
        out[k] = seg.reshape(shape if shape != torch.Size([]) else ())
        off += n
    return out


# ---------- DDP reduction for stats + mats_processed + loss ----------

# @torch.no_grad()
# def ddp_reduce_and_average_stats_and_loss_tensors(
#     tracker,                      # AssignmentTrackerFixedNoMapGPU
#     loss_scalar: torch.Tensor,    # scalar (0-D preferred)
#     num_processes: int,           # average loss over this
#     include_perm_counts: bool = False,  # <-- default OFF
# ) -> Tuple[Dict[str, float], int]:
#     """
#     One CUDA collective (SUM) over:
#       - tracker._stat_sums[*]      (0-D float64)
#       - tracker.mats_processed     (0-D int64)
#       - loss_scalar                (0-D, any dtype)
#       - tracker.counts             (1-D float64)
#       - tracker.config_counts[C]   (1-D float64) for each C
# 
#     Returns:
#       - logs: Dict[str, float] with stat means and "loss_avg", plus list[int] payloads:
#           logs["row_usage_counts"]            = [...]
#           logs["config_usage_counts_C{C}"]    = [...]
#           logs["config_rowset_counts_C{C}"]   = [...]   # permutation-marginalized
#           # (perm counts omitted unless include_perm_counts=True)
#       - n_total: int
#     """
#     dev = tracker.device
# 
#     # ensure 0-D loss on device
#     loss0d = loss_scalar.detach()
#     if loss0d.ndim != 0:
#         loss0d = loss0d.mean()
#     loss0d = loss0d.to(dev)
# 
#     # pack locals for one collective
#     vals: Dict[str, torch.Tensor] = {k: tracker._stat_sums[k] for k in tracker._STAT_KEYS}
#     vals["mats_processed"] = tracker.mats_processed
#     vals["loss_sum"] = loss0d
#     vals["counts"] = tracker.counts
# 
#     cfg_keys = []
#     for C, t in tracker.config_counts.items():
#         key = f"cfg_counts_C{C}"
#         vals[key] = t
#         cfg_keys.append((C, key))
# 
#     reduced = allreduce_sums_keep_shapes(vals, pack_dtype=torch.float64)
# 
#     # scalars → floats
#     n_total = int(reduced["mats_processed"].item())
#     logs: Dict[str, float] = {}
#     if n_total == 0:
#         for k in tracker._STAT_KEYS:
#             logs[k] = float("nan")
#     else:
#         den = float(n_total)
#         for k in tracker._STAT_KEYS:
#             logs[k] = float(reduced[k].item() / den)
# 
#     loss_avg = reduced["loss_sum"] / torch.tensor(num_processes, dtype=loss0d.dtype, device=dev)
#     logs["loss_avg"] = float(loss_avg.item())
# 
#     # row usage → list[int]
#     logs["row_usage_counts"] = reduced["counts"].to(torch.int64).detach().cpu().tolist()
# 
#     # per-C config payloads (full, and row-set marginal only)
#     R = tracker.R_fixed
#     for C, key in cfg_keys:
#         v = reduced[key]  # 1-D float64, length comb(R,C)*C!
#         logs[f"config_usage_counts_C{C}"] = v.to(torch.int64).detach().cpu().tolist()
# 
#         S = math.comb(R, C)
#         P = math.factorial(C)
#         mat = v.reshape(S, P)
# 
#         # Row-set marginal (sum over permutations) → length S
#         logs[f"config_rowset_counts_C{C}"] = mat.sum(dim=1).to(torch.int64).cpu().tolist()
# 
#         # Optional: permutation marginal (usually not meaningful for you)
#         if include_perm_counts:
#             logs[f"config_perm_counts_C{C}"] = mat.sum(dim=0).to(torch.int64).cpu().tolist()
# 
#     return logs, n_total

@torch.no_grad()
def ddp_reduce_and_average_stats_and_loss_tensors(
    tracker,                      # e.g., MatchingTrackerFixed
    loss_scalar: torch.Tensor,    # scalar; any dtype/shape -> reduced to 0-D
    num_processes: int,           # average loss over this
    include_config_counts: bool = True,  # leave False if you don't track them
):
    dev = tracker.device

    # 0-D loss on device
    loss0d = loss_scalar.detach()
    if loss0d.ndim != 0:
        loss0d = loss0d.mean()
    loss0d = loss0d.to(dev)

    # ---- pack for one all-reduce ----
    vals = {k: tracker._stat_sums[k] for k in tracker._STAT_KEYS}
    vals["mats_processed"] = tracker.mats_processed
    vals["loss_sum"] = loss0d
    vals["counts"] = tracker.counts  # per-row usage
    vals["token_counts"] = tracker.token_counts # NOTE CHANGE

    # per-C ROW-SET counts (what you track now)
    rowset_keys = []
    if hasattr(tracker, "rowset_counts"):
        for C, t in tracker.rowset_counts.items():
            if C==4: # NOTE TODO temp 
                key = f"rowset_counts_C{C}"
                vals[key] = t
                rowset_keys.append((C, key))

    # optional: flat CONFIG counts per C (only if you actually track them)
    config_keys = []
    if include_config_counts and hasattr(tracker, "config_counts"):
        for C, t in tracker.config_counts.items():
            if C==4: # NOTE TODO temp 
                key = f"config_usage_counts_C{C}"
                vals[key] = t
                config_keys.append((C, key))

    reduced = allreduce_sums_keep_shapes(vals, pack_dtype=torch.float64)

    # ---- build logs (floats) ----
    n_total = int(reduced["mats_processed"].item())
    logs: Dict[str, float] = {}
    if n_total == 0:
        for k in tracker._STAT_KEYS:
            logs[k] = float("nan")
    else:
        den = float(n_total)
        for k in tracker._STAT_KEYS:
            logs[k] = float(reduced[k].item() / den)

    loss_avg_f = float((
        reduced["loss_sum"] / torch.tensor(num_processes, dtype=loss0d.dtype, device=dev)
    ).item())
    logs["loss_avg"] = loss_avg_f
    logs["loss"] = loss_avg_f

    # arrays to lists for histogram/bar logging
    counts_r = reduced["counts"]
    tokens_r = reduced["token_counts"]
    logs["row_usage_counts"] = counts_r.to(torch.int64).cpu().tolist() # NOTE CHANGE
    logs["row_token_counts"] = tokens_r.to(torch.int64).cpu().tolist()
    tokens_per_use = tokens_r / counts_r.clamp_min(1)
    logs["row_tokens_per_use"] = tokens_per_use.to(torch.float64).cpu().tolist()
    # import rpdb
    # import torch.distributed as dist
    # port = 4444+dist.get_rank()
    # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
    # rpdb.set_trace(port=port)

    for C, key in rowset_keys:
        logs[f"rowset_counts_C{C}"] = reduced[key].to(torch.int64).cpu().tolist()

    # only if you really track config_counts
    for C, key in config_keys:
        logs[f"config_usage_counts_C{C}"] = reduced[key].to(torch.int64).cpu().tolist()

    # import rpdb
    # import torch.distributed as dist
    # port = 4444+dist.get_rank()
    # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
    # rpdb.set_trace(port=port)
    return logs, n_total



# @torch.no_grad()
# def ddp_reduce_and_average_stats_and_loss_tensors(
#     tracker: MatchingTrackerFixed,
#     tr_loss: torch.Tensor,        # scalar (0-D preferred); any dtype
#     num_processes: int,               # average loss over this value
# ):
#     """
#     All-reduces on CUDA:
#       - tracker's stat sums (0-D float64)
#       - mats_processed (0-D int64)
#       - loss scalar
# 
#     Returns:
#       - logs: Dict[str, float]  -> {<stat_means>..., "loss_avg": <float>}
#               where stat means are (sums / mats_processed), NaN if mats_processed==0
#       - n_total: int            -> reduced mats_processed across ranks
#     """
#     dev = tracker.device
# 
#     # ensure 0-D loss on device
#     loss0d = tr_loss.detach()
#     if loss0d.ndim != 0:
#         loss0d = loss0d.mean()
#     loss0d = loss0d.to(dev)
# 
#     # pack locals (CUDA tensors) for one collective
#     vals = {k: tracker._stat_sums[k] for k in tracker._STAT_KEYS}  # 0-D float64
#     vals["mats_processed"] = tracker.mats_processed                # 0-D int64
#     vals["loss_sum"] = loss0d                                      # loss dtype
#     vals["counts"] = tracker.counts #NOTE for hist
# 
#     # single all-reduce SUM (uses your atleast_1d(...).contiguous() packer)
#     reduced = allreduce_sums_keep_shapes(vals, pack_dtype=torch.float64)
# 
#     # convert to Python types for logging
#     n_total = int(reduced["mats_processed"].item())
# 
#     logs: Dict[str, float] = {}
#     if n_total == 0:
#         # no data: stats -> NaN
#         for k in tracker._STAT_KEYS:
#             logs[k] = float("nan")
#     else:
#         den = float(n_total)
#         for k in tracker._STAT_KEYS:
#             logs[k] = float(reduced[k].item() / den)
# 
#     # include loss averaged over num_processes
#     loss_avg = reduced["loss_sum"] / torch.tensor(num_processes, dtype=loss0d.dtype, device=dev)
#     logs["loss"] = float(loss_avg.item())
# 
#     counts_list: list[int] = reduced["counts"].to(torch.int64).detach().cpu().tolist() #NOTE for hist
#     logs["row_usage_counts"] = counts_list # NOTE for hist
#     return logs



class ScalableSoFTTrainer(Trainer):
    def __init__(
            self,
            model: Union[PreTrainedModel, nn.Module, None] = None,
            args: TrainingArguments = None,
            data_collator: Optional[DataCollator] = None,
            train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
            eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
            processing_class: Optional[
                Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
            ] = None,
            compute_loss_func: Optional[Callable] = None,
            model_init: Optional[Callable[[], PreTrainedModel]] = None,
            compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
            callbacks: Optional[list[TrainerCallback]] = None,
            optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
            optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
            preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
            formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
            bmsft_phase_train=True, # train
            # peft_config: Optional["PeftConfig"] = None,
        ):
        # NOTE only model,train_dataset,eval_dataset,args,data_collator are passed in
        # data_collator does have tokenizer though
        ####################################
        ##### STEP 1 TRL trainer init  #####
        ####################################
        if processing_class is None: # processing_class=None
            processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
            assert getattr(processing_class, "pad_token", None) is not None
        # args.max_seq_length already set in sft.py

        model_init_kwargs = {} # NOTE here 
        if args.dataset_kwargs is None: # =None
            args.dataset_kwargs = {}
        # if chars_per_token is not None:
        #     warnings.warn(
        #         "You passed a `chars_per_token` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
        #     )
        #     args.chars_per_token = chars_per_token
        args.chars_per_token = None
        #NOTE trl/sft_trainer.py both 0.12.0 and 0.19.0
        self.dataset_num_proc = args.dataset_num_proc#=None
        assert data_collator is not None

        self.sot_tokens = [(f"<think{i+1}>", f"</think{i+1}>") for i in range(args.N_num_sot_tokens)]
        self.N_num_sot_tokens = len(self.sot_tokens)

        train_dataset = self._prepare_dataset( # here
            train_dataset,    #features:['question', 'solution', 'thinking_trajectories', 'attempt', 'grades', 'text']
            processing_class, #Qwen2TokenizerFast
            # args.dataset_text_field, #'text'
            # args.max_seq_length, #32768
            args,
            args.packing, # False
            formatting_func,     #None
            # datset_name
            # args.num_of_sequences, #1024, useless here. just passed to super
            # args.chars_per_token, #3.6, useless here. just passed to super
            # remove_unused_columns=args.remove_unused_columns if args is not None else True,
            # **args.dataset_kwargs,
        )

        #############################################
        ##### STEP 2 superclass trainer.py init #####
        #############################################
        self.args = args
        self.bmsft_phase_train = bmsft_phase_train
        self.compute_loss_func = compute_loss_func
        set_seed(self.args.seed)

        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
        self.hp_name = None
        self.deepspeed = None
        self.is_in_train = False
        self.model = model
        
        ######################### NEW ##############################
        self.original_gradient_accumulation_steps = args.gradient_accumulation_steps
        self.padded_gradient_accumulation_steps = args.gradient_accumulation_steps+args.instantaneous_pad_batch_size
        # self.padded_gradient_accumulation_steps = args.gradient_accumulation_steps+math.ceil(self.num_sot_tokens//args.per_device_batch_size)
        self.args.gradient_accumulation_steps = self.padded_gradient_accumulation_steps # NOTE should only be used in "self.propagate_args_to_deepspeed"
        assert (self.args.past_index == -1)
        ############################################################
        ################### Accelerator(**args) init + propagate_args_to_deepspeed #####################################
        self.create_accelerator_and_postprocess() # self.args.gradient_accumulation_steps=9 being used, per_device_batch_size, global_batch_size chosen
        ################################################################################################################
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()
        # import ipdb; ipdb.set_trace()
        log_level = args.get_process_log_level() #=20
        logging.set_verbosity(log_level)
        args._setup_devices # device(type='cuda', index=0)
        self.model_init = model_init
        if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
            self.is_model_parallel = True #model.is_parallelizable=False
        else:
            self.is_model_parallel = False

        if self.args.use_liger_kernel:
            if is_liger_kernel_available():
                from liger_kernel.transformers import _apply_liger_kernel_to_instance
                if isinstance(model, PreTrainedModel): # NOTE SJ: --> True
                    _apply_liger_kernel_to_instance(model=model)
                elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
                    _apply_liger_kernel_to_instance(model=model.get_base_model())
                else:
                    logger.warning(
                        "The model is not an instance of PreTrainedModel. No liger kernels will be applied."
                    )
            else:
                raise ImportError(
                    "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
                    "Please install it with `pip install liger-kernel`"
                )
        self.is_fsdp_xla_enabled = args.fsdp_config["xla"] #=False
        self.place_model_on_device = args.place_model_on_device #NOTE=True
        if (
            self.is_model_parallel #=False
            or self.is_deepspeed_enabled #=False
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) #args.fp16_full_eval=False, args.bf16_full_eval=False, args.do_train=False
            or self.is_fsdp_xla_enabled #=False
            or self.is_fsdp_enabled     #=True
        ):
            self.place_model_on_device = False

        self.data_collator = data_collator
        self.train_dataset = train_dataset
        self.processing_class = processing_class
        if self.is_model_parallel: #False
            self.args._n_gpu = 1
        self.model_wrapped = model
        self.model = model
        unwrapped_model = self.accelerator.unwrap_model(model)
        if hasattr(unwrapped_model, "accepts_loss_kwargs"):
            self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
        else: 
            forward_params = inspect.signature(unwrapped_model.forward).parameters
            self.model_accepts_loss_kwargs = any(
                k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
            )
        self.neftune_noise_alpha = args.neftune_noise_alpha #NOTE None

        self.compute_metrics = compute_metrics #=None
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics#=None
        self.optimizer, self.lr_scheduler = optimizers #self.optimizer=None, self.lr_scheduler=None
        self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs #None

        #NOTE optimizer & lr_scheduler are None until "deepspeed_init" in _inner_training_loop
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
        ) # NOTE self.callback_handler.callback_list='DefaultFlowCallback'
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) # NOTE callback_list='DefaultFlowCallback\nProgressCallback'
        
        # NOTE self.callback_handler.callbacks=['DefaultFlowCallback', 'WandbCallback', 'ProgressCallback'] 
        self._loggers_initialized = False

        if self.args.should_save: #=True, making empty output directory, E.g. 'data/Debug-SFT-SeparateBatch-MultiTarget-Distill-7B'
            os.makedirs(self.args.output_dir, exist_ok=True)
        if args.max_steps > 0 and args.num_train_epochs > 0: # args.max_steps=-1, args.num_train_epochs=3
            logger.info("max_steps is given, it will override any value given in num_train_epochs")
        self._signature_columns = None
        # Mixed precision setup
        self.use_apex = False
        self.use_cpu_amp = False
        self.label_smoother = None # Here, args.label_smoothing_factor=0.0
        ########################################################
        ## TrainerContrl() & TrainerState() are for callbacks ##
        ########################################################
        self.control = TrainerControl() # NOTE TrainerControl(should_training_stop=False, should_epoch_stop=False, 
        # should_save=False, should_evaluate=False, should_log=False)
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(), # is_local_process_zero()=True
            is_world_process_zero=self.is_world_process_zero(), # is_world_process_zero()=True
            stateful_callbacks=[
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ],
        )
        # self.state=TrainerState(epoch=None, global_step=0, max_steps=0, logging_steps=500, eval_steps=500,
        # save_steps=500, train_batch_size=None, num_train_epochs=0, num_input_tokens_seen=0, total_flos=0,
        # log_history=[], is_local_process_zero=True, is_world_process_zero=True, stateful_callbacks={''})
        self.current_flos = 0
        self.hp_search_backend = None
        default_label_names = find_labels(self.model.__class__) #=['labels']
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names # NOTE =['labels']
        self.can_return_loss = can_return_loss(self.model.__class__) # =False
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)#TrainerControl(should_training_stop=False, should_epoch_stop=False, should_save=False, should_evaluate=False, should_log=False)
        # Internal variables to help with automatic batch size reduction
        self._train_batch_size = args.train_batch_size # NOTE per_device_batch_size=2
        self._created_lr_scheduler = False
        # very last
        self._memory_tracker.stop_and_update_metrics()
        self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
        if self.is_fsdp_xla_v2_enabled:
            if not IS_XLA_FSDPV2_POST_2_2:
                raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
            # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
            # Tensor axis is just a placeholder where it will not be used in FSDPv2.
            num_devices = xr.global_runtime_device_count()
            xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
        self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
        ########################
        #### back to trl init ##
        ########################
        # if self.args.activation_offloading:
        #     self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
        # else:
        #     self.maybe_activation_offload_context = contextlib.nullcontext()

    def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        trial: Union["optuna.Trial", dict[str, Any], None] = None,
        ignore_keys_for_eval: Optional[list[str]] = None,
        **kwargs,
    ):
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
        self._memory_tracker.start()
        args = self.args
        self.is_in_train = True
        # import ipdb; ipdb.set_trace()
        if (
            (args.fp16_full_eval or args.bf16_full_eval)
            and not args.do_train
            and not self.is_model_parallel
            and self.model_init is None
        ): # NOTE SKIP because of args.do_train=False?
            self._move_model_to_device(self.model, args.device)
        self._hp_search_setup(trial)
        self._train_batch_size = self.args.train_batch_size

        # Load potential model checkpoint
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
            if resume_from_checkpoint is None:
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

        if resume_from_checkpoint is not None:
            if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
                self._load_from_checkpoint(resume_from_checkpoint)
            # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
            state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
            if state.train_batch_size is not None:
                self._train_batch_size = state.train_batch_size

        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size #self._train_batch_size=2, args.auto_find_batch_size=False
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ): 
        device = self.accelerator.device # device(type='cuda', index=0)
        self.accelerator.free_memory()
        self._train_batch_size = args.train_batch_size #=2
        train_dataloader = self.get_train_dataloader()

        total_train_batch_size = self.get_total_train_batch_size(args)
        assert args.max_steps == -1
        # num_train_epochs = args.num_train_epochs
        num_train_epochs = int(args.num_train_epochs)
        precomputed_num_update_steps_per_epoch_list = args.precomputed_num_update_steps_per_epoch_list # TODO precomputed list
        # precomputed_num_examples = args.precomputed_num_examples
        # precomputed_num_train_samples = args.precomputed_num_train_samples # total train samples over all epochs
        epoch_based = True
        # NOTE len_dataloader only represents the number of sets in dataloader, not actually the number of sequences being trained on, DONT USE IT
        precomputed_max_steps = sum(precomputed_num_update_steps_per_epoch_list)   # 
        #####################
        delay_optimizer_creation = ( # NOTE self.is_fsdp_xla_enabled=True, delay optimizer creation.  only deepspeed has no delay
            is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
        ) # =True
        is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
        if is_fsdp2:
            delay_optimizer_creation = False

        if self._created_lr_scheduler: #=False
            self.lr_scheduler = None
            self._created_lr_scheduler = False

        if self.is_deepspeed_enabled: #=False, 
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation: # delay_optimizer_creation=True
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        self.state = TrainerState(
            stateful_callbacks=[
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ]
        )
        self.state.is_hyper_param_search = trial is not None
        self.state.train_batch_size = self._train_batch_size #self._train_batch_size=1

        # Compute absolute values for logging, eval, and save if given as ratio
        self.state.compute_steps(args, precomputed_max_steps) #max_steps=1250

        if args.gradient_checkpointing: #=True
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)

        model = self._wrap_model(self.model_wrapped)

        use_accelerator_prepare = True if model is self.model else False #NOTE=True, model=self.model

        if use_accelerator_prepare and self.is_fsdp_enabled: # True and True
            # In case of auto_find_batch_size=True
            # Remove FSDP wrapping from sub-models.
            self.model = unwrap_model(self.model, recursive=True)

        if delay_optimizer_creation: # NOTE True
            if use_accelerator_prepare: # NOTE True
                # configure fsdp plugin for qlora if any
                self._fsdp_qlora_plugin_updates()
                if self.accelerator.mixed_precision != "fp8": # self.accelerator.mixed_precision='bf16'
                    self.model = self.accelerator.prepare(self.model) 
            self.create_optimizer_and_scheduler(num_training_steps=precomputed_max_steps) #max_steps=1250

        if use_accelerator_prepare:
            self.model.train()
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else:
                    if delay_optimizer_creation:
                        model = self.accelerator.prepare(self.model)
                    else:
                        model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )

        if self.is_fsdp_enabled:
            self.model = self.model_wrapped = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # ckpt loading
        if resume_from_checkpoint is not None:
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(
                    self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
                )
            elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
                self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)
        self._load_scaler(resume_from_checkpoint)

        # Train!
        logger.info("***** Running training *****")
        # logger.info(f"  Precomputed num examples = {precomputed_num_examples:,}") #=100
        logger.info(f"  Num Epochs = {num_train_epochs:,}") #=3
        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") #=2
        if self.args.per_device_train_batch_size != self._train_batch_size:
            logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
        logger.info(f"  Padded GAS (args.gradient_accumulation_steps) = {args.gradient_accumulation_steps}")
        logger.info(f"  Precomputed Total optimization steps = {precomputed_max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_progress_bar = None

        steps_trained_before_current_epoch = 0

        self.set_queue = deque() # process1=[[z1_R1, z2_R2], [z1_R1]]  process2=[[z1_R1,], [z1_R1, z2_R2]]
        self.steps_sampled_in_current_epoch = 0
        self.steps_trained_in_current_epoch = 0
        skip_loaded_idxs = False
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ): # TRAINER_STATE_NAME='trainer_state.json'
            set_queue_list, self.steps_sampled_in_current_epoch, self.steps_trained_in_current_epoch, self.remain_sets_sizes_list = torch.load(os.path.join(resume_from_checkpoint, f"set_queue_{args.process_index}.pth")) 
            self.set_queue = deque(set_queue_list)
            del set_queue_list
            skip_loaded_idxs = True

            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
            self.compare_trainer_and_checkpoint_args(self.args, self.state)
            self._load_callback_state()
            #epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
            #assert epochs_trained == int(self.state.epoch)
            epochs_trained == int(self.state.epoch)
            steps_trained_before_current_epoch = sum(precomputed_num_update_steps_per_epoch_list[:epochs_trained])
            assert self.steps_trained_in_current_epoch == self.state.global_step-steps_trained_before_current_epoch # previously "self.state.global_step%num_update_steps_per_epoch" "steps_trained_in_current_epoch*=args.gradient_accumulation_steps"
            curr_epoch_progress = round(self.state.epoch-int(self.state.epoch), 6)
            curr_epoch_progress_ = round((self.state.global_step-steps_trained_before_current_epoch)/precomputed_num_update_steps_per_epoch_list[epochs_trained], 6)
            # port = 4444+self.args.process_index
            # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
            # rpdb.set_trace(port=port)
            assert curr_epoch_progress == curr_epoch_progress_

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first"
                    f" {self.steps_trained_in_current_epoch} batches in the first epoch."
                )
        for attr in ("model", "optimizer", "lr_scheduler"):
            setattr(self.callback_handler, attr, getattr(self, attr))
        self.callback_handler.train_dataloader = train_dataloader
        #NOTE max_steps is already computed above as precomputed_max_steps
        self.state.init_training_references(self, precomputed_max_steps, num_train_epochs, trial)

        tr_loss = torch.tensor(0.0, device=args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step #=0
        model.zero_grad()
        grad_norm: Optional[float] = None
        learning_rate = None
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        self.matching_tracker = MatchingTrackerFixed(self.N_num_sot_tokens)
        assert self.state.global_step-steps_trained_before_current_epoch==self.steps_trained_in_current_epoch
        if not self.bmsft_phase_train:
            precomputed_num_update_steps_per_epoch_list = [] #NOTE inspect_only
        _bs = self._train_batch_size
        pad_token_id = self.data_collator.pad_token_id

        # import rpdb
        # import torch.distributed as dist
        # port = 4444+dist.get_rank()
        # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
        # rpdb.set_trace(port=port)
        for epoch in range(epochs_trained, num_train_epochs):
            epoch_dataloader = train_dataloader
            epoch_dataloader.set_epoch(epoch)
            len_dataloader = len(epoch_dataloader) # same "number of per_device_batch input prompts" as before, also same every epoch, 
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

            if epoch == epochs_trained and resume_from_checkpoint is not None and self.steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint) #NOTE what's the point of this?

            rng_to_sync = False
            if skip_loaded_idxs: # Only come here once for the first loading.  all reset for subsequent epochs
                epoch_dataloader = skip_first_batches(epoch_dataloader, self.steps_sampled_in_current_epoch)
                skip_loaded_idxs = False
                rng_to_sync = True
            else:
                self.steps_sampled_in_current_epoch = 0

                self.set_queue = deque()

                self.steps_trained_in_current_epoch = 0

                self.remain_sets_sizes_list = []
            epoch_iterator = iter(train_dataloader)
            local_wrapped_once = False
            # remain_sets_sizes_list = [] # NOTE THIS value needs to be saved and loaded, THIS WAS WRONG
            while True: # NOTE EPOCH LOOP
                dataloader_num_batches = self.original_gradient_accumulation_steps # guarantees there's at least one valid global_batch
                # NOTE original_per_device_global_batch_size=2 = 2*1 (correct)
                original_per_device_global_batch_size = self.original_gradient_accumulation_steps*_bs
                accumu_i = 0
                while accumu_i < dataloader_num_batches: # 2 iterations with original_gradient_accumulation_steps
                    try:
                        inputs_ = next(epoch_iterator) # [[z1_R1, z1_R2, z2_R1, z2_R2, z3_R1, z3_R2], [z1_R1, z2_R1]]
                        # len(inputs_['input_ids'][0])=8
                        for m_i in range(_bs):
                            input_ids_Nki = inputs_['input_ids'][m_i]
                            # import ipdb; ipdb.set_trace()
                            assert len(input_ids_Nki)%self.N_num_sot_tokens == 0
                            set_size = len(input_ids_Nki)//self.N_num_sot_tokens #=2
                            self.set_queue.append({
                                'input_ids':input_ids_Nki,
                                'attention_mask':inputs_['attention_mask'][m_i],
                                'position_ids': inputs_['position_ids'][m_i],
                                'labels':inputs_['labels'][m_i],
                                'matching_input_ids':inputs_['matching_input_ids'][m_i],
                                'matching_attention_mask':inputs_['matching_attention_mask'][m_i],
                                'matching_position_ids': inputs_['matching_position_ids'][m_i],
                                'matching_labels':inputs_['matching_labels'][m_i],
                                })
                            self.remain_sets_sizes_list.append(set_size)
                        accumu_i+=1
                        self.steps_sampled_in_current_epoch+=1
                    except StopIteration:
                        epoch_iterator = iter(train_dataloader)
                        local_wrapped_once = True
                packed_obj = (self.remain_sets_sizes_list, local_wrapped_once)
                ################ GATHER every sample step #######################
                packed_all_set_sizes__wrapped = gather_object(packed_obj) #[0].shape=[num_processes, original_per_device_global_batch_size=8*2]
                temp_all_remain_sets_sizes_lists = []
                all_wrapped_once = []
                num_items_unpack = 2
                for p_idx in range(self.accelerator.num_processes):
                    process_items = packed_all_set_sizes__wrapped[p_idx*num_items_unpack:(p_idx+1)*(num_items_unpack)]
                    temp_all_remain_sets_sizes_lists.append(process_items[0])
                    all_wrapped_once.append(process_items[1])

                while True: # Do optimizer steps until one of the processes doesn't have enough sequences to form "gradient_accumulation_steps*per_device_batch_size=8*2. a.k.a not enough for 16" 
                    all_remain_sets_sizes_lists = temp_all_remain_sets_sizes_lists
                    all_reached_sets_sizes_lists, temp_all_remain_sets_sizes_lists = [], []
                    all_processes_reachable = True
                    all_processes_max_num_per_device_global_batch_size = -1
                    for M_i, process_remain_sets_sizes_lists in enumerate(all_remain_sets_sizes_lists):
                        cumsum = 0 
                        reached_pos = None
                        for pos, device_set_size in enumerate(process_remain_sets_sizes_lists):
                            cumsum += device_set_size
                            #original_per_device_global_batch_size=2 device bs * 8 GAS = 16
                            if cumsum >= original_per_device_global_batch_size: 
                                reached_pos = pos
                                all_processes_max_num_per_device_global_batch_size = max(all_processes_max_num_per_device_global_batch_size, cumsum)
                                break
                        if reached_pos is None:
                            all_processes_reachable = False
                            break
                        else:
                            reached_sizes = process_remain_sets_sizes_lists[:reached_pos+1] # [2,2,2,2]
                            remain_sizes = process_remain_sets_sizes_lists[reached_pos+1:] # [2,3,2,3]
                            all_reached_sets_sizes_lists.append(reached_sizes)
                            temp_all_remain_sets_sizes_lists.append(remain_sizes)

                    ####
                    # port = 4444+self.args.process_index
                    # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                    # rpdb.set_trace(port=port)
                    ####

                    if not all_processes_reachable: 
                        break # gathered all_set_sizes is no longer needed if it breaks here, sample next batch and append it to the queue

                    ########################## NOTE enough sets to update ###############################
                    # callback_handler will set "should_log=False", "should_evaluate=False", "should_save=False" and then call_event("on_step_begin")
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 

                    # forward_this_global_Nki = [] # \sum_i^number_reached_sets{ N*k_i }
                    forward_Nki_input_ids = []
                    forward_Nki_attention_mask = []
                    forward_Nki_position_ids = []
                    forward_Nki_labels = []
                    # backward_ki= [] #  
                    backward_ki_input_ids = []
                    backward_ki_attention_mask = []
                    backward_ki_position_ids = []
                    backward_ki_labels = []
                    nonpad_this_global_ki = 0
                    # matching ids NOTE CHANGE 5
                    forward_Nki_matching_input_ids = []
                    forward_Nki_matching_attention_mask = []
                    forward_Nki_matching_position_ids = []
                    forward_Nki_matching_labels = []
                    for set_idx, set_size in enumerate(all_reached_sets_sizes_lists[self.args.process_index]):
                        set_seqs = self.set_queue.popleft()
                        # forward_this_global_Nki.extend(set_seqs) # Extend list of N_ki-length, ``seq_input_ids_tensor''
                        # Key line to debug the number of dequeued items from ckpt set-queue matches their reported set sizes 
                        assert len(set_seqs['input_ids']) == self.N_num_sot_tokens*set_size #TODO comment out when actually training
                        forward_Nki_input_ids.extend(set_seqs['input_ids'])
                        forward_Nki_attention_mask.extend(set_seqs['attention_mask'])
                        forward_Nki_position_ids.extend(set_seqs['position_ids'])
                        forward_Nki_labels.extend(set_seqs['labels'])
                        nonpad_this_global_ki += set_size
                        forward_Nki_matching_input_ids.extend(set_seqs['matching_input_ids'])
                        forward_Nki_matching_attention_mask.extend(set_seqs['matching_attention_mask'])
                        forward_Nki_matching_position_ids.extend(set_seqs['matching_position_ids'])
                        forward_Nki_matching_labels.extend(set_seqs['matching_labels'])
                    ################ remaining sets from this process for the next full-grad-accumu ############ #[2,3,2,3]
                    self.remain_sets_sizes_list = self.remain_sets_sizes_list[len(all_reached_sets_sizes_lists[self.args.process_index]):]
                    ############################################################################################
                    # all_processes_max_num_per_device_global_batch_size=17  - nonpad_this_global_ki=16
                    # num_pads = 4*(17-16)=4
                    num_pads = self.N_num_sot_tokens*(all_processes_max_num_per_device_global_batch_size-nonpad_this_global_ki)
                    # max_forward_per_device_global_bs=N4* max_num_per_device_global4=16??
                    max_forward_per_device_global_bs = self.N_num_sot_tokens*all_processes_max_num_per_device_global_batch_size #4*4=68   >4*16=64
                    assert nonpad_this_global_ki * self.N_num_sot_tokens == len(forward_Nki_input_ids) # 
                    # forward_this_global_Nki = forward_this_global_Nki + [torch.tensor([pad_token_id]) for i in range(num_pads)]
                    forward_Nki_input_ids = forward_Nki_input_ids + [torch.tensor([pad_token_id]) for _ in range(num_pads)]
                    forward_Nki_attention_mask = forward_Nki_attention_mask+[torch.tensor([0]) for _ in range(num_pads)]
                    forward_Nki_position_ids = forward_Nki_position_ids+[torch.tensor([0]) for _ in range(num_pads)]
                    forward_Nki_labels = forward_Nki_labels+[torch.tensor([-100]) for _ in range(num_pads)]
                    assert len(forward_Nki_input_ids) == max_forward_per_device_global_bs
                    forward_Nki_matching_input_ids = forward_Nki_matching_input_ids + [torch.tensor([pad_token_id]) for _ in range(num_pads)]
                    forward_Nki_matching_attention_mask = forward_Nki_matching_attention_mask+[torch.tensor([0]) for _ in range(num_pads)]
                    forward_Nki_matching_position_ids = forward_Nki_matching_position_ids+[torch.tensor([0]) for _ in range(num_pads)]
                    forward_Nki_matching_labels = forward_Nki_matching_labels+[torch.tensor([-100]) for _ in range(num_pads)]
                    # NOTE Should be correct here for now
                    # port = 4444+self.args.process_index
                    # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                    # rpdb.set_trace(port=port)

                    ################# Inner discrete optimization #################
                    # forward_bs = _bs*(self.args.max_length//self.args.L_first_matching_tokens)  # NOTE foward_bs can be larger and more efficient due to "no_grad"
                    # _bs=1
                    forward_bs = _bs*(32768//(2*self.args.L_first_matching_tokens))  # NOTE foward_bs can be larger and more efficient due to "no_grad"
                    forward_bs = max(forward_bs, 1)
                    ####
                    # import rpdb
                    # port = 4444+self.args.process_index
                    # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                    # rpdb.set_trace(port=port)
                    # TODO max_forward_per_device_global_bs should be 4 sot tokens * 4 targets in the set!!! 
                    forward_bs = min(forward_bs, max_forward_per_device_global_bs) #min(32768//2000=16, 32)=16 It's okay not to be divisible by 2
                    forward_num_steps = math.ceil(len(forward_Nki_input_ids) / forward_bs) # 2
                    # score_mat = torch.zeros(self.N_num_sot_tokens*nonpad_this_global_ki) # NOTE WRONG, 
                    # NOTE nonpad_this_global_ki=8, all_processes_max_num_per_device_global_batch_size=9
                    # TODO maybe no need to score pad elements since they're not used later but idk if ddp would hang
                    score_mat = torch.zeros(self.N_num_sot_tokens*all_processes_max_num_per_device_global_batch_size)
                    num_toks_mat = torch.zeros(self.N_num_sot_tokens*all_processes_max_num_per_device_global_batch_size)
                    # max_forward_per_device_global_batch_size=24
                    num_tokens_global_Nki = send_to_device(torch.zeros(max_forward_per_device_global_bs), device, non_blocking=True)
                    num_items_in_batch = 0
                    # import ipdb; ipdb.set_trace() # all_processes_max_num_per_device_global_batch_size=8, nonpad_this_global_ki=8
                    with torch.no_grad():
                        for f_i in range(forward_num_steps): # NOTE forward_num_steps=2
                            padded_inputs = {}
                            full_len_labels = pad(forward_Nki_labels[f_i*forward_bs:(f_i+1)*forward_bs], padding_value=-100, padding_side="right")
                            num_tokens_global_Nki[f_i*forward_bs:(f_i+1)*forward_bs] = full_len_labels.ne(-100).sum(dim=1) # all 6*4= 24
                            # padded_inputs["labels"]=padded_inputs["labels"][:, :self.args.L_first_matching_tokens]
                            # NOTE padded_inputs["input_ids"].shape=[16,1173]
                            padded_inputs["input_ids"] = pad(forward_Nki_matching_input_ids[f_i*forward_bs:(f_i+1)*forward_bs], padding_value=pad_token_id, padding_side="right")
                            padded_inputs["attention_mask"] = pad(forward_Nki_matching_attention_mask[f_i*forward_bs:(f_i+1)*forward_bs], padding_value=0, padding_side="right")
                            padded_inputs["position_ids"] = pad(forward_Nki_matching_position_ids[f_i*forward_bs:(f_i+1)*forward_bs], padding_value=0, padding_side="right")
                            padded_inputs["labels"] = pad(forward_Nki_matching_labels[f_i*forward_bs:(f_i+1)*forward_bs], padding_value=-100, padding_side="right")
                            padded_inputs = send_to_device(padded_inputs, device, non_blocking=True)
                            if self.bmsft_phase_train:
                                scores = self.compute_matching_scores(padded_inputs) # TODO put back for training
                                # ## NOTE this should work now
                                # import rpdb
                                # port = 4444+self.args.process_index
                                # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                                # rpdb.set_trace(port=port)
                                # ##
                                # if f_i == 2:
                                #     port = 4444+self.args.process_index
                                #     print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                                #     rpdb.set_trace(port=port)
                                # NOTE below score_mat.shape=16 will all be assigned the same value if scores is a scalar without using updated Liger-Kernel
                                score_mat[f_i*forward_bs:(f_i+1)*forward_bs] = scores
                            del padded_inputs
                    start_idx = 0
                    # fixed_global_bs = self.padded_gradient_accumulation_steps*_bs #3*1=3 NOTE this part is wrong for "padded_gradient_accumulation_steps<num_targets"
                    fixed_global_bs = max(self.padded_gradient_accumulation_steps*_bs, all_processes_max_num_per_device_global_batch_size) #3*1=3
                    # assert fixed_global_bs % _bs == 0
                    backward_num_steps = fixed_global_bs // _bs
                    # score_mat has shape [sum(Nki) till padded ]
                    if self.bmsft_phase_train:
                        for set_idx, set_size in enumerate(all_reached_sets_sizes_lists[self.args.process_index]): #won't iterate pad sets
                            cost_mat = (score_mat[start_idx:start_idx+self.N_num_sot_tokens*set_size]).view(self.N_num_sot_tokens, set_size)
                            # import ipdb; ipdb.set_trace()
                            # row_idx_, col_idx_ = linear_sum_assignment(cost_mat.detach().cpu().float().numpy()) #.shape=ki
                            # import rpdb
                            # port = 4444+self.args.process_index
                            # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                            # rpdb.set_trace(port=port)
                            # CHANGE 1
                            row_idx, col_idx = self.matching_tracker.process_optimal(cost_mat, num_tokens_global_Nki.view(self.N_num_sot_tokens, set_size), backend='cpu')
                            # row_idx, col_idx = self.matching_tracker.process_optimal(cost_mat, backend='cpu')
                            if self.args.debug_randomized_matching:
                                row_idx, col_idx = random_configuration(cost_mat)
                            # import rpdb
                            # port = 4444+self.args.process_index
                            # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                            # rpdb.set_trace(port=port)

                            for N_idx, zi_idx in zip(row_idx, col_idx):
                                # selected_idx = start_idx+N_idx*self.N_num_sot_tokens+zi_idx # THIS IS WRONG
                                selected_idx = start_idx+N_idx*set_size+zi_idx
                                # backward_ki.append(forward_this_global_Nki[selected_idx])
                                # import ipdb; ipdb.set_trace()
                                backward_ki_input_ids.append(forward_Nki_input_ids[selected_idx])
                                backward_ki_attention_mask.append(forward_Nki_attention_mask[selected_idx])
                                backward_ki_position_ids.append(forward_Nki_position_ids[selected_idx])
                                backward_ki_labels.append(forward_Nki_labels[selected_idx])
                                num_items_in_batch += num_tokens_global_Nki[selected_idx]
                                if len(backward_ki_input_ids) >= fixed_global_bs:
                                    break
                            # import rpdb
                            # port = 4444+self.args.process_index
                            # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                            # rpdb.set_trace(port=port)

                            if len(backward_ki_input_ids) >= fixed_global_bs:
                                break
                            start_idx += self.N_num_sot_tokens*set_size
                        ################ GATHER every optimization step ##############
                        num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()

                    # backward_num_steps = math.ceil(all_processes_max_num_per_device_global_batch_size/args.per_device_batch_size)
                    num_pads = fixed_global_bs-len(backward_ki_input_ids) # 0 in case this process has the most num of sequences and cause negative
                    assert num_pads >= 0
                    # import ipdb; ipdb.set_trace() # all_processes_max_num_per_device_global_batch_size=8, nonpad_this_global_ki=8

                    backward_ki_input_ids = backward_ki_input_ids + [torch.tensor([pad_token_id]) for _ in range(num_pads)]
                    backward_ki_attention_mask = backward_ki_attention_mask+[torch.tensor([0]) for _ in range(num_pads)]
                    backward_ki_position_ids = backward_ki_position_ids+[torch.tensor([0]) for _ in range(num_pads)]
                    backward_ki_labels = backward_ki_labels+[torch.tensor([-100]) for _ in range(num_pads)]

                    # backward_ki = backward_ki + [torch.tensor([pad_token_id]) for i in range(all_processes_max_num_per_device_global_batch_size-nonpad_this_global_ki)]
                    # print(f"f{self.set_queue}")
                    for b_i in range(backward_num_steps): #b_i=0, 1, ..., backward_num_steps-1 # backward_num_steps=4
                        do_sync_step = (b_i==backward_num_steps-1)
                        self.accelerator.gradient_state._set_sync_gradients(do_sync_step) # NOTE key line to signal optimizer.step
                        if rng_to_sync: # False
                            self._load_rng_state(resume_from_checkpoint)
                            rng_to_sync = False
                        padded_inputs = {}
                        padded_inputs["input_ids"] = pad(backward_ki_input_ids[b_i*_bs:(b_i+1)*_bs], padding_value=pad_token_id, padding_side="right")
                        padded_inputs["attention_mask"] = pad(backward_ki_attention_mask[b_i*_bs:(b_i+1)*_bs], padding_value=0, padding_side="right")
                        padded_inputs["position_ids"] = pad(backward_ki_position_ids[b_i*_bs:(b_i+1)*_bs], padding_value=0, padding_side="right")
                        padded_inputs["labels"] = pad(backward_ki_labels[b_i*_bs:(b_i+1)*_bs], padding_value=-100, padding_side="right")
                        padded_inputs = send_to_device(padded_inputs, device, non_blocking=True)
                        # ##
                        # import rpdb
                        # port = 4444+self.args.process_index
                        # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                        # rpdb.set_trace(port=port)
                        # ##
                        context = (
                            partial(self.accelerator.no_sync, model=model)
                            if b_i != backward_num_steps - 1
                            and self.accelerator.distributed_type != DistributedType.DEEPSPEED
                            else contextlib.nullcontext
                        )
                        if self.bmsft_phase_train:
                            with context(): 
                                tr_loss_step = self.training_step(model, padded_inputs, num_items_in_batch)
                                # tr_loss_step=0.0833 when self.accelerator.num_processes is not multiplied, fixed after multiplication
                                # num_items_in_batch=209746, padded_inputs['input_ids'][0].shape=14074
                                # import rpdb
                                # port = 4444+self.args.process_index
                                # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
                                # rpdb.set_trace(port=port)
                                # import ipdb; ipdb.set_trace() # all_processes_max_num_per_device_global_batch_size=8, nonpad_this_global_ki=8
                            if (
                                args.logging_nan_inf_filter      # logging_nan_inf_filter=True
                                and not is_torch_xla_available() # not False = True
                                and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) # nan / inf edge case
                            ):
                                # if loss is nan or inf simply add the average of previous logged losses
                                tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
                            else:
                                if tr_loss.device != tr_loss_step.device:
                                    raise ValueError(
                                        f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
                                    )
                                tr_loss = tr_loss + tr_loss_step #NOTE tr_loss is just total loss for logging purpose
                        # self.current_flos += float(self.floating_point_ops(inputs))
                        if do_sync_step:
                            self.accelerator.gradient_state._set_sync_gradients(True)
                            if args.max_grad_norm is not None and args.max_grad_norm > 0:
                                _grad_norm = self.accelerator.clip_grad_norm_(
                                    model.parameters(),
                                    args.max_grad_norm, #=0.2
                                )
                            if (
                                is_accelerate_available()
                                and self.accelerator.distributed_type == DistributedType.DEEPSPEED
                            ): #NOTE HERE, is_accelerate_available()=True, self.accelerator.distributed_type=DEEPSPEED
                                grad_norm = model.get_global_grad_norm() # tensor(6.7867, device='cuda:0', dtype=torch.float64)
                                # In some cases the grad norm may not return a float
                                if hasattr(grad_norm, "item"):
                                    grad_norm = grad_norm.item() # Here, turns into float
                            else:
                                grad_norm = _grad_norm
                            self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
                            # import ipdb; ipdb.set_trace() # all_processes_max_num_per_device_global_batch_size=8, nonpad_this_global_ki=8
                            if self.bmsft_phase_train:
                                self.optimizer.step()
                            self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
                            learning_rate = self._get_learning_rate()
                            if not self.accelerator.optimizer_step_was_skipped: #NOTE not False=True
                                # Delay optimizer scheduling until metrics are generated
                                if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                                    self.lr_scheduler.step() # NOTE HERE 

                            model.zero_grad()
                            ###########
                            self.steps_trained_in_current_epoch += 1
                            ###########
                            self.state.global_step += 1
                            if self.bmsft_phase_train:
                                self.state.epoch = epoch + self.steps_trained_in_current_epoch/precomputed_num_update_steps_per_epoch_list[epoch] # self.state.epoch is a float value
                                self.control = self.callback_handler.on_step_end(args, self.state, self.control) #set control.should_log=True
                                self._maybe_log_save_evaluate(
                                    tr_loss,
                                    grad_norm,
                                    model,
                                    trial,
                                    epoch,
                                    ignore_keys_for_eval,
                                    start_time,
                                    learning_rate=learning_rate,
                                ) # args.should_save=True only for self.args.process_index==0
                                self.matching_tracker.reset()
                            else:
                                self.control = self.callback_handler.on_step_end(args, self.state, self.control)
                        else:
                            self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

                        # PyTorch/XLA relies on the data loader to insert the mark_step for
                        # each step. Since we are breaking the loop early, we need to manually
                        # insert the mark_step here.
                        if self.bmsft_phase_train and (self.control.should_epoch_stop or self.control.should_training_stop):
                            if is_torch_xla_available():
                                xm.mark_step()
                            break
                    if self.bmsft_phase_train and (self.control.should_epoch_stop or self.control.should_training_stop):
                        if is_torch_xla_available():
                            xm.mark_step()
                        break
                if self.bmsft_phase_train and (self.control.should_epoch_stop or self.control.should_training_stop):
                    if is_torch_xla_available():
                        xm.mark_step()
                    break
                if not self.bmsft_phase_train:
                    print(f"sim steps trained in epoch{epoch} = {self.steps_trained_in_current_epoch}")
                # EPOCH WILL END AFTER THIS
                if all(all_wrapped_once):
                    precomputed_num_update_steps_per_epoch_list.append(self.steps_trained_in_current_epoch)
                    print(f"epoch {epoch} done")
                    break
                
        if self.bmsft_phase_train:
            self._total_loss_scalar += tr_loss.item()
            effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError
            train_loss = self._total_loss_scalar / effective_global_step

            metrics = speed_metrics(
                "train",
                start_time,
                num_samples=4000, # temp
                num_steps=self.state.max_steps,
                num_tokens=None, # NOTE None
            )
            # self.store_flos()
            # metrics["total_flos"] = self.state.total_flos
            metrics["train_loss"] = train_loss

            # self.is_in_train = False

            # self._memory_tracker.stop_and_update_metrics(metrics)

            # self.log(metrics)

            run_dir = self._get_output_dir(trial)
            checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

            if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
                for checkpoint in checkpoints_sorted:
                    if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
                        logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                        shutil.rmtree(checkpoint, ignore_errors=True)

            self.control = self.callback_handler.on_train_end(args, self.state, self.control)

            # Wait for the checkpoint to be uploaded.
            print("before finish_current_push")
            self._finish_current_push()
            print("after finish_current_push")

            # After training we make sure to retrieve back the original forward pass method
            # for the embedding layer by removing the forward post hook.
            if self.neftune_noise_alpha is not None:
                self._deactivate_neftune(self.model)

            return TrainOutput(self.state.global_step, train_loss, metrics)
        else:
            print(f"precomputed_num_update_steps_per_epoch_list={precomputed_num_update_steps_per_epoch_list}") 

        

    # def training_step(
    #     self,
    #     model: nn.Module,
    #     inputs: dict[str, Union[torch.Tensor, Any]],
    #     num_items_in_batch: Optional[torch.Tensor] = None,
    # ) -> torch.Tensor:
    #     model.train()
    #     if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
    #         self.optimizer.train()

    def _prepare_dataset(
            self,
            dataset,
            processing_class,
            args,
            packing,
            formatting_func,
            # dataset_name
            ):
        column_names = list(next(iter(dataset)).keys())
        assert formatting_func is None
        with PartialState().local_main_process_first():
            def find_required_indices(s: str, i: int):
                assist = "<|im_start|>assistant\n" # end index of this token will mark the start of assistant message
                a = s.index(assist)
                open_tok = f"<think{i+1}>"
                o = s.index(open_tok, a + len(assist))
                close_tok = f"</think{i+1}>"
                c = s.index(close_tok, o + len(open_tok)) 
                return a, o, c
            map_kwargs = {}
            map_kwargs["desc"] = f"Tokenizing dataset"
            def tokenize(example, processing_class, dataset_text_field, 
                         N_num_sot_tokens, L_first_matching_tokens):
                pass # NOTE processed assert a, o, c do exist now, or raise max length too small
                # NOTE 
                processed = processing_class(
                        example[dataset_text_field],
                        add_special_tokens=True, #won't do anything here
                        truncation=True,
                        padding=False,
                        max_length=args.max_seq_length,
                        return_overflowing_tokens=False,
                        return_length=False,
                    )
                processed['len_in_matching'] = []
                num_targets = len(example[dataset_text_field])//self.N_num_sot_tokens
                len_assistant_tag = len("<|im_start|>assistant\n")
                len_think_tag = len("<think1>")
                sot_tag_id = 0 # change of sot tags doesnt change the output
                for target_id in range(num_targets):
                    # seq_id = sot_tag_id*num_targets + target_id
                    seq_ = example["text"][target_id] # len(example['text'])=16
                    s_assistant_idx, s_think_idx, c_think_idx = find_required_indices(seq_, sot_tag_id)
                    if "len_pre_skip" not in processed:
                        prefix_seq = seq_[:(s_think_idx+len_think_tag)]
                        prefix_to_think_ids = processing_class(text=prefix_seq, add_special_tokens=False)
                        processed["len_pre_skip"] = len(prefix_to_think_ids['input_ids'])
                        prefix_seq = seq_[:(s_assistant_idx+len_assistant_tag)]
                        prefix_to_assistant_ids = processing_class(text=prefix_seq, add_special_tokens=False)
                        processed["len_assistant_skip"] = len(prefix_to_assistant_ids['input_ids'])  # this can be reused too
                    # import rpdb
                    # import torch.distributed as dist
                    # port = 4444+dist.get_rank()
                    # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
                    # rpdb.set_trace(port=port)
                    ####
                    seq_processed = processing_class(text=seq_, add_special_tokens=False)
                    matching_inputs_seq = seq_[:c_think_idx]
                    matching_inputs = processing_class(text=matching_inputs_seq, add_special_tokens=False)
                    matching_input_ids = matching_inputs['input_ids']
                    len_in_matching = min(processed["len_pre_skip"]+L_first_matching_tokens, len(matching_input_ids))
                    processed['len_in_matching'].append(len_in_matching)
                return processed
            fn_kwargs={
                "processing_class": processing_class,
                "dataset_text_field": args.dataset_text_field,
                "N_num_sot_tokens": self.N_num_sot_tokens,
                "L_first_matching_tokens": args.L_first_matching_tokens,
            }
            map_kwargs = {
                "batched": False,
                "load_from_cache_file": True,
            }
            if isinstance(dataset, datasets.Dataset):
                map_kwargs["num_proc"] = self.dataset_num_proc  # this arg is not available for IterableDataset


            ######## Hasher debug ############3
            # right before dataset.map(...)
            #from datasets.fingerprint import Hasher

            #print("Hasher.hash(tokenize):", Hasher.hash(tokenize))

            ## Inspect closure objects that are captured by tokenize
            #if tokenize.__closure__:
            #    for i, cell in enumerate(tokenize.__closure__):
            #        obj = cell.cell_contents
            #        print(f"[closure {i}] type={type(obj).__name__}")

            #        # optional deeper diff: snapshot only primitive fields
            #        import dataclasses, json, hashlib
            #        def sanitize(o):
            #            if dataclasses.is_dataclass(o):
            #                d = dataclasses.asdict(o)
            #            elif hasattr(o, "__dict__"):
            #                d = vars(o).copy()
            #            else:
            #                return str(o)
            #            return {k: v for k, v in d.items()
            #                    if isinstance(v, (int, float, str, bool, type(None)))}

            #        snap = sanitize(obj)
            #        sig = hashlib.sha1(json.dumps(snap, sort_keys=True).encode()).hexdigest()
            #        print(f"[closure {i}] primitive_sha1={sig}")
            #        print(json.dumps(snap, indent=2, sort_keys=True))
            #############################################

            # from datasets.utils.logging import set_verbosity_info
            # set_verbosity_info()
            # print("ds fingerprint before map:", dataset._fingerprint)
            # print("cache files before map:", dataset.cache_files)  # after your first map this should NOT be empty
            # from datasets.fingerprint import Hasher
            # print("FUNC FP before:", Hasher.hash(tokenize))
            dataset = dataset.map(tokenize, fn_kwargs=fn_kwargs, **map_kwargs)
            # print("ds fingerprint after map:", dataset._fingerprint)
            # print("cache files after map:", dataset.cache_files)  # after your first map this should NOT be empty
            # print("FUNC FP after:", Hasher.hash(tokenize))
            # import rpdb
            # import torch.distributed as dist
            # port = 4444+dist.get_rank()
            # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
            # rpdb.set_trace(port=port)

            dataset = dataset.select_columns({"input_ids", "position_ids", "len_pre_skip", "len_in_matching", "len_assistant_skip"}.intersection(dataset.column_names))
        return dataset
            
    def _save_checkpoint(self, model, trial):
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
        if self.hp_search_backend is None and trial is None:
            self.store_flos()
        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir, _internal_call=True)
        # NOTE Also save "set_queue" and "steps_sampled_in_current_epoch" for each process

        if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
            best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
            best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)

            if os.path.exists(best_checkpoint_dir):
                self.state.best_model_checkpoint = best_checkpoint_dir

        if not self.args.save_only_model: # "not False"=True NOTE HERE
            # Save optimizer and scheduler
            self._save_optimizer_and_scheduler(output_dir)
            self._save_scaler(output_dir)
            # Save RNG state
            self._save_rng_state(output_dir)
            torch.save((list(self.set_queue), self.steps_sampled_in_current_epoch, self.steps_trained_in_current_epoch, self.remain_sets_sizes_list), os.path.join(output_dir, f"set_queue_{self.args.process_index}.pth"))

        # Save the Trainer state
        if self.args.should_save: #NOTE only process_index=0 will enter
            # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
            for cb in [
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ]:
                cb_name = cb.__class__.__name__
                cb_state = cb.state()
                if isinstance(self.state.stateful_callbacks[cb_name], list):
                    self.state.stateful_callbacks[cb_name].append(cb_state)
                else:
                    self.state.stateful_callbacks[cb_name] = cb_state
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        # Maybe delete some older checkpoints.
        if self.args.should_save: #NOTE only process_index=0 will enter
            # we use mtime as default, filesystems without mtime support will be detected in `_sorted_checkpoints`
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

    def compute_matching_scores(self, inputs, return_outputs=False):
        inputs = self._prepare_inputs(inputs)
        with self.compute_loss_context_manager():
            mode = "train" if self.model.training else "eval"
            ##### super.compute_loss below########
            labels = None
            loss_kwargs = {"is_scoring": True}
            # if num_items_in_batch is not None:
            #     loss_kwargs["num_items_in_batch"] = num_items_in_batch
            inputs = {**inputs, **loss_kwargs}
            # inputs['input_ids'].shape=[16,1173]
            outputs = self.model(**inputs)
            # import rpdb
            # port = 4444+self.args.process_index
            # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
            # rpdb.set_trace(port=port)
            loss = outputs["loss"].detach() 
            #######################
            # if mode == "train":
            #    pass # TODO Logging related metrics
        del inputs
        return loss

    def _get_dataloader(
        self,
        dataset: Dataset,
        description: str,
        batch_size: int,
        sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
        is_training: bool = False,
        dataloader_key: Optional[str] = None,
    ) -> DataLoader:
        """Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
        # import ipdb; ipdb.set_trace()
        data_collator = self.data_collator
        # NOTE CHANGE 4 keep all the columns
        # self.args.remove_unused_columns = False
        # if is_datasets_available() and isinstance(dataset, datasets.Dataset):
        #     dataset = self._remove_unused_columns(dataset, description=description)
        # else:
        #     data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)

        dataloader_params = {
            "batch_size": batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(dataset, torch.utils.data.IterableDataset):
            if sampler_fn is not None:
                dataloader_params["sampler"] = sampler_fn(dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last  #self.args.dataloader_drop_last=False
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
            if is_training:
                dataloader_params["worker_init_fn"] = partial(
                    seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
                )
        # import ipdb; ipdb.set_trace()
        # dataloader_params['batch_size']:2, 'drop_last': False,  'num_workers':0,  
        # 'persistent_workers':False,  'pin_memory':True,  'prefetch_factor':None,
        # 'sampler': RandomSampler,  'worker_init_fn': seed_worker
        # 'collate_fn': DataCollatorForLanguageModeling(pad_token_id=151643, 
        # completion_only_loss=False, padding_free=False, return_position_ids=True,
        # pad_to_multiple_of=None, return_tensors='pt'
        dataloader = DataLoader(dataset, **dataloader_params) # NOTE dataloader.sampler = <sampler.RandomSampler>

        # Accelerator.free_memory() will destroy the references, so
        # we need to store the non-prepared version for eval dataloaders.
        if dataloader_key is not None and self.args.dataloader_persistent_workers:
            if hasattr(self, "_eval_dataloaders"):
                self._eval_dataloaders[dataloader_key] = dataloader
            else:
                self._eval_dataloaders = {dataloader_key: dataloader}
        return self.accelerator.prepare_data_loader(dataloader,device_placement=False)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
            labels = inputs.pop("labels")
        else: # inputs does have "labels" but label_smoother=None, self.compute_loss_func=None
            labels = None # NOTE here when model already returns loss function (common for fused forward-pass + loss computation)
        if self.model_accepts_loss_kwargs: #=True, HERE
            loss_kwargs = {"is_scoring": False} # NOTE THE ONLY CHANGE from transformers/trainer.py 
            if num_items_in_batch is not None:
                loss_kwargs["num_items_in_batch"] = num_items_in_batch
            inputs = {**inputs, **loss_kwargs}
        # import ipdb; ipdb.set_trace()
        # NOTE inputs.keys()=['input_ids', 'attention_mask', 'position_ids', 'labels', 'num_items_in_batch']
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        # import ipdb; ipdb.set_trace()
        if labels is not None: # labels=None,  labels are in "inputs['labels']"
            unwrapped_model = self.accelerator.unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            # User-defined compute_loss function
            if self.compute_loss_func is not None:
                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
            elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else: # NOTE outputs.keys() = ['loss']
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        # import ipdb; ipdb.set_trace() #somehow self.args.average_tokens_across_devices=False
        # if (
        #     self.args.average_tokens_across_devices #=True for multi-process
        #     and (self.model_accepts_loss_kwargs or self.compute_loss_func) # model_accepts_loss_kwargs=True (True here in s1 code as well)
        #     and num_items_in_batch is not None # =tensor(8000, device'cuda:0') not None obviously
        # ):
        #     loss *= self.accelerator.num_processes
        loss *= self.accelerator.num_processes

        return (loss, outputs) if return_outputs else loss #return_outputs=True

    def _maybe_log_save_evaluate(
        self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
    ):
        # NOTE: self.control.should_log=False, self.state.global_step=0, self._globalstep_last_logged=0 
        if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
            if is_torch_xla_available():
                xm.mark_step()

            # logs: dict[str, float] = {}

            # all_gather + mean() to get average loss over all processes
            # tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

            logs, n_total = ddp_reduce_and_average_stats_and_loss_tensors(
                self.matching_tracker, loss_scalar=tr_loss, num_processes=self.args.world_size
            )

            # reset tr_loss to zero
            tr_loss -= tr_loss

            # import rpdb
            # port = 4444+self.args.process_index
            # print(f"Process {self.args.process_index} waiting for debugger on port {port}")
            # rpdb.set_trace(port=port)
            # logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            logs["loss"] = round(logs["loss"] / (self.state.global_step - self._globalstep_last_logged), 4)

            if grad_norm is not None:
                logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
            if learning_rate is not None:
                logs["learning_rate"] = learning_rate
            else:
                logs["learning_rate"] = self._get_learning_rate()

            # self._total_loss_scalar += tr_loss_scalar
            self._total_loss_scalar += logs["loss"]
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()

            self.log(logs, start_time)

        metrics = None
        if self.control.should_evaluate:
            metrics = self._evaluate(trial, ignore_keys_for_eval)
            is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

            if self.args.save_strategy == SaveStrategy.BEST:
                self.control.should_save = is_new_best_metric

        if self.control.should_save: # True every save_steps / every epoch
            self._save_checkpoint(model, trial)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

