"""
ESLM: Efficient Selective Language Modeling
Key implementation functions for hierarchical risk-aware batch selection

1. Instance-level selection via early-exit proxy (Stage 1)
2. Token-level loss shaping via VaR/CVaR thresholding (Stage 2)
3. Adaptive alpha adjustment (ADA-ESLM)
"""
import os
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from contextlib import nullcontext
import torch.distributed as dist
import pickle
import random
import json
import matplotlib.pyplot as plt
# Profiling related:
# from torch.profiler import profile, record_function, ProfilerActivity
from ptflops import get_model_complexity_info
from calflops import calculate_flops
import gc
import math
from utils.configurator import parse_arguments, set_random_seed
from utils.consts import set_globals
from utils.utils import estimate_reference_flops, estimate_reference_flops_theoretical, compute_cvar, log_to_file, print_memory_usage
from model_w_proxy_sentence import (GPTConfig, GPT)
from utils.flops_meter import FlopsMeter
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")


out_dir = './eslm/'
data_dir = './eslm/'
log_dir = './eslm/'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False  # if True, script exits right after the first eval
always_save_checkpoint = False  # if True, always save a checkpoint after each eval
init_from = 'scratch'  # 'scratch' or 'resume' or 'gpt2*'
resume_ckpt_path = 'ckpt.pt'
ckpt_save_interval = eval_interval
# data
dataset = 'slim_6b'
gradient_accumulation_steps = 5 * 8  # used to simulate larger batch sizes
batch_size = 12  # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
# model
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0  # for pretraining 0 is good, for finetuning try 0.1+
bias = False  # do we use bias inside LayerNorm and Linear layers?
is_select_token = True
masker_type = 'self_supervised'
token_select_type = 'cvar_thr_loss'
is_ada_eslm = False
is_calc_precise_flops = False
num_tokens = 0  # for ESLM, to count selected number of tokens
# KD-based experiments and baselines
is_kl_distill = False
is_dense_kl_distill = False
is_salt = False  # enable SALT-baseline training
# Optimizer type
optimizer_type = 'AdamW'  # ['AdamW', 'SGD', 'TIDE']
learning_rate = 6e-4  # max learning rate
max_iters = 200000  # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0  # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True  # whether to decay the learning rate
warmup_iters = 2000  # how many steps to warm up for
lr_decay_iters = 200000  # should be ~= max_iters per Chinchilla
min_lr = 6e-5  # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# DDP settings
backend = 'nccl'  # 'nccl', 'gloo', etc.
# system
device = 'cuda'  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'  # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True  # use PyTorch 2.0 to compile the model to be faster
# wandb logging
wandb_log = True  # disabled by default
wandb_project = 'eslm'
run_name = f'gpt2_{optimizer_type}'  # 'run' + str(time.time())
seed = None
# Label noise
label_noise = 0.0 # Probability of corrupting labels with random noise
# Attention sparsity analysis
sparsity_thr = 1e-3 # threshold for attention sparsity to loss-masked tokens
sparsity_check_interval = 1000
# Grad norm analysis
is_ffn_grad_analysis = False
grad_analysis_interval = 1000
#####
# dataset mixtures
is_train_on_mix_batch=False # Whether to train on DataMixture, but a batch contains different domain inputs
train_dataset_mixtures = {
    "arxiv": 0.04235,
    "book": 0.08201,
    "cc": 0.381,
    "c4": 0.1141,
    "github": 0.0654,
    "stackexchange": 0.0847,
    "wikipedia": 0.2305
}
val_dataset_mixtures = {
    "arxiv": 0.1428,
    "book": 0.1428,
    "cc": 0.1428,
    "c4": 0.1428,
    "github": 0.1428,
    "stackexchange": 0.1428,
    "wikipedia": 0.1428
}
# ----------CONFIGURATION & UPDATE CONSTS-------------
args = parse_arguments()
# print(args)
# Initialize the globals
set_globals(args, globals())
# print(f"globals set to: {globals().items}")
set_random_seed(seed)
# print(f"seed set to: {seed}")
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
# print(f"config_keys set to: {config_keys}")
# Poor Man's configurator
# exec(open('utils/configurator.py').read())  # overrides from command line or config file
config = {k: globals()[k] for k in config_keys}  # will be useful for logging
# print(f"config set to: {config}")
# -----------------------------------------------------------------------------
# ----Assertions:
# args error checking and convenience variables
assert 1 <= args.block_size <= 1024   # sequence length
assert args.dtype in {"float32", "float16", "bfloat16"}
# -----------------------------------------------------------------------------


