import json
from dataclasses import dataclass, field
import os
import argparse
import random
import re
import math
import numpy as np
from typing import List, Dict, Optional


import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, HfArgumentParser
from transformers.trainer_utils import IntervalStrategy
from rec_adam import RecAdam, anneal_function
from transformers import TrainerCallback
from result_evaluator import ResultEvaluator

from transformers import Trainer


class PrefixOnlyGenerateTrainer(Trainer):
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        if not self.args.predict_with_generate:
            return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)

        # Normalize inputs to device tensors
        inputs = self._prepare_inputs(inputs)
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask", None)
        labels = inputs.get("labels", None)

        tok = getattr(self, "tokenizer", None)
        pad_id = (getattr(tok, "pad_token_id", None)
                  or getattr(self.model.config, "pad_token_id", None)
                  or (getattr(tok, "eos_token_id", None) if tok is not None else None)
                  or getattr(self.model.config, "eos_token_id", None)
                  or 0)
        eos_id = ((getattr(tok, "eos_token_id", None) if tok is not None else None)
                  or getattr(self.model.config, "eos_token_id", None)
                  or pad_id)

        with torch.no_grad():
            generated = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=self.args.generation_max_new_tokens,
                do_sample=self.args.generation_do_sample,
                num_beams=self.args.generation_num_beams,
                pad_token_id=pad_id,
                eos_token_id=eos_id,
            )

        # Use attention_mask to compute each sample's prompt length
        if attention_mask is not None:
            prompt_lens = attention_mask.long().sum(dim=1)
        else:
            prompt_lens = (input_ids != pad_id).long().sum(dim=1)

        # Keep only the newly generated tokens
        new_tokens = []
        for i in range(input_ids.size(0)):
            L = int(prompt_lens[i].item())
            new_tokens.append(generated[i, L:])

        max_new = max((t.size(0) for t in new_tokens), default=0)
        preds_padded = input_ids.new_full((input_ids.size(0), max_new), pad_id)
        for i, t in enumerate(new_tokens):
            if t.numel() > 0:
                preds_padded[i, :t.size(0)] = t

        preds_padded = preds_padded.clamp_min(0)  # Ensure no negative values
        preds_padded = preds_padded.to(dtype=torch.long)

        return (None, preds_padded.detach().cpu(),
                labels.detach().cpu() if labels is not None else None)


class ForceEvalEveryNSteps(TrainerCallback):
    def __init__(self, n_steps=500, also_on_epoch_end=True):
        self.n = int(n_steps)
        self.also_on_epoch_end = also_on_epoch_end

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step > 0 and (state.global_step % self.n == 0):
            control.should_evaluate = True
            print(f"[DEBUG] force-eval at step={state.global_step}")
        return control

    def on_epoch_end(self, args, state, control, **kwargs):
        if self.also_on_epoch_end:
            control.should_evaluate = True
            print(f"[DEBUG] force-eval at epoch_end={state.epoch}")
        return control


# Optional: Import swanlab for experiment tracking
try:
    import swanlab
    SWANLAB_AVAILABLE = True
except ImportError:
    SWANLAB_AVAILABLE = False
    print("swanlab not available. Install with: pip install swanlab")


def load_model(args):
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path, trust_remote_code=True)
    tokenizer.padding_side = "right"
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)

    # Ensure tokenizer has pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(
        f"[TOK] pad={tokenizer.pad_token_id}, eos={getattr(tokenizer, 'eos_token_id', None)}")

    if args.lora_path is not None:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, args.lora_path)
        model = model.merge_and_unload()

    return tokenizer, model


def model_to_lora_model(model):
    from peft import LoraConfig, get_peft_model

    with open("lora_config.json", "r") as f:
        lora_config_dict = json.load(f)

    lora_config = LoraConfig(**lora_config_dict)
    model = get_peft_model(model, lora_config)

    trainable_params = sum(p.numel()
                           for p in model.parameters() if p.requires_grad)
    print(
        f"\033[91m trainable_params: {trainable_params} (total: {model.num_parameters()})\033[0m")

    return model


