import argparse
import contextlib
import inspect
import json
import math
import random
from dataclasses import dataclass, asdict
from functools import lru_cache
from pathlib import Path
from typing import List

import evaluate
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from trl import SFTConfig, SFTTrainer


DEFAULT_TRAIN_DATASET = "flytech/python-codes-25k"
DEFAULT_EVAL_DATASET = ("wiki40b", "fr", "test", "text")
MODEL_ID = "Qwen/Qwen3-0.6B"


@lru_cache(maxsize=1)
def get_perplexity_metric() -> evaluate.Metric:
    """Load the Hugging Face perplexity metric once."""
    return evaluate.load("perplexity")


@contextlib.contextmanager
def _use_in_memory_model(model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
    """Temporarily patch *from_pretrained to reuse an already loaded model/tokenizer."""
    original_model_ctor = AutoModelForCausalLM.from_pretrained
    original_tokenizer_ctor = AutoTokenizer.from_pretrained

    AutoModelForCausalLM.from_pretrained = classmethod(
        lambda cls, *_args, **_kwargs: model)
    AutoTokenizer.from_pretrained = classmethod(
        lambda cls, *_args, **_kwargs: tokenizer)

    try:
        yield
    finally:
        AutoModelForCausalLM.from_pretrained = original_model_ctor
        AutoTokenizer.from_pretrained = original_tokenizer_ctor


def compute_perplexity(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    texts: List[str],
    device: torch.device,
    max_length: int,
    batch_size: int,
) -> tuple[float, float]:
    """Return (mean_perplexity, mean_loss) for the provided texts using the evaluate metric."""
    samples = list(texts)
    if not samples:
        raise ValueError("compute_perplexity received an empty text sequence")

    original_pad = tokenizer.pad_token
    try:
        if tokenizer.pad_token is None:
            if tokenizer.eos_token is None:
                raise ValueError(
                    "Tokenizer lacks both pad_token and eos_token; cannot evaluate perplexity."
                )
            tokenizer.pad_token = tokenizer.eos_token

        metric = get_perplexity_metric()
        metric_device = "cuda" if device.type == "cuda" else "cpu"

        with _use_in_memory_model(model, tokenizer):
            results = metric.compute(
                model_id=model.config._name_or_path,
                predictions=samples,
                batch_size=batch_size,
                device=metric_device,
                max_length=max_length,
                add_start_token=False,
            )
    finally:
        tokenizer.pad_token = original_pad

    per_sample = [float(p) for p in results["perplexities"]]
    if not per_sample or any(p <= 0 for p in per_sample):
        raise ValueError(
            "Perplexity metric returned non-positive values; cannot take logarithm."
        )

    mean_perplexity = float(results["mean_perplexity"])
    mean_loss = float(sum(math.log(p) for p in per_sample) / len(per_sample))
    return mean_perplexity, mean_loss


def parse_dtype(dtype: str) -> torch.dtype:
    key = dtype.lower()
    if key == "auto":
        return torch.float32

    mapping = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    if key not in mapping:
        raise ValueError(
            "Unsupported dtype '%s'. Choose from auto, float32, float16, bfloat16." % dtype
        )
    return mapping[key]


def resolve_device(device_arg: str) -> torch.device:
    if device_arg == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")
    return torch.device(device_arg)


@dataclass(slots=True)
class EvalRecord:
    step: int
    epoch: float
    phase: str
    train_loss: float | None
    eval_loss: float
    eval_perplexity: float
    heldout_eval_loss: float | None
    heldout_eval_perplexity: float | None


class PerplexityCallback(TrainerCallback):
    def __init__(
        self,
        *,
        model: torch.nn.Module,
        tokenizer: AutoTokenizer,
        eval_texts: List[str],
        heldout_texts: List[str] | None,
        device: torch.device,
        eval_max_length: int,
        eval_batch_size: int,
        eval_steps: int,
        records: List[EvalRecord],
        log_fn=None,
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.eval_texts = eval_texts
        self.heldout_texts = heldout_texts
        self.device = device
        self.eval_max_length = eval_max_length
        self.eval_batch_size = eval_batch_size
        self.eval_steps = eval_steps
        self.records = records
        self.log_fn = log_fn
        self.last_logged_train_loss: float | None = None
        self.last_recorded_step = 0

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            try:
                self.last_logged_train_loss = float(logs["loss"])
            except (TypeError, ValueError):
                pass
        return control

    def _record(self, phase: str, state, model: torch.nn.Module | None) -> None:
        if not state.is_world_process_zero:
            return
        if model is None:
            return

        model.eval()
        eval_ppl, eval_loss = compute_perplexity(
            model,
            self.tokenizer,
            self.eval_texts,
            self.device,
            self.eval_max_length,
            self.eval_batch_size,
        )
        heldout_ppl = None
        heldout_loss = None
        if self.heldout_texts:
            heldout_ppl, heldout_loss = compute_perplexity(
                model,
                self.tokenizer,
                self.heldout_texts,
                self.device,
                self.eval_max_length,
                self.eval_batch_size,
            )
        epoch = float(state.epoch) if state.epoch is not None else 0.0
        record = EvalRecord(
            step=int(state.global_step),
            epoch=epoch,
            phase=phase,
            train_loss=self.last_logged_train_loss,
            eval_loss=eval_loss,
            eval_perplexity=eval_ppl,
            heldout_eval_loss=heldout_loss,
            heldout_eval_perplexity=heldout_ppl,
        )
        self.records.append(record)

        train_loss_str = (
            f"{self.last_logged_train_loss:.4f}" if self.last_logged_train_loss is not None else "n/a"
        )
        heldout_str = ""
        if heldout_ppl is not None and heldout_loss is not None:
            heldout_str = f" heldout_loss={heldout_loss:.4f} heldout_ppl={heldout_ppl:.3f}"
        print(
            f"[{phase}] step={state.global_step} epoch={epoch:.2f} train_loss={train_loss_str} "
            f"eval_loss={eval_loss:.4f} eval_ppl={eval_ppl:.3f}{heldout_str}",
            flush=True,
        )
        if self.log_fn is not None:
            log_payload = {"eval_ppl": eval_ppl, "eval_loss": eval_loss}
            if heldout_ppl is not None and heldout_loss is not None:
                log_payload["eval_heldout_ppl"] = heldout_ppl
                log_payload["eval_heldout_loss"] = heldout_loss
            self.log_fn(log_payload)
        model.train()

    def on_step_end(self, args, state, control, **kwargs):
        if self.eval_steps <= 0:
            return control
        if state.global_step == 0 or state.global_step == self.last_recorded_step:
            return control
        if state.global_step % self.eval_steps != 0:
            return control

        self._record("train", state, kwargs.get("model", self.model))
        self.last_recorded_step = state.global_step
        return control

    def on_train_end(self, args, state, control, **kwargs):
        if state.global_step == 0 or state.global_step == self.last_recorded_step:
            return control
        self._record("final", state, kwargs.get("model", self.model))
        return control


def _filter_kwargs(allowed: set[str], kwargs: dict) -> tuple[dict, List[str]]:
    filtered = {key: value for key, value in kwargs.items() if key in allowed}
    skipped = [key for key in kwargs if key not in allowed]
    return filtered, skipped


def _normalize_config(value: str | None) -> str | None:
    if value is None:
        return None
    cleaned = value.strip()
    if cleaned == "" or cleaned.lower() == "none":
        return None
    return cleaned


def load_split(dataset_name: str, dataset_config: str | None, split: str) -> Dataset:
    if dataset_config:
        return load_dataset(dataset_name, dataset_config, split=split)
    return load_dataset(dataset_name, split=split)


def resolve_text_column(dataset: Dataset, preferred: str | None) -> str:
    columns = list(dataset.column_names)
    if preferred:
        if preferred in columns:
            return preferred
        print(
            f"Requested text column '{preferred}' not found; auto-selecting from {columns}",
            flush=True,
        )

    for candidate in ("text", "content", "code", "python", "completion", "prompt"):
        if candidate in columns:
            return candidate

    if len(dataset) > 0:
        sample = dataset[0]
        for name in columns:
            if isinstance(sample.get(name), str):
                return name

    raise ValueError(
        "Unable to infer a text column from the dataset; please set --train-text-column.")


def resolve_chat_columns(
    dataset: Dataset,
    instruction_column: str,
    input_column: str,
    output_column: str,
) -> tuple[str, str | None, str]:
    columns = set(dataset.column_names)
    missing = [col for col in (
        instruction_column, output_column) if col not in columns]
    if missing:
        raise ValueError(
            "Missing required chat columns: " + ", ".join(missing)
        )

    resolved_input = input_column if input_column in columns else None
    return instruction_column, resolved_input, output_column


def filter_text_dataset(dataset: Dataset, text_column: str) -> Dataset:
    def is_valid(example: dict) -> bool:
        value = example.get(text_column)
        return isinstance(value, str) and value.strip() != ""

    return dataset.filter(is_valid)


def filter_chat_dataset(
    dataset: Dataset,
    instruction_column: str,
    input_column: str | None,
    output_column: str,
) -> Dataset:
    def is_valid(example: dict) -> bool:
        instruction = example.get(instruction_column)
        output = example.get(output_column)
        if not isinstance(instruction, str) or instruction.strip() == "":
            return False
        if not isinstance(output, str) or output.strip() == "":
            return False
        if input_column is None:
            return True
        input_value = example.get(input_column)
        if input_value is None:
            return True
        return isinstance(input_value, str)

    return dataset.filter(is_valid)


def build_chat_text(
    tokenizer: AutoTokenizer,
    instruction: str,
    input_text: str | None,
    output_text: str,
    system_prompt: str | None,
) -> str:
    user_message = instruction.strip()
    assistant_parts = []
    if input_text:
        input_text = input_text.strip()
        if input_text:
            assistant_parts.append(input_text)
    output_text = output_text.strip()
    if output_text:
        assistant_parts.append(output_text)
    assistant_message = "\n\n".join(part for part in assistant_parts if part)

    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_message})
    messages.append({"role": "assistant", "content": assistant_message})

    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )


def make_chat_formatting_func(
    tokenizer: AutoTokenizer,
    instruction_column: str,
    input_column: str | None,
    output_column: str,
    system_prompt: str | None,
):
    def formatting_func(example: dict) -> str | List[str]:
        instructions = example.get(instruction_column)
        outputs = example.get(output_column)
        inputs = example.get(input_column) if input_column else None

        if isinstance(instructions, list):
            results = []
            for idx, instruction in enumerate(instructions):
                input_text = inputs[idx] if isinstance(inputs, list) else None
                output_text = outputs[idx] if isinstance(outputs, list) else ""
                results.append(
                    build_chat_text(
                        tokenizer,
                        instruction=instruction,
                        input_text=input_text,
                        output_text=output_text,
                        system_prompt=system_prompt,
                    )
                )
            return results

        return build_chat_text(
            tokenizer,
            instruction=instructions,
            input_text=inputs,
            output_text=outputs,
            system_prompt=system_prompt,
        )

    return formatting_func


def gather_texts(
    dataset_name: str,
    dataset_config: str | None,
    dataset_split: str,
    text_column: str,
    max_samples: int,
) -> List[str]:
    dataset = load_split(dataset_name, dataset_config, dataset_split)
    texts: List[str] = []

    for record in dataset:
        text = record.get(text_column)
        if not isinstance(text, str):
            continue
        text = text.strip()
        if not text:
            continue
        texts.append(text)
        if max_samples > 0 and len(texts) >= max_samples:
            break

    if not texts:
        raise ValueError(
            f"Dataset {dataset_name}/{dataset_config}:{dataset_split} yielded no non-empty '{text_column}' entries"
        )
    return texts