def print0(*args, **kwargs):
    # modified print that only prints from the master process
    # if this is not a distributed run, it's just a print
    if int(os.environ.get("RANK", 0)) == 0:
        print(*args, **kwargs)


print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:")


# various inits, derived attributes, I/O setup
ddp = int(os.environ.get('RANK', -1)) != -1  # is this a ddp run?
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
    seed_offset = ddp_rank  # each process gets a different seed
    # world_size number of processes will be training simultaneously, so we can scale
    # down the desired gradient accumulation iterations per process proportionally
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    # if not ddp, we are running on a single gpu, and one process
    master_process = True
    seed_offset = 0
    ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(seed + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu'  # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join(data_dir, dataset)


# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
# model init
if is_select_token:
    model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout, is_flash_att=args.is_flash_att,
                  is_select_token=is_select_token, token_select_type=token_select_type)  # start with model_args from command line
else:
    model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout, is_flash_att=args.is_flash_att,
                  is_select_token=None, token_select_type='')  # start with model_args from command line
if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    # resume training from a checkpoint.
    checkpoint = torch.load(resume_ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    # force these config attributes to be equal otherwise we can't even resume training
    # the rest of the attributes (e.g. dropout) can stay as desired from command line
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    # create the model
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    # fix the keys of the state dictionary :(
    # honestly no idea how checkpoints sometimes get this prefix, have to debug more
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    # initialize from OpenAI GPT-2 weights
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    # read off the created config params, so we can store them into checkpoint correctly
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)
# crop down the model block size if desired, using model surgery
if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size  # so that the checkpoint will have the right value
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)


# ============================================================================
# STAGE 1: Instance-Level Selection (Proxy Phase)
# ============================================================================

def forward_early_exit(model, input_ids, labels, num_layers, metric="nll"):
    """
    Compute per-instance risk scores using shallow early-exit pass.

    Args:
        model: Transformer model with accessible layer structure
        input_ids: Input token IDs [batch_size, seq_len]
        labels: Target token IDs [batch_size, seq_len]
        num_layers: Number of layers to use for proxy (L in paper, typically 1-3)
        metric: Risk metric - "nll" (negative log-likelihood) or "entropy"

    Returns:
        seq_scores: Per-instance risk scores [batch_size]
    """
    model.eval()
    with torch.no_grad():
        # Run forward pass through first L layers only
        hidden_states = model.transformer.wte(input_ids)

        # Pass through first L transformer blocks
        for i in range(num_layers):
            hidden_states = model.transformer.h[i](hidden_states)[0]

        # Apply layer norm and project to vocab
        hidden_states = model.transformer.ln_f(hidden_states)
        logits = model.lm_head(hidden_states)  # [batch_size, seq_len, vocab_size]

        # Compute per-token scores based on metric
        if metric == "nll":
            # Negative log-likelihood (loss)
            token_losses = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-1,
                reduction='none'
            ).view(labels.shape)  # [batch_size, seq_len]

            # Average over valid (non-padded) tokens per sequence
            valid_mask = (labels != -1)
            seq_scores = (token_losses * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1)

        elif metric == "entropy":
            # Predictive entropy
            probs = F.softmax(logits, dim=-1)
            token_entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)  # [batch_size, seq_len]

            # Average over valid tokens per sequence
            valid_mask = (labels != -1)
            seq_scores = (token_entropy * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1)

        else:
            raise ValueError(f"Unknown metric: {metric}")

    model.train()
    return seq_scores


def instance_level_selection(seq_scores, alpha, ddp=False, world_size=1):
    """
    Select high-risk instances using VaR thresholding.

    Args:
        seq_scores: Per-instance risk scores [pool_size]
        alpha: Confidence level (e.g., 0.1 keeps top 90%)
        ddp: Whether using distributed training
        world_size: Number of GPUs for DDP

    Returns:
        keep_indices: Indices of selected instances
        threshold: VaR threshold value
    """
    # Compute VaR threshold
    if ddp and world_size > 1:
        # Global quantile across all GPUs
        threshold = _global_quantile(seq_scores, alpha, world_size)
    else:
        threshold = torch.quantile(seq_scores, alpha)

    # Select instances above threshold
    keep_indices = torch.nonzero(seq_scores >= threshold).squeeze(1)

    return keep_indices, threshold


