# finetune_lora.py
# LoRA SFT for KimiAudioModel
# - Reuse LoRA build & export utilities from finetune_codes/lora_utils.py
# - Keep dataset wiring and base SFT training logic unchanged

from dataclasses import dataclass, field
import os, json, logging
from typing import Dict, Optional, List

import torch
import transformers
from transformers import Trainer, AutoTokenizer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother
from huggingface_hub import snapshot_download

from finetune_codes.model import KimiAudioModel
from finetune_codes.datasets import LazySupervisedDataset
from functools import partial
from kimia_infer.utils.special_tokens import instantiate_extra_tokens

# Shared LoRA/export utilities
from finetune_codes.lora_utils import (
    attach_lora, ExportSplitCallback, print_trainable_params
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
IGNORE_TOKEN_ID = LabelSmoother.ignore_index


# ------------------------
# Args
# ------------------------
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="moonshotai/Kimi-Audio-7B")
    model_path: Optional[str] = field(default=None, metadata={"help": "Local pretrained path (optional)."})


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data (jsonl)."})
    eval_ratio: float = field(default=0.0, metadata={"help": "0 to disable eval split."})
    lazy_preprocess: bool = False  # kept for compatibility; not used in current implementation


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    dataloader_pin_memory: bool = field(default=False)
    model_max_length: int = field(default=8192)
    bf16: bool = field(default=True)
    fp16: bool = field(default=False)
    save_strategy: str = field(default="no")  # we export split checkpoints ourselves


@dataclass
class LoRAArguments:
    use_lora: bool = field(default=True)
    lora_r: int = field(default=16)
    lora_alpha: int = field(default=32)
    lora_dropout: float = field(default=0.05)
    include_mlp: bool = field(default=False)   # include up/gate/down projections
    top_k_layers: int = field(default=0)       # 0=all layers; >0=top-K only
    exclude_mimo: bool = field(default=True)
    adapter_name: str = field(default="default")
    target_modules: Optional[str] = field(default=None, metadata={"help": "Comma-separated LoRA targets override"})
    modules_to_save: Optional[str] = field(default=None, metadata={"help": "Comma-separated prefixes to keep trainable (non-LoRA)"})


@dataclass
class SaveArguments:
    export_split_base_dir: Optional[str] = field(default=None, metadata={"help": "Base dir for split checkpoints; default=output_dir/split_ckpts"})
    export_split_every_n_epochs: int = field(default=1)
    export_split_keep_last_k: Optional[int] = field(default=3)


# ------------------------
# Data
# ------------------------
def make_supervised_data_module(text_tokenizer, data_args, max_len, kimia_token_offset) -> Dict:
    # removed noisy stdout prints; keep behavior identical
    with open(data_args.data_path, "r") as f:
        all_data = [json.loads(line) for line in f]
    if data_args.eval_ratio and data_args.eval_ratio > 0:
        ev_n = int(len(all_data) * data_args.eval_ratio)
        eval_data = all_data[:ev_n]
        train_data = all_data[ev_n:]
        assert len(eval_data) > 0 and len(train_data) > 0
    else:
        eval_data, train_data = None, all_data

    train_dataset = LazySupervisedDataset(
        train_data, text_tokenizer=text_tokenizer, max_len=max_len, kimia_token_offset=kimia_token_offset
    )
    eval_dataset = LazySupervisedDataset(
        eval_data, text_tokenizer=text_tokenizer, max_len=max_len, kimia_token_offset=kimia_token_offset
    ) if eval_data else None

    # Use dataset's collate_fn directly (token left-pad, waveform right-pad)
    def collate(examples):
        return LazySupervisedDataset.collate_fn(examples, pad_token_id=train_dataset.pad_token)

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=collate)


# ------------------------
# Trainer (SFT: text-only cross-entropy)
# ------------------------
class LogAfterOptStepTrainer(Trainer):
    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        # Optionally log ced_processor.alpha (if present)
        try:
            base = getattr(self.model, "module", self.model)
            for name, p in base.named_parameters():
                if name.endswith("ced_processor.alpha"):
                    val = float(p.detach().cpu().reshape(()))
                    try:
                        self.log({"alpha": val})
                    except Exception:
                        pass
                    # removed verbose print; keep trainer.log to avoid stdout noise
                    break
        except Exception:
            pass

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        # compute CE on text only
        audio_logits, text_logits = outputs.logits
        audio_labels, text_labels, audio_loss_mask, text_loss_mask = labels
        ce = torch.nn.CrossEntropyLoss(reduction="none")
        loss = ce(text_logits.view(-1, text_logits.shape[-1]), text_labels.view(-1))
        loss = (loss * text_loss_mask.view(-1)).sum() / (text_loss_mask.view(-1).sum() + 1e-6)
        return (loss, outputs) if return_outputs else loss


# ------------------------
# Train
# ------------------------
def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoRAArguments, SaveArguments))
    (model_args, data_args, training_args, lora_args, save_args) = parser.parse_args_into_dataclasses()

    # ===== Load base & tokenizer =====
    logger.info("Loading Kimi-Audio base model")
    cache_path = model_args.model_name_or_path if os.path.exists(model_args.model_name_or_path) \
                 else snapshot_download(model_args.model_name_or_path)
    if model_args.model_path and not os.path.exists(model_args.model_path):
        raise ValueError(f"Model path {model_args.model_path} does not exist")

    model = KimiAudioModel.init_from_pretrained(
        model_name_or_path=model_args.model_name_or_path,
        model_load_kwargs={'low_cpu_mem_usage': True, 'torch_dtype': torch.bfloat16}
    )
    tokenizer = AutoTokenizer.from_pretrained(cache_path, trust_remote_code=True)
    extra = instantiate_extra_tokens(tokenizer)
    pad_id = extra.pad

    # ===== Data =====
    data_module = make_supervised_data_module(
        text_tokenizer=tokenizer,
        data_args=data_args,
        max_len=training_args.model_max_length,
        kimia_token_offset=model.config.kimia_token_offset
    )

    # ===== Freeze & attach LoRA =====
    for _, p in model.named_parameters():
        p.requires_grad = False
    if lora_args.use_lora:
        model = attach_lora(model, lora_args)

    try:
        model.floating_point_ops = lambda inputs: 0
    except Exception:
        pass

    print_trainable_params(model)

    train_dataset = data_module["train_dataset"]
    eval_dataset = data_module["eval_dataset"]

    # ===== Trainer =====
    trainer = LogAfterOptStepTrainer(
        model=model,
        args=training_args,
        data_collator=partial(LazySupervisedDataset.collate_fn, pad_token_id=pad_id),
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

    # Explicitly set label field to silence warnings
    trainer.label_names = ["labels"]

    # Export split model (LoRA merged) every epoch
    split_base = save_args.export_split_base_dir or os.path.join(training_args.output_dir, "split_ckpts")
    trainer.add_callback(ExportSplitCallback(
        export_base_dir=split_base,
        every_n_epochs=save_args.export_split_every_n_epochs,
        keep_last_k=save_args.export_split_keep_last_k
    ))

    trainer.train()
    try:
        trainer.save_state()
    except Exception:
        pass


if __name__ == "__main__":
    train()