def prepare_input_data(data, include_reasoning=False, train_data_type='loginum'):
    if train_data_type == 'loginum':
        facts = "Facts:\n"
        for i, fact in enumerate(data['facts-tuned-nl']):
            facts += f"{i + 1}. {fact}\n"
        rules = "Rules:\n"
        for i, rule in enumerate(data['rules-tuned-nl']):
            rules += f"{i + 1}. {rule}\n"
        query_entity, query_attribute = data['query']
        query = f"Query:\nWhat is the value of {query_entity}'s {query_attribute}?\n"
        if not include_reasoning:
            return facts + rules + query
        reasoning_process = "Reasoning:\n"
        reasoning_process += data["reasoning_process_nl"] + "\n"
        answer = data["answer"]
        answer = f"Answer: \\boxed{{{answer}}}\n"
        return facts + rules + query + reasoning_process + answer


def prepare_output_data(data, train_data_type='loginum'):
    if train_data_type == 'loginum':
        reasoning_process = "Reasoning:\n" + \
            data["reasoning_process_nl"] + "\n"
        answer = data["answer"]
        answer = f"Answer: \\boxed{{{answer}}}\n"
        return reasoning_process + answer


def load_jsonl_data(data_path):
    data = []
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line.strip())
            data.append(item)
    return data


def data_type_from_path(data_path: str) -> str:
    if any(key_word in data_path for key_word in ['train-el.json', 'train-en.json']):
        return 'loginum'
    else:
        raise ValueError(f"Unknown dataset type in data_path: {data_path}.")


def preprocess_jsonl_data(tokenizer, instruction, data, args):
    """Process JSONL format reasoning data, following the data processing approach in run_llm_inference.py"""
    processed_data = []
    train_data_type = data_type_from_path(args.data_path)
    for item in data:
        # Input: <system>instruction <user>Facts + Rules + Query <assistant>
        input_text = prepare_input_data(
            item, include_reasoning=False, train_data_type=train_data_type)
        # Output: Reasoning + Answer
        output_text = prepare_output_data(
            item, train_data_type=train_data_type)

        message = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": input_text}
        ]
        if "Qwen3" in args.model_path and "base" not in args.model_path.lower():
            input_text = tokenizer.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True, enable_thinking=False)
        elif "base" in args.model_path.lower():
            input_text = f"{instruction}\n\n{input_text}"
        else:
            input_text = tokenizer.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True)

        output_text += "<|im_end|>"
        processed_data.append({'input': input_text, 'output': output_text})
    return processed_data


def preprocess_qa_val_data(tokenizer, data, args):
    """
    Convert the validation set {question, answer, problemtype} to:
    - Input: Chat template (system+user), requiring output in the format `Answer: \\boxed{...}`
    - Output: Standardized answer text in the format: `Answer: \\boxed{<label>}\n<|im_end|>`
    - Keep the problem_type field for future compute_metrics needs.
    """

    def _sys_init(pt: str) -> str:
        pt = pt.lower()
        if pt in ["gsm8k", "math", "mathqa", "minervamath", "svamp", "asdiv", "mawps"]:  # Math datasets
            return "Please answer the question. Return the final result as: Answer: \\boxed{...}."

        return "You are a careful logical reasoner, Answer the question according to the requirements."

    def _normalize_truth(pt: str, a: str) -> str:
        pt = pt.lower()
        if pt == "logicnli":
            x = ResultEvaluator._normalize_logicnli_label(a)
            return x or a
        if pt == "robust":
            x = ResultEvaluator._normalize_robustlr_label(a)
            return x or a
        if pt == "folio":
            x = ResultEvaluator._normalize_folio_label(a)
            return x or a
        if pt == "logiqa":
            x = ResultEvaluator._normalize_logiqa_label(a)
            return x or a
        if pt == "abductionr":
            x = ResultEvaluator._normalize_abductionr_label(a)
            return x or a
        return str(a).strip()

    processed = []
    for item in data:
        q = item["question"]
        a = item["answer"]
        pt = item["problem_type"]

        sys_init = _sys_init(pt)
        a_norm = _normalize_truth(pt, a)

        message = [
            {"role": "system", "content": sys_init},
            {"role": "user", "content": q}
        ]
        # Maintain the same chat template logic as the training set
        if "Qwen3" in args.model_path and "base" not in args.model_path.lower():
            input_text = tokenizer.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True, enable_thinking=False)
        elif "base" in args.model_path.lower():
            input_text = f"{sys_init}\n\n{q}"
        else:
            input_text = tokenizer.apply_chat_template(
                message, tokenize=False, add_generation_prompt=True)

        # Only require the model to output the final answer (not forcing "Reasoning:")
        output_text = f"Answer: \\boxed{{{a_norm}}}\n<|im_end|>"
        processed.append(
            {"input": input_text, "output": output_text, "problem_type": pt})
    return processed