def _global_quantile(scores, q, world_size):
    """Compute quantile across all GPUs in DDP setting."""
    import torch.distributed as dist

    # Gather variable-length scores from all ranks
    local_n = torch.tensor([scores.numel()], device=scores.device, dtype=torch.long)
    ns = [torch.zeros_like(local_n) for _ in range(world_size)]
    dist.all_gather(ns, local_n)

    max_n = int(torch.stack(ns).max().item())

    # Pad local scores
    pad = torch.empty(max_n, device=scores.device, dtype=scores.dtype).fill_(float('-inf'))
    pad[:scores.numel()] = scores

    # Gather all padded scores
    bufs = [torch.empty_like(pad) for _ in range(world_size)]
    dist.all_gather(bufs, pad)

    # Concatenate and compute global quantile
    all_scores = torch.cat([b[:n.item()] for b, n in zip(bufs, ns)])
    return torch.quantile(all_scores, q)


# ============================================================================
# STAGE 2: Token-Level Loss Shaping
# ============================================================================

def compute_token_risk_scores(logits, labels, metric="loss"):
    """
    Compute per-token risk scores for loss shaping.

    Args:
        logits: Model predictions [batch_size, seq_len, vocab_size]
        labels: Target tokens [batch_size, seq_len]
        metric: "loss" for CVaR-loss or "entropy" for VaR-entropy

    Returns:
        token_scores: Per-token risk scores [batch_size, seq_len]
    """
    if metric == "loss":
        # Per-token cross-entropy loss
        token_scores = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-1,
            reduction='none'
        ).view(labels.shape)

    elif metric == "entropy":
        # Per-token predictive entropy
        probs = F.softmax(logits, dim=-1)
        token_scores = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)

    else:
        raise ValueError(f"Unknown metric: {metric}")

    return token_scores


def token_level_selection(token_scores, labels, alpha):
    """
    Select high-risk tokens using VaR thresholding.

    Args:
        token_scores: Per-token risk scores [batch_size, seq_len]
        labels: Target tokens (to identify valid positions) [batch_size, seq_len]
        alpha: Confidence level

    Returns:
        selected_mask: Boolean mask [batch_size, seq_len], True for selected tokens
        threshold: VaR threshold value
    """
    # Flatten and filter valid tokens
    valid_mask = (labels != -1)
    valid_scores = token_scores[valid_mask]

    # Compute VaR threshold over valid tokens
    threshold = torch.quantile(valid_scores, alpha)

    # Create selection mask
    selected_mask = (token_scores >= threshold) & valid_mask

    return selected_mask, threshold


def shape_loss(logits, labels, selected_mask):
    """
    Compute loss only on selected high-risk tokens.

    Args:
        logits: Model predictions [batch_size, seq_len, vocab_size]
        labels: Target tokens [batch_size, seq_len]
        selected_mask: Boolean mask [batch_size, seq_len]

    Returns:
        loss: Shaped loss (CVaR or masked expectation)
    """
    # Compute per-token loss
    token_losses = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        labels.view(-1),
        ignore_index=-1,
        reduction='none'
    ).view(labels.shape)

    # Apply selection mask
    masked_losses = token_losses * selected_mask

    # Average over selected tokens
    num_selected = selected_mask.sum()
    if num_selected > 0:
        loss = masked_losses.sum() / num_selected
    else:
        # Fallback: if no tokens selected, use all valid tokens
        valid_mask = (labels != -1)
        loss = (token_losses * valid_mask).sum() / valid_mask.sum().clamp(min=1)

    return loss


# ============================================================================
# ADAPTIVE ALPHA (ADA-ESLM)
# ============================================================================

def compute_cvar(risk_scores, alpha_tail=0.1):
    """
    Compute Conditional Value-at-Risk (CVaR) for adaptive alpha adjustment.

    Args:
        risk_scores: Tensor of token-level risk scores
        alpha_tail: Defines tail region (e.g., 0.1 = top 10%)

    Returns:
        cvar: Mean risk score in the tail
    """
    if len(risk_scores) == 0:
        return 0.0

    # Compute VaR threshold for tail
    threshold = torch.quantile(risk_scores, 1.0 - alpha_tail)

    # Compute mean of tail (CVaR)
    tail_scores = risk_scores[risk_scores >= threshold]
    cvar = tail_scores.mean().item() if len(tail_scores) > 0 else risk_scores.mean().item()

    return cvar