def gather_texts_from_dataset(
    dataset: Dataset,
    text_column: str | None,
    formatting_func,
    max_samples: int,
) -> List[str]:
    texts: List[str] = []

    for record in dataset:
        if formatting_func is not None:
            formatted = formatting_func(record)
            candidates = formatted if isinstance(
                formatted, list) else [formatted]
        else:
            if text_column is None:
                raise ValueError(
                    "text_column is required when formatting_func is not provided"
                )
            value = record.get(text_column)
            if not isinstance(value, str):
                continue
            candidates = [value]

        for text in candidates:
            if not isinstance(text, str):
                continue
            text = text.strip()
            if not text:
                continue
            texts.append(text)
            if max_samples > 0 and len(texts) >= max_samples:
                return texts

    return texts


def set_seed(seed: int | None) -> None:
    if seed is None:
        return
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "LoRA fine-tuning for Qwen/Qwen3-0.6B on python code, "
            "with periodic perplexity evaluation to detect forgetting."
        )
    )
    parser.add_argument("--train-dataset-name", default=DEFAULT_TRAIN_DATASET)
    parser.add_argument("--train-dataset-config", default=None)
    parser.add_argument("--train-split", default="train")
    parser.add_argument(
        "--train-format",
        choices=["chat", "plain"],
        default="chat",
        help="Train on chat-formatted samples or raw text",
    )
    parser.add_argument(
        "--train-text-column",
        default=None,
        help="Text column to use when --train-format=plain",
    )
    parser.add_argument("--instruction-column", default="instruction")
    parser.add_argument(
        "--input-column",
        default="input",
        help="Optional assistant-preface column prepended to the output",
    )
    parser.add_argument("--output-column", default="output")
    parser.add_argument(
        "--system-prompt",
        default=None,
        help="Optional system prompt injected into the chat template",
    )
    parser.add_argument(
        "--max-train-samples",
        type=int,
        default=0,
        help="Limit training samples (<=0 uses the full split)",
    )
    parser.add_argument(
        "--train-max-length",
        type=int,
        default=1024,
        help="Max sequence length for packed training examples",
    )
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--grad-accumulation", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--max-steps", type=int, default=0)
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=None,
        help="Defaults to 3e-4 for LoRA or 5e-5 for full fine-tuning when unset.",
    )
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--warmup-steps", type=int, default=0)
    parser.add_argument("--log-steps", type=int, default=5)
    parser.add_argument(
        "--save-strategy",
        choices=["no", "steps", "epoch"],
        default="epoch",
        help="Checkpoint save strategy during training.",
    )
    parser.add_argument(
        "--save-steps",
        type=int,
        default=500,
        help="Save checkpoint every N steps when --save-strategy=steps.",
    )
    parser.add_argument(
        "--save-total-limit",
        type=int,
        default=None,
        help="Maximum number of checkpoints to keep (default keeps all).",
    )
    parser.add_argument(
        "--resume-from-checkpoint",
        default=None,
        help="Path to a checkpoint directory or 'auto' to resume latest in output_dir.",
    )
    parser.add_argument(
        "--logging-dir",
        type=Path,
        default=Path("results/lora/tensorboard"),
        help="TensorBoard base log directory (run name is appended)",
    )
    parser.add_argument("--eval-steps", type=int, default=100)
    parser.add_argument(
        "--eval-dataset-name",
        default=DEFAULT_EVAL_DATASET[0],
        help="Dataset to track forgetting (perplexity evaluation)",
    )
    parser.add_argument("--eval-dataset-config",
                        default=DEFAULT_EVAL_DATASET[1])
    parser.add_argument("--eval-split", default=DEFAULT_EVAL_DATASET[2])
    parser.add_argument("--eval-text-column", default=DEFAULT_EVAL_DATASET[3])
    parser.add_argument("--eval-max-samples", type=int, default=512)
    parser.add_argument("--eval-max-length", type=int, default=256)
    parser.add_argument("--eval-batch-size", type=int, default=8)
    parser.add_argument(
        "--train-holdout-fraction",
        type=float,
        default=0.05,
        help="Fraction of filtered training data held out for in-domain eval (0 disables).",
    )
    parser.add_argument(
        "--train-holdout-max-samples",
        type=int,
        default=512,
        help="Max held-out training samples to evaluate (<=0 uses full holdout split).",
    )
    parser.add_argument("--lora-r", type=int, default=8)
    parser.add_argument("--lora-alpha", type=int, default=16)
    parser.add_argument("--lora-dropout", type=float, default=0.05)
    parser.add_argument(
        "--lora-target-modules",
        default="q_proj,v_proj,k_proj,o_proj",
        help="Comma-separated list of module names to LoRA-adapt",
    )
    parser.add_argument(
        "--full-finetune",
        action="store_true",
        help="Train all model weights instead of LoRA adapters.",
    )
    parser.add_argument("--gradient-checkpointing", action="store_true")
    parser.add_argument("--dtype", default="auto",
                        help="Model dtype: auto|float32|float16|bfloat16")
    parser.add_argument("--device", default="auto",
                        help="Device to run on: auto|cpu|cuda|cuda:<index>")
    parser.add_argument("--trust-remote-code", action="store_true")
    parser.add_argument(
        "--packing",
        choices=["on", "off"],
        default="off",
        help="Enable packing (off avoids flash-attn requirements)",
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-proc", type=int, default=None)
    parser.add_argument("--output-json", type=Path,
                        default=Path("results/lora/lora_finetune_metrics.json"))
    parser.add_argument("--output-dir", type=Path,
                        default=Path("results/lora/lora_adapter"))
    parser.add_argument("--save-adapter", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    set_seed(args.seed)
    device = resolve_device(args.device)
    dtype = parse_dtype(args.dtype)
    run_name = "full-finetune" if args.full_finetune else f"lora-r{args.lora_r}"
    logging_dir = args.logging_dir
    if logging_dir.name != run_name:
        logging_dir = logging_dir / run_name
    effective_lr = args.learning_rate
    if effective_lr is None:
        effective_lr = 5e-5 if args.full_finetune else 5e-4

    if device.type == "cuda" and device.index is not None:
        torch.cuda.set_device(device.index)

    train_dataset_config = _normalize_config(args.train_dataset_config)
    eval_dataset_config = _normalize_config(args.eval_dataset_config)

    print(
        "Loading evaluation dataset "
        f"{args.eval_dataset_name}/{eval_dataset_config}:{args.eval_split} (max_samples={args.eval_max_samples}) ...",
        flush=True,
    )
    eval_texts = gather_texts(
        dataset_name=args.eval_dataset_name,
        dataset_config=eval_dataset_config,
        dataset_split=args.eval_split,
        text_column=args.eval_text_column,
        max_samples=args.eval_max_samples,
    )
    print(f"Collected {len(eval_texts)} evaluation texts", flush=True)

    print(
        f"Loading tokenizer (trust_remote_code={args.trust_remote_code}) ...", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID, trust_remote_code=args.trust_remote_code)
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is None:
            raise ValueError(
                "Tokenizer lacks both pad_token and eos_token; cannot proceed.")
        tokenizer.pad_token = tokenizer.eos_token

    print(f"Loading model on {device} with dtype={dtype} ...", flush=True)
    model_allowed = set(inspect.signature(
        AutoModelForCausalLM.from_pretrained).parameters)
    model_allowed.discard("self")
    model_kwargs = {
        "trust_remote_code": args.trust_remote_code,
        "dtype": dtype,
        "torch_dtype": dtype,
    }
    if "dtype" in model_allowed:
        model_kwargs.pop("torch_dtype", None)
    else:
        model_kwargs.pop("dtype", None)
    model_kwargs, _ = _filter_kwargs(model_allowed, model_kwargs)
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs)
    model.config.use_cache = False
    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    model.to(device)

    print(
        "Loading training dataset "
        f"{args.train_dataset_name}/{train_dataset_config}:{args.train_split} ...",
        flush=True,
    )
    train_dataset = load_split(
        args.train_dataset_name, train_dataset_config, args.train_split)
    formatting_func = None
    dataset_text_field: str | None = None
    instruction_column = None
    input_column = None
    output_column = None

    if args.train_format == "chat":
        instruction_column, input_column, output_column = resolve_chat_columns(
            train_dataset,
            args.instruction_column,
            args.input_column,
            args.output_column,
        )
        train_dataset = filter_chat_dataset(
            train_dataset,
            instruction_column,
            input_column,
            output_column,
        )
    else:
        dataset_text_field = resolve_text_column(
            train_dataset, args.train_text_column)
        train_dataset = filter_text_dataset(train_dataset, dataset_text_field)

    if args.max_train_samples > 0:
        train_dataset = train_dataset.select(
            range(min(args.max_train_samples, len(train_dataset))))

    if len(train_dataset) == 0:
        raise ValueError("Training dataset is empty after filtering.")

    heldout_dataset = None
    if args.train_holdout_fraction > 0:
        if not 0.0 < args.train_holdout_fraction < 1.0:
            raise ValueError(
                "--train-holdout-fraction must be between 0 and 1 (exclusive)."
            )
        if len(train_dataset) < 2:
            raise ValueError(
                "Need at least 2 training samples to create a holdout split."
            )
        split = train_dataset.train_test_split(
            test_size=args.train_holdout_fraction,
            seed=args.seed,
            shuffle=True,
        )
        train_dataset = split["train"]
        heldout_dataset = split["test"]
        if len(train_dataset) == 0 or len(heldout_dataset) == 0:
            raise ValueError(
                "Holdout split produced an empty train or held-out dataset."
            )
        print(
            f"Held out {len(heldout_dataset)} training samples for eval; "
            f"training on {len(train_dataset)} samples.",
            flush=True,
        )

    max_seq_length = min(args.train_max_length, tokenizer.model_max_length)
    if max_seq_length <= 0:
        raise ValueError("--train-max-length must be positive")

    if args.train_format == "chat":
        system_prompt = args.system_prompt.strip() if args.system_prompt else None
        formatting_func = make_chat_formatting_func(
            tokenizer,
            instruction_column,
            input_column,
            output_column,
            system_prompt,
        )

    heldout_texts = None
    if heldout_dataset is not None:
        heldout_texts = gather_texts_from_dataset(
            dataset=heldout_dataset,
            text_column=dataset_text_field,
            formatting_func=formatting_func,
            max_samples=args.train_holdout_max_samples,
        )
        if not heldout_texts:
            raise ValueError(
                "Held-out training split produced no usable text samples."
            )
        print(
            f"Collected {len(heldout_texts)} held-out training texts",
            flush=True,
        )

    target_modules: List[str] = []
    lora_config = None
    if not args.full_finetune:
        target_modules = [module.strip() for module in args.lora_target_modules.split(
            ",") if module.strip()]
        if not target_modules:
            raise ValueError(
                "--lora-target-modules must include at least one module name")

        lora_config = LoraConfig(
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            target_modules=target_modules,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )

    if args.log_steps <= 0:
        logging_strategy = "no"
        logging_steps = 1
    else:
        logging_strategy = "steps"
        logging_steps = args.log_steps

    max_steps = args.max_steps if args.max_steps > 0 else -1
    trainer_allowed = set(inspect.signature(
        SFTTrainer.__init__).parameters) - {"self"}

    if formatting_func is not None and "formatting_func" not in trainer_allowed:
        train_dataset = train_dataset.map(
            lambda example: {"text": formatting_func(example)},
            remove_columns=list(train_dataset.column_names),
            num_proc=args.num_proc,
            desc="Formatting chat samples",
        )
        dataset_text_field = "text"
        formatting_func = None

    packing_enabled = args.packing == "on"
    if packing_enabled:
        print(
            "Warning: packing enabled without flash-attn; disable with --packing off if you see cross-contamination.",
            flush=True,
        )

    config_kwargs = {
        "output_dir": str(args.output_dir),
        "logging_dir": str(logging_dir),
        "per_device_train_batch_size": args.batch_size,
        "per_device_eval_batch_size": args.eval_batch_size,
        "gradient_accumulation_steps": args.grad_accumulation,
        "num_train_epochs": args.epochs,
        "max_steps": max_steps,
        "learning_rate": effective_lr,
        "weight_decay": args.weight_decay,
        "warmup_steps": args.warmup_steps,
        "logging_strategy": logging_strategy,
        "logging_steps": logging_steps,
        "evaluation_strategy": "no",
        "save_strategy": args.save_strategy,
        "save_steps": args.save_steps,
        "save_total_limit": args.save_total_limit,
        "gradient_checkpointing": args.gradient_checkpointing,
        "bf16": dtype == torch.bfloat16,
        "fp16": dtype == torch.float16,
        "seed": args.seed,
        "report_to": ["tensorboard"],
        "run_name": run_name,
        "max_seq_length": max_seq_length,
        "packing": packing_enabled,
        "dataset_num_proc": args.num_proc,
    }
    if dataset_text_field is not None:
        config_kwargs["dataset_text_field"] = dataset_text_field
    accepted = set(inspect.signature(SFTConfig.__init__).parameters) - {"self"}
    filtered_kwargs, skipped = _filter_kwargs(accepted, config_kwargs)
    if "max_seq_length" in skipped:
        tokenizer.model_max_length = max_seq_length
        skipped.remove("max_seq_length")
    if "evaluation_strategy" in skipped:
        skipped.remove("evaluation_strategy")
    if "run_name" in skipped:
        skipped.remove("run_name")
    if skipped:
        print(
            "Skipping unsupported SFTConfig args: " +
            ", ".join(sorted(skipped)),
            flush=True,
        )
    training_args = SFTConfig(**filtered_kwargs)

    trainer_kwargs = {
        "model": model,
        "args": training_args,
        "train_dataset": train_dataset,
    }
    if lora_config is not None:
        trainer_kwargs["peft_config"] = lora_config
    if formatting_func is not None and "formatting_func" in trainer_allowed:
        trainer_kwargs["formatting_func"] = formatting_func
    if dataset_text_field is not None and "dataset_text_field" not in accepted:
        trainer_kwargs["dataset_text_field"] = dataset_text_field
    if "max_seq_length" not in accepted and "max_seq_length" in trainer_allowed:
        trainer_kwargs["max_seq_length"] = max_seq_length
    if "packing" not in accepted:
        trainer_kwargs["packing"] = packing_enabled
    if "dataset_num_proc" not in accepted:
        trainer_kwargs["dataset_num_proc"] = args.num_proc
    if "tokenizer" in trainer_allowed:
        trainer_kwargs["tokenizer"] = tokenizer
    elif "processing_class" in trainer_allowed:
        trainer_kwargs["processing_class"] = tokenizer

    trainer_kwargs, trainer_skipped = _filter_kwargs(
        trainer_allowed, trainer_kwargs)
    if trainer_skipped:
        print(
            "Skipping unsupported SFTTrainer args: " +
            ", ".join(sorted(trainer_skipped)),
            flush=True,
        )
    trainer = SFTTrainer(**trainer_kwargs)

    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()

    records: List[EvalRecord] = []

    print("Computing baseline perplexity before fine-tuning ...", flush=True)
    baseline_ppl, baseline_loss = compute_perplexity(
        trainer.model,
        tokenizer,
        eval_texts,
        device,
        args.eval_max_length,
        args.eval_batch_size,
    )
    baseline_heldout_ppl = None
    baseline_heldout_loss = None
    if heldout_texts:
        baseline_heldout_ppl, baseline_heldout_loss = compute_perplexity(
            trainer.model,
            tokenizer,
            heldout_texts,
            device,
            args.eval_max_length,
            args.eval_batch_size,
        )
    records.append(
        EvalRecord(
            step=0,
            epoch=0.0,
            phase="baseline",
            train_loss=None,
            eval_loss=baseline_loss,
            eval_perplexity=baseline_ppl,
            heldout_eval_loss=baseline_heldout_loss,
            heldout_eval_perplexity=baseline_heldout_ppl,
        )
    )
    heldout_str = ""
    if baseline_heldout_ppl is not None and baseline_heldout_loss is not None:
        heldout_str = (
            f" heldout_loss={baseline_heldout_loss:.4f} "
            f"heldout_ppl={baseline_heldout_ppl:.3f}"
        )
    print(
        f"[baseline] step=0 epoch=0.00 train_loss=n/a eval_loss={baseline_loss:.4f} "
        f"eval_ppl={baseline_ppl:.3f}{heldout_str}",
        flush=True,
    )

    trainer.add_callback(
        PerplexityCallback(
            model=trainer.model,
            tokenizer=tokenizer,
            eval_texts=eval_texts,
            heldout_texts=heldout_texts,
            device=device,
            eval_max_length=args.eval_max_length,
            eval_batch_size=args.eval_batch_size,
            eval_steps=args.eval_steps,
            records=records,
            log_fn=trainer.log,
        )
    )

    resume_from_checkpoint = None
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint.lower() == "auto":
            resume_from_checkpoint = get_last_checkpoint(str(args.output_dir))
            if resume_from_checkpoint is None:
                print(
                    f"No checkpoint found in {args.output_dir}; starting from scratch.",
                    flush=True,
                )
        else:
            resume_from_checkpoint = args.resume_from_checkpoint

    if resume_from_checkpoint:
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    else:
        trainer.train()

    if args.output_json:
        args.output_json.parent.mkdir(parents=True, exist_ok=True)
        baseline_record = next(
            (record for record in records if record.phase == "baseline"), None)
        baseline_loss = baseline_record.eval_loss if baseline_record is not None else None
        baseline_ppl = baseline_record.eval_perplexity if baseline_record is not None else None
        baseline_heldout_loss = (
            baseline_record.heldout_eval_loss if baseline_record is not None else None
        )
        baseline_heldout_ppl = (
            baseline_record.heldout_eval_perplexity if baseline_record is not None else None
        )
        forgetting_curve = []
        for record in records:
            delta_loss = None
            delta_ppl = None
            if baseline_loss is not None:
                delta_loss = record.eval_loss - baseline_loss
            if baseline_ppl is not None:
                delta_ppl = record.eval_perplexity - baseline_ppl
            delta_heldout_loss = None
            delta_heldout_ppl = None
            if baseline_heldout_loss is not None and record.heldout_eval_loss is not None:
                delta_heldout_loss = record.heldout_eval_loss - baseline_heldout_loss
            if baseline_heldout_ppl is not None and record.heldout_eval_perplexity is not None:
                delta_heldout_ppl = (
                    record.heldout_eval_perplexity - baseline_heldout_ppl
                )
            forgetting_curve.append(
                {
                    "step": record.step,
                    "epoch": record.epoch,
                    "phase": record.phase,
                    "train_loss": record.train_loss,
                    "eval_loss": record.eval_loss,
                    "eval_perplexity": record.eval_perplexity,
                    "heldout_eval_loss": record.heldout_eval_loss,
                    "heldout_eval_perplexity": record.heldout_eval_perplexity,
                    "delta_eval_loss": delta_loss,
                    "delta_eval_perplexity": delta_ppl,
                    "delta_heldout_eval_loss": delta_heldout_loss,
                    "delta_heldout_eval_perplexity": delta_heldout_ppl,
                }
            )
        payload = {
            "model_id": MODEL_ID,
            "train_dataset": {
                "name": args.train_dataset_name,
                "config": train_dataset_config,
                "split": args.train_split,
                "format": args.train_format,
                "text_column": dataset_text_field,
                "instruction_column": instruction_column,
                "input_column": input_column,
                "output_column": output_column,
                "system_prompt": args.system_prompt,
                "max_samples": args.max_train_samples,
                "max_seq_length": max_seq_length,
                "packing": True,
            },
            "eval_dataset": {
                "name": args.eval_dataset_name,
                "config": eval_dataset_config,
                "split": args.eval_split,
                "text_column": args.eval_text_column,
                "max_samples": args.eval_max_samples,
            },
            "train_holdout": {
                "enabled": bool(heldout_texts),
                "fraction": args.train_holdout_fraction,
                "max_samples": args.train_holdout_max_samples,
                "num_samples": len(heldout_texts) if heldout_texts is not None else 0,
            },
            "lora": {
                "enabled": not args.full_finetune,
                "r": args.lora_r if not args.full_finetune else None,
                "alpha": args.lora_alpha if not args.full_finetune else None,
                "dropout": args.lora_dropout if not args.full_finetune else None,
                "target_modules": target_modules if not args.full_finetune else [],
            },
            "training": {
                "full_finetune": args.full_finetune,
                "learning_rate": effective_lr,
                "batch_size": args.batch_size,
                "grad_accumulation": args.grad_accumulation,
                "epochs": args.epochs,
                "max_steps": trainer.state.max_steps,
                "global_step": trainer.state.global_step,
                "warmup_steps": args.warmup_steps,
                "dtype": args.dtype,
                "device": str(device),
                "packing": packing_enabled,
                "log_steps": args.log_steps,
                "eval_steps": args.eval_steps,
                "save_strategy": args.save_strategy,
                "save_steps": args.save_steps,
                "save_total_limit": args.save_total_limit,
                "resume_from_checkpoint": args.resume_from_checkpoint,
                "output_dir": str(args.output_dir),
                "logging_dir": str(logging_dir),
                "run_name": run_name,
            },
            "records": [asdict(record) for record in records],
            "forgetting_curve": forgetting_curve,
        }
        args.output_json.write_text(json.dumps(payload, indent=2))
        print(f"Saved metrics to {args.output_json}", flush=True)

    if args.save_adapter:
        args.output_dir.mkdir(parents=True, exist_ok=True)
        trainer.model.save_pretrained(args.output_dir)
        print(f"Saved LoRA adapter to {args.output_dir}", flush=True)


if __name__ == "__main__":  # pragma: no cover
    main()