@dataclass
class CustomArguments:
    data_path: str = field(
        default='../data/el-hn.jsonl,../data/el-en.jsonl',
        metadata={
            "help": "Path or comma-separated paths to reasoning dataset(s) (JSONL format)"}
    )
    max_seq_length: int = field(default=2048)
    model_path: str = field(
        default='~/models/Qwen3-8B')
    lora_path: str = field(
        default=None,
        metadata={"help": "Path to LoRA weights"}
    )
    use_lora: bool = field(
        default=False,
        metadata={"help": "Whether to use LoRA"}
    )
    use_swanlab: bool = field(
        default=False,
        metadata={"help": "Whether to use swanlab for experiment tracking"}
    )
    swanlab_project: str = field(
        default="sft-reasoning",
        metadata={"help": "SwanLab project name"}
    )
    swanlab_experiment_name: str = field(
        default=None,
        metadata={"help": "SwanLab experiment name (auto-generated if None)"}
    )
    swanlab_description: str = field(
        default="SFT training experiment",
        metadata={"help": "SwanLab experiment description"}
    )
    swanlab_api_key: str = field(
        default=None,
        metadata={"help": "SwanLab API key for authentication"}
    )
    swanlab_mode: str = field(
        default="cloud",
        metadata={
            "help": "SwanLab mode: 'cloud', 'offline', 'disabled', or 'local'"}
    )
    # ==== RecAdam toggles and hyperparameters ====
    use_recall_adam: bool = field(
        default=False, metadata={"help": "Use RecAdam optimizer instead of AdamW"}
    )
    rec_anneal_fun: str = field(
        default="sigmoid", metadata={"help": "RecAdam anneal_fun: sigmoid|linear|cos|constant"}
    )
    rec_anneal_k: float = field(
        default=0.1, metadata={"help": "RecAdam anneal_k (slope for sigmoid, etc.)"}
    )
    rec_anneal_t0: float = field(
        default=500.0, metadata={"help": "RecAdam anneal_t0 (mid-point / scale of steps)"}
    )
    rec_anneal_w: float = field(
        default=1.0, metadata={"help": "RecAdam anneal_w (max lambda)"}
    )
    rec_pretrain_cof: float = field(
        default=3000.0, metadata={"help": "RecAdam pretrain_cof (gamma): quadratic penalty strength"}
    )
    val_data_path: str = field(
        default=None,
        metadata={
            "help": "The JSONL path for the external validation set (fields include question/answer)"}
    )