def update_alpha_ada_eslm(current_cvar, prev_cvar, current_alpha, step_size=0.05):
    """
    Adaptively update alpha based on CVaR feedback.

    Args:
        current_cvar: CVaR at current iteration
        prev_cvar: CVaR at previous iteration
        current_alpha: Current alpha level
        step_size: Update rate (gamma in paper)

    Returns:
        new_alpha: Updated alpha level
    """
    if prev_cvar is None or prev_cvar == 0:
        return current_alpha

    # Compute normalized change
    delta = current_cvar - prev_cvar
    normalized_delta = delta / (abs(prev_cvar) + 1e-6)

    # Exponential update: alpha * exp(-gamma * delta_norm)
    # If CVaR increases (harder batch) -> decrease alpha (include more tokens)
    # If CVaR decreases (easier batch) -> increase alpha (be more selective)
    new_alpha = current_alpha * torch.exp(torch.tensor(-step_size * normalized_delta))

    # Clip to reasonable bounds
    new_alpha = float(torch.clamp(new_alpha, 0.01, 0.5))

    return new_alpha


# ============================================================================
# COMPLETE ESLM TRAINING STEP
# ============================================================================

def eslm_training_step(
        model,
        candidate_pool_X,
        candidate_pool_Y,
        optimizer,
        alpha=0.1,
        proxy_layers=1,
        metric="loss",
        is_ada_eslm=False,
        ada_eslm_state=None
):
    """
    Complete ESLM training step with hierarchical selection.

    Args:
        model: Language model
        candidate_pool_X: Input candidates [pool_size, seq_len]
        candidate_pool_Y: Target candidates [pool_size, seq_len]
        optimizer: Optimizer
        alpha: Confidence level for selection
        proxy_layers: Number of layers for Stage 1 proxy
        metric: "loss" or "entropy"
        is_ada_eslm: Whether to use adaptive alpha
        ada_eslm_state: Dictionary with CVaR history for ADA-ESLM

    Returns:
        loss: Training loss
        stats: Dictionary with selection statistics
    """

    # ===== STAGE 1: Instance-Level Selection =====
    proxy_metric = "nll" if metric == "loss" else "entropy"
    seq_scores = forward_early_exit(
        model, candidate_pool_X, candidate_pool_Y,
        num_layers=proxy_layers,
        metric=proxy_metric
    )

    keep_indices, instance_threshold = instance_level_selection(
        seq_scores, alpha, ddp=False, world_size=1
    )

    # Select instances
    X_selected = candidate_pool_X[keep_indices]
    Y_selected = candidate_pool_Y[keep_indices]

    # ===== STAGE 2: Token-Level Loss Shaping =====
    model.train()
    optimizer.zero_grad()

    # Forward pass on selected instances
    logits, _ = model(X_selected, Y_selected)

    # Compute token-level risk scores
    token_scores = compute_token_risk_scores(logits, Y_selected, metric=metric)

    # Select high-risk tokens
    selected_mask, token_threshold = token_level_selection(
        token_scores, Y_selected, alpha
    )
    loss = shape_loss(logits, Y_selected, selected_mask)

    # Backward pass
    loss.backward()
    optimizer.step()

    # ===== ADAPTIVE ALPHA (Optional) =====
    if is_ada_eslm and ada_eslm_state is not None:
        current_cvar = compute_cvar(token_scores[Y_selected != -1], alpha_tail=0.1)
        prev_cvar = ada_eslm_state.get('prev_cvar', None)

        new_alpha = update_alpha_ada_eslm(
            current_cvar, prev_cvar, alpha,
            step_size=ada_eslm_state.get('step_size', 0.05)
        )

        # Update state
        ada_eslm_state['prev_cvar'] = current_cvar
        alpha = new_alpha

    # ===== Statistics =====
    stats = {
        'loss': loss.item(),
        'alpha': alpha,
        'instance_threshold': instance_threshold.item(),
        'token_threshold': token_threshold.item(),
        'instances_kept': len(keep_indices),
        'tokens_kept': selected_mask.sum().item(),
        'instances_total': len(candidate_pool_X),
        'tokens_total': (Y_selected != -1).sum().item(),
        'instance_keep_ratio': len(keep_indices) / len(candidate_pool_X),
        'token_keep_ratio': selected_mask.sum().item() / (Y_selected != -1).sum().item()
    }

    if is_ada_eslm:
        stats['cvar'] = ada_eslm_state['prev_cvar']

    return loss, stats, alpha


