import os
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch.nn.functional as F
import logging

from sal.config import Config
from sal.utils.parser import H4ArgumentParser
from utils.save_mapping import save_temperature_dict_npz

# setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _get_base_model(module):
    for attr in ("model", "transformer", "base_model", "language_model", "backbone"):
        if hasattr(module, attr):
            return getattr(module, attr)
    return None


def get_last_hidden_only(model, input_ids, attention_mask):
    """Get last layer hidden state (single GPU): prioritize using base model's last_hidden_state; if unavailable, use output_hidden_states for the last layer."""
    base = _get_base_model(model)
    if base is not None:
        out = base(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        if hasattr(out, "last_hidden_state") and out.last_hidden_state is not None:
            return out.last_hidden_state
        if hasattr(out, "hidden_states") and out.hidden_states is not None:
            return out.hidden_states[-1]
    # Fallback: request hidden_states and take the last layer
    out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True)
    return out.hidden_states[-1]


def joint_train_delta_temp(model, prompt, completions, scores, tokenizer, config):
    """
    Jointly train delta (hidden state adjustment) and temperature (logits scaling).
    """
    # 1. select top-K or softmax sample K
    n1 = min(config.n1, len(completions))
    k = min(config.k, n1)
    completions_n1 = completions[:n1]
    scores_n1 = scores[:n1]

    # Score calculation method
    if config.score_selection == "last":
        # scores_n1: List[List[float]] or List[float]
        # If List[List], take the last element of each
        scores_flat = [s[-1] if isinstance(s, (list, np.ndarray)) else s for s in scores_n1]
    elif config.score_selection == "product":
        scores_flat = [float(np.prod(s)) if isinstance(s, (list, np.ndarray)) else s for s in scores_n1]
    elif config.score_selection == "average":
        scores_flat = [float(np.mean(s)) if isinstance(s, (list, np.ndarray)) else s for s in scores_n1]
    else:
        raise ValueError(f"Unknown score_selection: {config.score_selection}")

    # Selection method
    if config.dataset_selection == "top":
        # Take the k highest scoring samples
        topk_indices = np.argsort(scores_flat)[-k:]
        selected_completions = [completions_n1[i] for i in topk_indices]
    elif config.dataset_selection == "softmax":
        # Use softmax(score) as probability to sample k items (without replacement)
        exp_scores = np.exp(scores_flat - np.max(scores_flat))
        probs = exp_scores / np.sum(exp_scores)
        # If k > n1, it will automatically reduce to n1
        selected_indices = np.random.choice(np.arange(n1), size=k, replace=False, p=probs)
        selected_completions = [completions_n1[i] for i in selected_indices]
    else:
        raise ValueError(f"Unknown dataset_selection: {config.dataset_selection}")

    # 2. build input_ids/labels (mask prompt)
    input_ids_list = []
    labels_list = []
    for ans_text in selected_completions:
        # system prompt
        conv = [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": ans_text}
        ]
        if hasattr(tokenizer, "apply_chat_template"):
            full_text = tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
        else:
            full_text = (config.system_prompt + "\n" if config.system_prompt else "") + prompt + "\n" + ans_text
        encoded = tokenizer(full_text, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_tokens)
        input_ids = encoded["input_ids"].squeeze(0)
        input_ids_list.append(input_ids)

        # prompt masking (system+user)
        user_conv = [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt}
        ]
        if hasattr(tokenizer, "apply_chat_template"):
            user_text = tokenizer.apply_chat_template(user_conv, tokenize=False, add_generation_prompt=False)
        else:
            user_text = (config.system_prompt + "\n" if config.system_prompt else "") + prompt
        user_encoded = tokenizer(user_text, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_tokens)
        user_input_ids = user_encoded["input_ids"].squeeze(0)
        assistant_start = user_input_ids.shape[0]
        labels = input_ids.clone()
        labels[:assistant_start] = -100
        if tokenizer.pad_token_id is not None:
            labels[labels == tokenizer.pad_token_id] = -100
        labels_list.append(labels)

    # 3. padding (delay moving to GPU to avoid memory usage in large batches)
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_id)
    labels_padded = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100)
    attention_mask = (input_ids_padded != pad_id)

    device = model.device

    # 4. joint delta and temperature training
    calib_bs = int(getattr(config, "calib_batch_size", k))

    # dtype selection (prefer bf16)
    use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
    # Estimate hidden size (prefer using lm_head.in_features if available)
    hidden_size = getattr(getattr(model, 'lm_head', None), 'in_features', None)
    if hidden_size is None:
        hidden_size = getattr(model.get_input_embeddings(), 'embedding_dim', None)
    if hidden_size is None:
        # Fallback: use one minimal sample inference to get H
        with torch.inference_mode():
            tmp_ids = torch.nn.utils.rnn.pad_sequence([input_ids_list[0]], batch_first=True, padding_value=pad_id).to(device)
            tmp_mask = (tmp_ids != pad_id)
            last_h = get_last_hidden_only(model, tmp_ids, tmp_mask)
            hidden_size = last_h.shape[-1]
            del last_h, tmp_ids, tmp_mask
            torch.cuda.empty_cache()

    ablate_delta = getattr(config, "ablate_delta", False)
    ablate_temperature = getattr(config, "ablate_temperature", False)
    
    # Revert to fp32 to maintain consistency with existing experiments and ease numpy storage
    delta_dtype = torch.float32
    temp_dtype = torch.float32

    delta = torch.nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=delta_dtype))
    temperature = torch.nn.Parameter(torch.tensor(float(config.init_temp), device=device, dtype=temp_dtype))

    # setup optimizer
    optim_params = []
    if not ablate_delta:
        optim_params.append({'params': delta, 'weight_decay': getattr(config, 'weight_decay', 1e-2)})
    if not ablate_temperature:
        optim_params.append({'params': temperature, 'weight_decay': 0.0})

    if len(optim_params) == 0:
        raise ValueError("Both ablate_delta and ablate_temperature are True, nothing to optimize.")

    optimizer = torch.optim.AdamW(optim_params, lr=config.calib_lr, eps=1e-5)

    model.eval()

    # Decide whether to try full-batch based on calib_batch_size:
    # - When calib_batch_size >= k: first try full-batch to cache hidden; if OOM, switch to micro-batch halving strategy
    # - When calib_batch_size < k: directly use micro-batch, starting with calib_batch_size, no full-batch attempt
    use_full_batch = False
    sel_hidden_full = None
    sel_labels_full = None
    if calib_bs >= k:
        try:
            # First try full k samples forward pass
            input_ids_padded_dev = input_ids_padded.to(device)
            labels_padded_dev = labels_padded.to(device)
            attention_mask_dev = attention_mask.to(device)
            with torch.inference_mode():
                hidden_states = get_last_hidden_only(model, input_ids_padded_dev, attention_mask_dev).detach()
            # Only keep required positions (done outside inference_mode block to avoid inference tensor in autograd)
            hidden_prev = hidden_states[..., :-1, :]
            labels_next = labels_padded_dev[..., 1:]
            valid_mask = (labels_next != -100)
            if valid_mask.any().item():
                target_dtype = model.lm_head.weight.dtype
                sel_hidden_full = hidden_prev[valid_mask].contiguous().to(target_dtype).clone()
                sel_labels_full = labels_next[valid_mask].contiguous()
            else:
                sel_hidden_full = None
                sel_labels_full = None
            # Release full-batch large tensors to reduce memory usage
            del hidden_states, input_ids_padded_dev, attention_mask_dev, labels_padded_dev, hidden_prev, labels_next, valid_mask
            torch.cuda.empty_cache()
            use_full_batch = True
        except torch.cuda.OutOfMemoryError:
            logger.warning("OOM on full batch; switching to micro-batch and halving batch_size...")
            try:
                del input_ids_padded_dev, attention_mask_dev, labels_padded_dev
            except Exception:
                pass
            try:
                del hidden_states, hidden_prev, labels_next, valid_mask
            except Exception:
                pass
            torch.cuda.empty_cache()
            use_full_batch = False

    if use_full_batch:
        try:
            # Prepare position-level chunk parameters (0 or negative values mean no chunking)
            pos_chunk = int(getattr(config, "max_selected_chunk_size", 0))

            # If no valid positions, can return directly (avoid unnecessary computation)
            if sel_hidden_full is None or sel_labels_full is None or sel_labels_full.numel() == 0:
                delta_detached = delta.detach()
                temperature_detached = temperature.detach()
                try:
                    del sel_hidden_full, sel_labels_full
                except Exception:
                    pass
                torch.cuda.empty_cache()
                return delta_detached, temperature_detached

            N_total = int(sel_labels_full.numel())

            for epoch in range(config.calib_epochs):
                # Try full-batch training for this epoch; if OOM, enable/reduce pos_chunk and retry current epoch
                while True:
                    try:
                        optimizer.zero_grad()
                        total_loss = 0.0

                        lm_w = model.lm_head.weight.detach()
                        lm_b = model.lm_head.bias.detach() if getattr(model.lm_head, 'bias', None) is not None else None

                        target_dtype = lm_w.dtype
                        delta_t = delta.to(target_dtype).view(1, -1)
                        temp_t = temperature.to(target_dtype)

                        if pos_chunk <= 0 or pos_chunk >= N_total:
                            # No chunking (legacy behavior)
                            logits_sel = F.linear(sel_hidden_full + delta_t, lm_w, lm_b) / temp_t
                            loss = F.cross_entropy(logits_sel, sel_labels_full)
                            loss.backward()
                            total_loss += float(loss.detach().cpu())
                            del logits_sel, loss
                        else:
                            # Position-level chunking: sum then divide by N (equivalent to mean)
                            loss_sum = torch.zeros((), device=sel_hidden_full.device, dtype=torch.float32)
                            for s in range(0, N_total, pos_chunk):
                                e = min(s + pos_chunk, N_total)
                                h_chunk = sel_hidden_full[s:e] + delta_t
                                logits_chunk = F.linear(h_chunk, lm_w, lm_b) / temp_t
                                labels_chunk = sel_labels_full[s:e]
                                loss_sum = loss_sum + F.cross_entropy(logits_chunk, labels_chunk, reduction='sum')
                                del h_chunk, logits_chunk, labels_chunk
                                torch.cuda.empty_cache()
                            loss = loss_sum / max(N_total, 1)
                            loss.backward()
                            total_loss += float(loss.detach().cpu())
                            del loss_sum, loss

                        del lm_w, lm_b, delta_t, temp_t
                        torch.cuda.empty_cache()

                        optimizer.step()

                        if hasattr(config, 'log_epoch_loss') and config.log_epoch_loss:
                            delta_norm = delta.norm().item()
                            temp_value = float(temperature.detach().to(torch.float32).item())
                            logger.info(f"Joint training epoch {epoch+1}/{config.calib_epochs} | loss: {total_loss:.6f} | norm(delta): {delta_norm:.6f} | temperature: {temp_value:.6f} | pos_chunk: {pos_chunk if pos_chunk>0 else 'none'}")

                        break  # This epoch completed successfully
                    except torch.cuda.OutOfMemoryError:
                        # Prioritize enabling/reducing position-level chunk, then retry current epoch; only use micro-batch if chunking to 1 still causes OOM
                        torch.cuda.empty_cache()
                        optimizer.zero_grad(set_to_none=True)
                        if pos_chunk <= 0:
                            suggest = max(256, min(4096, N_total))
                            pos_chunk = suggest
                            logger.warning(f"OOM in full-batch (no chunk). Retrying epoch {epoch+1} with pos_chunk={pos_chunk}.")
                            continue
                        new_chunk = max(1, pos_chunk // 2)
                        if new_chunk == pos_chunk:
                            raise
                        logger.warning(f"OOM in full-batch with pos_chunk={pos_chunk}. Retrying epoch {epoch+1} with pos_chunk={new_chunk}.")
                        pos_chunk = new_chunk
                        continue

            delta_detached = delta.detach()
            temperature_detached = temperature.detach()
            del sel_hidden_full, sel_labels_full
            torch.cuda.empty_cache()
            return delta_detached, temperature_detached
        except torch.cuda.OutOfMemoryError:
            logger.warning("OOM during full-batch even after chunking; falling back to micro-batch with halving...")
            try:
                del sel_hidden_full, sel_labels_full
            except Exception:
                pass
            torch.cuda.empty_cache()

    try:
        del sel_hidden_full, sel_labels_full
    except Exception:
        pass
    torch.cuda.empty_cache()
    if calib_bs >= k:
        micro_bs = max(1, k // 2)
    else:
        micro_bs = int(min(calib_bs, k))

    epoch = 0
    while epoch < config.calib_epochs:
        try:
            optimizer.zero_grad()
            total_loss = 0.0
            accum_steps = int(np.ceil(k / micro_bs))

            for i in range(accum_steps):
                start = i * micro_bs
                end = min((i + 1) * micro_bs, k)
                if start >= end:
                    continue

                try:
                    mb_input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list[start:end], batch_first=True, padding_value=pad_id)
                    mb_labels = torch.nn.utils.rnn.pad_sequence(labels_list[start:end], batch_first=True, padding_value=-100)
                    mb_attn = (mb_input_ids != pad_id)

                    mb_input_ids = mb_input_ids.to(device)
                    mb_labels = mb_labels.to(device)
                    mb_attn = mb_attn.to(device)

                    with torch.inference_mode():
                        last_hidden = get_last_hidden_only(model, mb_input_ids, mb_attn)  # [b, L, H]

                    hidden_prev = last_hidden[..., :-1, :]
                    labels_next = mb_labels[..., 1:]
                    valid_mask = (labels_next != -100)
                    if not valid_mask.any().item():
                        del mb_input_ids, mb_labels, mb_attn, last_hidden, hidden_prev, labels_next, valid_mask
                        torch.cuda.empty_cache()
                        continue

                    sel_hidden = hidden_prev[valid_mask].contiguous()  # [N, H]
                    sel_labels = labels_next[valid_mask].contiguous()  # [N]

                    # F.linear with detached lm_head
                    lm_w = model.lm_head.weight.detach()
                    lm_b = model.lm_head.bias.detach() if getattr(model.lm_head, 'bias', None) is not None else None

                    target_dtype = lm_w.dtype
                    sel_hidden = sel_hidden.to(target_dtype)
                    delta_t = delta.to(target_dtype)
                    logits_sel = F.linear(sel_hidden + delta_t.view(1, -1), lm_w, lm_b)
                    scaled_logits = logits_sel / temperature.to(target_dtype)
                    loss = F.cross_entropy(scaled_logits, sel_labels)

                    if hasattr(config, 'gradient_accumulation') and config.gradient_accumulation and micro_bs < k:
                        loss = loss / accum_steps

                    loss.backward()
                    total_loss += float(loss.detach().cpu())

                    del mb_input_ids, mb_labels, mb_attn, last_hidden, hidden_prev, labels_next, valid_mask, sel_hidden, sel_labels, lm_w, lm_b, logits_sel, scaled_logits, loss, delta_t
                    torch.cuda.empty_cache()
                except torch.cuda.OutOfMemoryError:
                    logger.warning(f"OOM in micro-batch step {i+1}/{accum_steps}, skipping this micro-batch.")
                    try:
                        del mb_input_ids, mb_labels, mb_attn, last_hidden, hidden_prev, labels_next, valid_mask, sel_hidden, sel_labels, lm_w, lm_b, logits_sel, scaled_logits, loss, delta_t
                    except Exception:
                        pass
                    torch.cuda.empty_cache()
                    continue

            optimizer.step()

            if hasattr(config, 'log_epoch_loss') and config.log_epoch_loss:
                delta_norm = delta.norm().item()
                temp_value = float(temperature.detach().to(torch.float32).item())
                logger.info(f"Joint training epoch {epoch+1}/{config.calib_epochs} | loss: {total_loss:.6f} | norm(delta): {delta_norm:.6f} | temperature: {temp_value:.6f}")

            epoch += 1
        except torch.cuda.OutOfMemoryError:
            logger.warning(f"OOM at batch_size={micro_bs}, reducing batch_size...")
            torch.cuda.empty_cache()
            optimizer.zero_grad(set_to_none=True)
            new_bs = max(1, micro_bs // 2)
            if new_bs == micro_bs == 1:
                raise
            micro_bs = new_bs
            continue

    delta_detached = delta.detach().to(torch.float32)
    temperature_detached = temperature.detach().to(torch.float32)
    torch.cuda.empty_cache()
    return delta_detached, temperature_detached


def joint_train_delta_temp_stepwise(model, prompt, completions, scores, tokenizer, config):
    n1 = min(config.n1, len(completions))
    completions_n1 = completions[:n1]
    scores_n1 = scores[:n1]

    all_step_scores = []
    for s in scores_n1:
        all_step_scores.extend(s)
    all_step_scores = np.array(all_step_scores)
    mean_score = np.mean(all_step_scores)

    selected_completions = []
    stepwise_weights = []

    for ans_text, step_scores in zip(completions_n1, scores_n1):
        steps = ans_text.split('\n\n')
        n_steps = min(len(steps), len(step_scores))
        steps = steps[:n_steps]
        step_scores = step_scores[:n_steps]

        if len(step_scores) == 0:
            continue
        max_s = max(step_scores)
        min_s = min(step_scores)

        if max_s == min_s:
            weights = [0.0 for _ in step_scores]
        else:
            weights = [2 * (s - mean_score) / (max_s - min_s) for s in step_scores]
            # clip to [-1, 1]
            weights = [max(-1.0, min(1.0, w)) for w in weights]

        ans_text_trunc = '\n\n'.join(steps)
        selected_completions.append(ans_text_trunc)
        stepwise_weights.append(weights)

    if len(selected_completions) == 0:
        return None, None

    # 2. build input_ids/labels/weights
    input_ids_list = []
    labels_list = []
    weights_list = []
    for ans_text, weights in zip(selected_completions, stepwise_weights):
        conv = [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": ans_text}
        ]
        if hasattr(tokenizer, "apply_chat_template"):
            full_text = tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
        else:
            full_text = (config.system_prompt + "\n" if config.system_prompt else "") + prompt + "\n" + ans_text
        encoded = tokenizer(full_text, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_tokens)
        input_ids = encoded["input_ids"].squeeze(0)

        user_conv = [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt}
        ]
        if hasattr(tokenizer, "apply_chat_template"):
            user_text = tokenizer.apply_chat_template(user_conv, tokenize=False, add_generation_prompt=False)
        else:
            user_text = (config.system_prompt + "\n" if config.system_prompt else "") + prompt
        user_encoded = tokenizer(user_text, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_tokens)
        user_input_ids = user_encoded["input_ids"].squeeze(0)
        assistant_start = user_input_ids.shape[0]
        labels = input_ids.clone()
        labels[:assistant_start] = -100
        if tokenizer.pad_token_id is not None:
            labels[labels == tokenizer.pad_token_id] = -100

        step_texts = ans_text.split('\n\n')
        step_token_lens = [len(tokenizer(st, return_tensors="pt", add_special_tokens=False)["input_ids"].squeeze(0)) for st in step_texts]
        token_weights = []
        for w, l in zip(weights, step_token_lens):
            token_weights.extend([w] * l)
        token_weights = [0.0] * assistant_start + token_weights
        if len(token_weights) < len(labels):
            token_weights += [0.0] * (len(labels) - len(token_weights))
        token_weights = token_weights[:len(labels)]
        weights_tensor = torch.tensor(token_weights, dtype=torch.float32)
        input_ids_list.append(input_ids)
        labels_list.append(labels)
        weights_list.append(weights_tensor)

    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_id)
    labels_padded = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100)
    weights_padded = torch.nn.utils.rnn.pad_sequence(weights_list, batch_first=True, padding_value=0.0)
    attention_mask = (input_ids_padded != pad_id)

    device = model.device

    calib_bs = int(getattr(config, "calib_batch_size", len(selected_completions)))
    delta_dtype = torch.float32
    temp_dtype = torch.float32

    hidden_size = getattr(getattr(model, 'lm_head', None), 'in_features', None)
    if hidden_size is None:
        hidden_size = getattr(model.get_input_embeddings(), 'embedding_dim', None)
    if hidden_size is None:
        with torch.inference_mode():
            tmp_ids = torch.nn.utils.rnn.pad_sequence([input_ids_list[0]], batch_first=True, padding_value=pad_id).to(device)
            tmp_mask = (tmp_ids != pad_id)
            last_h = get_last_hidden_only(model, tmp_ids, tmp_mask)
            hidden_size = last_h.shape[-1]
            del last_h, tmp_ids, tmp_mask
            torch.cuda.empty_cache()

    delta = torch.nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=delta_dtype))
    temperature = torch.nn.Parameter(torch.tensor(float(config.init_temp), device=device, dtype=temp_dtype))

    optimizer = torch.optim.AdamW([
        {'params': delta, 'weight_decay': getattr(config, 'weight_decay', 1e-2)},
        {'params': temperature, 'weight_decay': 0.0},
    ], lr=config.calib_lr, eps=1e-5)

    model.eval()

    num_samples = input_ids_padded.shape[0]
    micro_bs = int(getattr(config, "calib_batch_size", min(8, num_samples)))
    epoch = 0
    while epoch < config.calib_epochs:
        try:
            optimizer.zero_grad()
            total_loss = 0.0
            accum_steps = int(np.ceil(num_samples / micro_bs))

            for i in range(accum_steps):
                start = i * micro_bs
                end = min((i + 1) * micro_bs, num_samples)
                if start >= end:
                    continue

                mb_input_ids = input_ids_padded[start:end].to(device)
                mb_labels = labels_padded[start:end].to(device)
                mb_weights = weights_padded[start:end].to(device)
                mb_attn = attention_mask[start:end].to(device)

                with torch.inference_mode():
                    hidden_states = get_last_hidden_only(model, mb_input_ids, mb_attn).detach()
                hidden_prev = hidden_states[..., :-1, :]
                labels_next = mb_labels[..., 1:]
                weights_next = mb_weights[..., 1:]
                valid_mask = (labels_next != -100)
                if not valid_mask.any().item():
                    del mb_input_ids, mb_labels, mb_weights, mb_attn, hidden_states, hidden_prev, labels_next, weights_next, valid_mask
                    torch.cuda.empty_cache()
                    continue

                target_dtype = model.lm_head.weight.dtype
                sel_hidden_full = hidden_prev[valid_mask].contiguous().to(target_dtype).clone()
                sel_labels_full = labels_next[valid_mask].contiguous()
                sel_weights_full = weights_next[valid_mask].contiguous().to(target_dtype)

                lm_w = model.lm_head.weight.detach()
                lm_b = model.lm_head.bias.detach() if getattr(model.lm_head, 'bias', None) is not None else None

                delta_t = delta.to(target_dtype).view(1, -1)
                temp_t = temperature.to(target_dtype)

                logits_sel = F.linear(sel_hidden_full + delta_t, lm_w, lm_b) / temp_t
                loss = F.cross_entropy(logits_sel, sel_labels_full, reduction='none')
                weighted_loss = (loss * sel_weights_full).sum() / (sel_weights_full.sum() + 1e-8)
                weighted_loss.backward()
                total_loss += float(weighted_loss.detach().cpu())

                del mb_input_ids, mb_labels, mb_weights, mb_attn, hidden_states, hidden_prev, labels_next, weights_next, valid_mask, sel_hidden_full, sel_labels_full, sel_weights_full, lm_w, lm_b, logits_sel, loss, weighted_loss, delta_t, temp_t
                torch.cuda.empty_cache()

            optimizer.step()

            if hasattr(config, 'log_epoch_loss') and config.log_epoch_loss:
                delta_norm = delta.norm().item()
                temp_value = float(temperature.detach().to(torch.float32).item())
                logger.info(f"Stepwise joint training epoch {epoch+1}/{config.calib_epochs} | loss: {total_loss:.6f} | norm(delta): {delta_norm:.6f} | temperature: {temp_value:.6f}")

            torch.cuda.empty_cache()
            epoch += 1
        except torch.cuda.OutOfMemoryError:
            logger.warning(f"[stepwise] OOM at micro_bs={micro_bs}, reducing batch_size...")
            torch.cuda.empty_cache()
            optimizer.zero_grad(set_to_none=True)
            new_bs = max(1, micro_bs // 2)
            if new_bs == micro_bs == 1:
                logger.error("[stepwise] OOM at micro_bs=1, skip this example!")
                return None, None
            micro_bs = new_bs
            continue

    delta_detached = delta.detach().to(torch.float32)
    temperature_detached = temperature.detach().to(torch.float32)
    logger.info(f"[stepwise] delta shape: {tuple(delta_detached.shape)}")
    torch.cuda.empty_cache()
    return delta_detached, temperature_detached


def main():
    parser = H4ArgumentParser(Config)
    config = parser.parse()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(config.model_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(config.model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(f"Loading input dataset from {config.input_dataset_path}")
    logger.info(f"Model path: {config.model_path}")
    logger.info(f"Output delta path: {config.output_delta_path}")
    logger.info(f"Output temperature path: {config.output_temperature_path}")
    logger.info(f"Joint training epochs: {config.calib_epochs}")
    logger.info(f"Learning rate: {config.calib_lr}")
    logger.info(f"n1: {config.n1}, k: {config.k}")
    logger.info(f"Device: {device}")

    if getattr(config, "ablate_delta", False):
        logger.info("Ablation mode: ONLY train temperature (delta is frozen)")
    elif getattr(config, "ablate_temperature", False):
        logger.info("Ablation mode: ONLY train delta (temperature is frozen)")
    else:
        logger.info("Joint train delta and temperature")

    dataset = load_dataset("json", data_files=config.input_dataset_path)["train"]

    all_delta_dict = {}
    all_temp_dict = {}

    for example in tqdm(dataset, desc="Processing"):
        unique_id = str(example["unique_id"])
        prompt = example["problem"]
        completions = example["completions"]
        scores = example["scores"]

        if config.dataset_selection == "stepwise":
            delta, temp = joint_train_delta_temp_stepwise(
                model, prompt, completions, scores, tokenizer, config
            )
        else:
            delta, temp = joint_train_delta_temp(
                model, prompt, completions, scores, tokenizer, config
            )

        if delta is None or temp is None:
            logger.info(f"[skip] unique_id={unique_id} all steps low score, skip calibration.")
            continue

        all_delta_dict[unique_id] = {"delta": delta.cpu().numpy()}
        all_temp_dict[unique_id] = {"temperature": temp.cpu().item()}
        
        # release memory
        del delta, temp
        torch.cuda.empty_cache()

    # save the results
    save_temperature_dict_npz(all_delta_dict, config.output_delta_path)
    save_temperature_dict_npz(all_temp_dict, config.output_temperature_path)
    logger.info(f"Saved delta to {config.output_delta_path}")
    logger.info(f"Saved temperature to {config.output_temperature_path}")


if __name__ == "__main__":
    main()