if __name__ == '__main__':
    parser = HfArgumentParser((CustomArguments, TrainingArguments))
    args, training_args = parser.parse_args_into_dataclasses()

    # === Configure evaluation/save strategy early (before swanlab.init) ===
    training_args.evaluation_strategy = IntervalStrategy.STEPS
    training_args.eval_steps = 20
    training_args.save_strategy = IntervalStrategy.STEPS
    training_args.save_steps = 625
    training_args.load_best_model_at_end = True
    training_args.metric_for_best_model = "accuracy"
    training_args.greater_is_better = True
    # === Use generation-based evaluation to avoid accumulating logits ===
    training_args.predict_with_generate = True
    training_args.generation_max_new_tokens = 4096
    training_args.generation_num_beams = 1
    training_args.generation_do_sample = False
    # Keep eval batch small and clear intermediate tensors quickly
    training_args.per_device_eval_batch_size = 8
    training_args.eval_accumulation_steps = 1
    training_args.dataloader_pin_memory = False

    if not getattr(training_args, "per_device_eval_batch_size", None):
        training_args.per_device_eval_batch_size = training_args.per_device_train_batch_size

    if args.use_swanlab and SWANLAB_AVAILABLE:
        training_args.report_to = ["swanlab"]
    else:
        training_args.report_to = ["none"]

    print(f"[CFG BEFORE SWAN] eval_strategy={training_args.evaluation_strategy}, "
          f"eval_steps={getattr(training_args, 'eval_steps', None)}, "
          f"do_eval={(training_args.evaluation_strategy != 'no')}")

    # Initialize SwanLab if requested
    if args.use_swanlab and SWANLAB_AVAILABLE:
        # Set up authentication if API key is provided
        if args.swanlab_api_key:
            swanlab.login(api_key=args.swanlab_api_key)

        # Auto-generate experiment name if not provided
        if args.swanlab_experiment_name is None:
            dataset_name = os.path.basename(
                args.data_path).replace('.jsonl', '')
            model_name = os.path.basename(args.model_path)
            args.swanlab_experiment_name = f"sft-{dataset_name}-{model_name}-{os.getpid()}"

        cfg = {
            "data_path": args.data_path,
            "model_path": args.model_path,
            "max_seq_length": args.max_seq_length,
            "lora_path": args.lora_path,
            "use_lora": args.use_lora,
            "learning_rate": training_args.learning_rate,
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "per_device_eval_batch_size": training_args.per_device_eval_batch_size,
            "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
            "num_train_epochs": training_args.num_train_epochs,
            "warmup_ratio": training_args.warmup_ratio,
            "lr_scheduler_type": training_args.lr_scheduler_type,
            "max_grad_norm": training_args.max_grad_norm,


            "evaluation_strategy": str(training_args.evaluation_strategy),
            "eval_steps": getattr(training_args, "eval_steps", None),
            "save_strategy": str(training_args.save_strategy),
            "load_best_model_at_end": training_args.load_best_model_at_end,
            "metric_for_best_model": training_args.metric_for_best_model,
            "greater_is_better": training_args.greater_is_better,
            "do_eval": (training_args.evaluation_strategy != "no"),
            "val_data_path": args.val_data_path,
        }

        # Initialize SwanLab
        swanlab.init(
            project=args.swanlab_project,
            experiment_name=args.swanlab_experiment_name,
            description=args.swanlab_description,
            config=cfg,
            mode=args.swanlab_mode
        )

        # Set SwanLab reporting in training args
        training_args.report_to = ["swanlab"]
        print(
            f"✓ SwanLab initialized ({args.swanlab_mode}): project={args.swanlab_project}, experiment={args.swanlab_experiment_name}")

    elif args.use_swanlab and not SWANLAB_AVAILABLE:
        print("⚠ SwanLab requested but not available. Install with: pip install swanlab")
        training_args.report_to = ["none"]
    else:
        training_args.report_to = ["none"]

    tokenizer, model = load_model(args)
    if args.use_lora:
        model = model_to_lora_model(model)

    # ==== Record initial trainable parameter snapshot (aligns with RecAdam pretrain_params) ====
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    pretrain_params = [p.detach().clone().float().cpu()
                       for p in trainable_params]

    model.cuda()

    # Support a single file or comma-separated multiple files
    dataset_paths = [p.strip() for p in args.data_path.split(',') if p.strip()]
    assert len(
        dataset_paths) == 1, "only allow one training file after 9.11 update"

    data = []
    for path in dataset_paths:
        cur = load_jsonl_data(path)
        if not isinstance(cur, list):
            raise ValueError(f"Data loader for {path} did not return a list.")
        data.extend(cur)

    if len(data) == 0:
        raise FileNotFoundError(
            f"No samples loaded from {dataset_paths}. Check --data_path")

    print(
        f"Loaded {len(data)} samples from {len(dataset_paths)} file(s): {dataset_paths}")

    # Shuffle data (single or multi-file)
    random.seed(42)
    random.shuffle(data)

    # with open("../prompt/instruction.txt", 'r', encoding='utf-8') as f:
    #     instruction = f.read()
    train_data_type = data_type_from_path(args.data_path)
    instruction_map = {
        "loginum": "Please analyze the given facts and rules, then answer the query with step-by-step reasoning.",
        **dict.fromkeys(
            ["folio", "logiqa", "robust", "abductionr", "arlsat", "reclor"],
            "You are a careful logical reasoner,Answer the question according to the requirements.",
        ),
        **dict.fromkeys(
            ["gsm8k", "math", "mathqa", "minervamath", "svamp", "asdiv", "mawps"],
            "Please answer the question. Return the final result as: Answer: \\boxed{...}.",
        ),
    }
    if train_data_type not in instruction_map:
        raise NotImplementedError(
            f"Instruction for data type {train_data_type} not defined.")

    processed_data = preprocess_jsonl_data(
        tokenizer, instruction_map[train_data_type], data, args)

    def tokenize(samples):
        """
        For chat model SFT, special handling is required:
        1. input is the complete chat template format (including system, user, assistant start marker)
        2. output is the assistant's reply content
        3. only calculate loss on the assistant part (input part in labels is set to -100)
        """
        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for input_text, output_text in zip(samples['input'], samples['output']):
            # Create complete conversation including assistant's reply and end marker
            full_conversation = input_text + output_text

            # Tokenize complete conversation
            full_tokens = tokenizer(
                full_conversation, add_special_tokens=False)
            full_input_ids = full_tokens['input_ids']

            # Tokenize input part (up to assistant start marker)
            input_tokens = tokenizer(input_text, add_special_tokens=False)
            input_length = len(input_tokens['input_ids'])

            # Create attention mask
            attention_mask = [1] * len(full_input_ids)

            # Create labels: input part as -100, output part (including <|im_end|>) as target tokens
            labels = [-100] * input_length + full_input_ids[input_length:]

            # Ensure consistent lengths
            assert len(full_input_ids) == len(labels) == len(attention_mask)

            # Truncate to max length if needed
            if len(full_input_ids) > args.max_seq_length:
                print(
                    f"\033[93mWarning: Input length {len(full_input_ids)} exceeds max_seq_length {args.max_seq_length}. Truncating...\033[0m")
                full_input_ids = full_input_ids[:args.max_seq_length]
                attention_mask = attention_mask[:args.max_seq_length]
                labels = labels[:args.max_seq_length]

            input_ids_list.append(full_input_ids)
            attention_mask_list.append(attention_mask)
            labels_list.append(labels)

        return {
            'input_ids': input_ids_list,
            'attention_mask': attention_mask_list,
            'labels': labels_list
        }

    def tokenize_eval_prompt_only(samples):
        """
        Validation set: Only provide the model with the prompt (system + user + initial assistant).
        The labels separately store the ground-truth answer (used for comparison in compute_metrics decoding).
        Meanwhile, pass through the problem_type for grouped evaluation in compute_metrics.
        """
        input_ids_list, attention_mask_list, labels_list = [], [], []
        pad_id = tokenizer.pad_token_id

        for input_text, output_text in zip(samples["input"], samples["output"]):
            inp = tokenizer(input_text, add_special_tokens=False)
            in_ids = inp["input_ids"]
            attn = [1] * len(in_ids)

            out = tokenizer(output_text, add_special_tokens=False)
            lab_ids = out["input_ids"]

            if len(in_ids) > args.max_seq_length:
                in_ids = in_ids[:args.max_seq_length]
                attn = attn[:args.max_seq_length]

            input_ids_list.append(in_ids)
            attention_mask_list.append(attn)
            labels_list.append(lab_ids)

        return {
            "input_ids": input_ids_list,
            "attention_mask": attention_mask_list,
            # Note: The length of the labels of eval can be different from the input (not used to calculate loss)
            "labels": labels_list,
            # Transparent problem_type for compute_metrics group assessment
            "problem_type": samples["problem_type"]
        }

    def data_collator_fn(features):
        pad_id = tokenizer.pad_token_id

        in_max = max(len(f["input_ids"]) for f in features)
        input_ids = torch.tensor(
            [f["input_ids"] + [pad_id] *
                (in_max - len(f["input_ids"])) for f in features],
            dtype=torch.long
        )
        attention_mask = torch.tensor(
            [f["attention_mask"] + [0] *
                (in_max - len(f["attention_mask"])) for f in features],
            dtype=torch.long
        )

        lab_max = max(len(f["labels"]) for f in features)
        labels = torch.tensor(
            [f["labels"] + [-100] * (lab_max - len(f["labels"]))
             for f in features],
            dtype=torch.long
        )

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

    def build_compute_metrics(tokenizer, problem_types: List[str]):
        PAD = (getattr(tokenizer, "pad_token_id", None)
               or getattr(tokenizer, "eos_token_id", None)
               or 0)

        def _to_2d_int_list(x, replace_neg=True):
            # torch.Tensor / np.ndarray / list / ragged object array
            if isinstance(x, tuple):
                x = x[0]
            if isinstance(x, torch.Tensor):
                x = x.detach().cpu().numpy()
            if isinstance(x, np.ndarray):
                # Turn each row (B, ?) into an int64 1D array
                if x.dtype == object:
                    rows = [np.asarray(r).astype(np.int64, copy=False)
                            for r in x.tolist()]
                else:
                    # Regular ndarrays: slice rows directly
                    rows = [r.astype(np.int64, copy=False) for r in x]
            elif isinstance(x, list):
                rows = [np.asarray(r).astype(np.int64, copy=False) for r in x]
            else:
                # Fallback: attempt to convert into list[list[int]]
                rows = [np.asarray(list(x)).astype(np.int64, copy=False)]

            out = []
            for r in rows:
                if replace_neg:
                    # Replace negative numbers (-1/-100 etc.) with PAD to avoid tokenizer errors
                    r = np.where(r < 0, PAD, r)
                out.append(r.tolist())
            return out

        def _eval_for_type(pt: str, preds: List[str], refs: List[str]):
            # Evaluate by problem_type
            pt = pt.lower()
            if pt == "logicnli":
                return ResultEvaluator.evaluate_logicnli(preds, refs)
            if pt == "robust":
                return ResultEvaluator.evaluate_robustlr(preds, refs)
            if pt == "folio":
                return ResultEvaluator.evaluate_folio(preds, refs)
            if pt == "logiqa":
                return ResultEvaluator.evaluate_logiqa(preds, refs)
            if pt == "abductionr":
                return ResultEvaluator.evaluate_abductionr(preds, refs)
            if pt == "arlsat":
                return ResultEvaluator.evaluate_arlsat(preds, refs)
            if pt == "reclor":
                return ResultEvaluator.evaluate_reclor(preds, refs)
            # A single evaluation is shared for mathematics datasets.
            return ResultEvaluator.evaluate_mathdata(preds, refs)

        def compute_metrics(eval_pred):
            preds = _to_2d_int_list(eval_pred.predictions, replace_neg=True)
            labels_arr = eval_pred.label_ids
            # In the labels, we previously used -100 for padding; here, we're replacing it with PAD before decoding.
            if isinstance(labels_arr, torch.Tensor):
                labels_arr = labels_arr.detach().cpu().numpy()
            labels_arr = np.where(labels_arr < 0, PAD, labels_arr)
            labels = _to_2d_int_list(labels_arr, replace_neg=False)

            pred_texts = tokenizer.batch_decode(
                preds, skip_special_tokens=True)
            label_texts = tokenizer.batch_decode(
                labels, skip_special_tokens=True)

            by_type_indices = {}
            for i, pt in enumerate(problem_types):
                by_type_indices.setdefault(pt.lower(), []).append(i)

            total_correct = 0
            total_total = 0
            metrics = {}

            for pt, idxs in by_type_indices.items():
                p_sub = [pred_texts[i] for i in idxs]
                r_sub = [label_texts[i] for i in idxs]
                m = _eval_for_type(pt, p_sub, r_sub)
                total_correct += int(m.get("correct", 0))
                total_total += int(m.get("total", len(idxs)))
                # Record the accuracy for each type as well
                metrics[f"accuracy_{pt}"] = float(m.get("accuracy", 0.0))

            overall_acc = (total_correct /
                           total_total) if total_total > 0 else 0.0
            metrics["accuracy"] = float(overall_acc)
            return metrics

        return compute_metrics

    dataset_all = Dataset.from_list(processed_data)
    val_raw = load_jsonl_data(args.val_data_path)
    val_proc = preprocess_qa_val_data(
        tokenizer, val_raw, args
    )
    train_dataset = dataset_all.map(tokenize, batched=True, num_proc=16)
    eval_dataset = Dataset.from_list(val_proc).map(
        tokenize_eval_prompt_only, batched=True, num_proc=16)

    eval_problem_types = eval_dataset["problem_type"]

    compute_metrics_fn = build_compute_metrics(tokenizer, eval_problem_types)

    input_ids_lens = list(map(len, train_dataset["input_ids"]))
    print(f'max input_ids length: {max(input_ids_lens)}')
    print(f'min input_ids length: {min(input_ids_lens)}')
    print(
        f'average input_ids length: {sum(input_ids_lens) / len(input_ids_lens)}')
    print(
        f'>512 input_ids length ratio: {len(list(filter(lambda x: x > 512, input_ids_lens))) / len(input_ids_lens)}')
    print(
        f'>768 input_ids length ratio: {len(list(filter(lambda x: x > 768, input_ids_lens))) / len(input_ids_lens)}')
    print(
        f'>1024 input_ids length ratio: {len(list(filter(lambda x: x > 1024, input_ids_lens))) / len(input_ids_lens)}')

    if args.use_lora:
        model.print_trainable_parameters()    # Train the model
    # ====== Build Trainer ======

    steps_per_epoch = math.ceil(
        len(train_dataset) / (training_args.per_device_train_batch_size *
                              training_args.gradient_accumulation_steps)
    )
    n_steps = max(1, steps_per_epoch)
    if args.use_recall_adam:

        device = next(model.parameters()).device
        pretrain_params_on_device = [pp.to(device) for pp in pretrain_params]
        assert len(trainable_params) == len(
            pretrain_params), "params/pretrain_params length mismatch!"

        optim = RecAdam(
            params=[{"params": trainable_params,
                     "pretrain_params": pretrain_params_on_device}],
            lr=training_args.learning_rate,
            betas=(training_args.adam_beta1, training_args.adam_beta2),
            eps=training_args.adam_epsilon,
            weight_decay=training_args.weight_decay,
            correct_bias=True,  # usually True
            anneal_fun=args.rec_anneal_fun,
            anneal_k=args.rec_anneal_k,
            anneal_t0=args.rec_anneal_t0,
            anneal_w=args.rec_anneal_w,
            pretrain_cof=args.rec_pretrain_cof,
        )

        print(
            f"[DATA] train size = {len(train_dataset)}   eval size = {len(eval_dataset)}")
        assert len(
            eval_dataset) > 0, "Eval dataset is empty! Check --val_data_path or split."

        # Conduct an immediate "pre-evaluation" smoke test to ensure that eval_loss is printed.
        tmp_trainer = PrefixOnlyGenerateTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator_fn,
            compute_metrics=compute_metrics_fn,
            tokenizer=tokenizer,
        )

        print("[PRE-EVAL]", tmp_trainer.evaluate(eval_dataset))

        trainer = PrefixOnlyGenerateTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator_fn,
            optimizers=(optim, None),
            callbacks=[ForceEvalEveryNSteps(
                n_steps=n_steps, also_on_epoch_end=True)],
            compute_metrics=compute_metrics_fn,
            tokenizer=tokenizer,
        )
    else:

        print("[CHECK]", type(training_args.evaluation_strategy),
              training_args.evaluation_strategy)
        print(
            f"[DATA] train size = {len(train_dataset)}   eval size = {len(eval_dataset)}")
        assert len(
            eval_dataset) > 0, "Eval dataset is empty! Check --val_data_path or split."

        tmp_trainer = PrefixOnlyGenerateTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator_fn,
            compute_metrics=compute_metrics_fn,
            tokenizer=tokenizer,
        )

        print("[PRE-EVAL]", tmp_trainer.evaluate(eval_dataset))

        trainer = PrefixOnlyGenerateTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator_fn,
            callbacks=[ForceEvalEveryNSteps(
                n_steps=n_steps, also_on_epoch_end=True)],
            compute_metrics=compute_metrics_fn,
            tokenizer=tokenizer,
        )

    trainer.train()
    trainer.save_model()

    # Finish SwanLab logging
    if args.use_swanlab and SWANLAB_AVAILABLE:
        swanlab.finish()
        print("✓ SwanLab logging finished")