def corrupt_labels(y_batch, vocab_size, noise_prob=0.1):
    """ Randomly corrupt labels with probability `noise_prob`. """
    mask = torch.rand_like(y_batch.float()) < noise_prob
    random_labels = torch.randint_like(y_batch, low=0, high=vocab_size)
    y_noisy = torch.where(mask, random_labels, y_batch)
    return y_noisy

def get_candidate_pool(split, candidate_pool_size):
    """
    Samples a large pool of candidate sequences for scoring.
    Returns: X_pool, Y_pool, selected_datasets_pool
    """
    if type(candidate_pool_size)==float:
        candidate_pool_size = math.ceil(candidate_pool_size)

    if is_train_on_mix_batch:
        dataset_names = list(train_dataset_mixtures.keys()) if split == 'train' else list(val_dataset_mixtures.keys())
        dataset_probs = list(train_dataset_mixtures.values()) if split == 'train' else list(val_dataset_mixtures.values())

        selected_datasets = random.choices(dataset_names, weights=dataset_probs, k=candidate_pool_size)
        x_batch, y_batch = [], []

        for dataset_name in selected_datasets:
            dataset_path = os.path.join(data_dir, dataset_name, f"{split}.bin")
            data = np.memmap(dataset_path, dtype=np.uint16, mode='r')

            # Sample a single random sequence from this dataset
            ix = torch.randint(len(data) - block_size, (1,))
            x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
            y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])

            x_batch.append(x)
            y_batch.append(y)

        x_batch = torch.cat(x_batch, dim=0)
        y_batch = torch.cat(y_batch, dim=0)

        if label_noise > 0:
            y_batch = corrupt_labels(y_batch, model_args['vocab_size'], noise_prob=label_noise)

        if device_type == "cuda":
            x_batch, y_batch = (x_batch.pin_memory().to(device, non_blocking=True),
                                y_batch.pin_memory().to(device, non_blocking=True))
        else:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # print(f"X shape: {x_batch.shape}, Y shape: {y_batch.shape}, selected_datasets: {selected_datasets}")
        return x_batch, y_batch, selected_datasets  # Return dataset names for each sequence
    else:
        # Handle single dataset case (unchanged)
        dataset_path = os.path.join(data_dir, f"{split}.bin")
        data = np.memmap(dataset_path, dtype=np.uint16, mode='r')

        ix = torch.randint(len(data) - block_size, (candidate_pool_size,))
        x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])

        if device_type == 'cuda':
            x, y = (x.pin_memory().to(device, non_blocking=True),
                    y.pin_memory().to(device, non_blocking=True))
        else:
            x, y = x.to(device), y.to(device)
        return x, y, [dataset] * batch_size  # Assign single dataset to batch



# ============================================================================

def run():
    # ESLM hyperparameters
    alpha = 0.1  # Confidence level
    proxy_layers = 1  # Early-exit depth for Stage 1
    K = 1  # Candidate pool multiplier
    batch_size = 12

    # ADA-ESLM state (optional)
    ada_eslm_state = {
        'prev_cvar': None,
        'step_size': 0.05
    }

    for iter_num in range(max_iters):
        # Sample candidate pool (K times larger than batch)
        candidate_pool_X, candidate_pool_Y = get_candidate_pool(
            split='train', candidate_pool_size=K * batch_size
        )

        # ESLM training step
        loss, stats, alpha = eslm_training_step(
            model=model,
            candidate_pool_X=candidate_pool_X,
            candidate_pool_Y=candidate_pool_Y,
            optimizer=optimizer,
            alpha=alpha,
            proxy_layers=proxy_layers,
            metric="loss",  # or "entropy"
            is_ada_eslm=True,  # Enable adaptive alpha
            ada_eslm_state=ada_eslm_state
        )

        # Logging
        if iter_num % 100 == 0:
            print(f"Iter {iter_num}: Loss={loss:.4f}, Alpha={alpha:.4f}, "
                  f"Instances kept={stats['instance_keep_ratio']:.2%}, "
                  f"Tokens kept={stats['token_keep_ratio']:.2%}")


if __name__ == "__main__":
    run()