# finetune.py
# SFT for KimiAudioModel (export split at each epoch; do not save large checkpoints)
from dataclasses import dataclass, field
import json, logging, os, gc, warnings
from typing import Dict, Optional, List

import torch
import torch.distributed as dist
import transformers
from transformers import Trainer, AutoTokenizer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother
from transformers.integrations import deepspeed as hf_ds
from huggingface_hub import snapshot_download

from finetune_codes.model import KimiAudioModel
from finetune_codes.datasets import LazySupervisedDataset
from functools import partial
from kimia_infer.utils.special_tokens import instantiate_extra_tokens

# safetensors is used only for Zero3 shard consolidation; DDP direct copy doesn't need it
try:
    from safetensors.torch import save_file, safe_open
    _HAVE_SAFETENSORS = True
except Exception:
    _HAVE_SAFETENSORS = False

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
IGNORE_TOKEN_ID = LabelSmoother.ignore_index

# ======================
# Args
# ======================
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="moonshotai/Kimi-Audio-7B")
    model_path: Optional[str] = field(default=None, metadata={"help": "Optional: local pretrained path (kept for compatibility)."})


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Training data (jsonl, one sample per line)."})
    eval_ratio: float = field(default=0.0, metadata={"help": "Validation split ratio; 0 disables eval split."})
    lazy_preprocess: bool = False


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    dataloader_pin_memory: bool = field(default=False)
    dataloader_num_workers: int = field(default=4)
    model_max_length: int = field(default=8192)
    bf16: bool = field(default=True)
    fp16: bool = field(default=False)


@dataclass
class ExportArguments:
    export_split_base_dir: Optional[str] = field(default=None, metadata={"help": "Root dir for split exports; default=output_dir/split_ckpts"})
    export_split_every_n_epochs: int = field(default=1)
    export_split_keep_last_k: Optional[int] = field(default=3)


# ======================
# Utils
# ======================
local_rank = None
def rank0_print(*args):
    if local_rank in (0, None):
        print(*args)

def print_trainable_params(model):
    total = trainable = 0
    for _, p in model.named_parameters():
        n = p.numel()
        total += n
        if p.requires_grad:
            trainable += n
    rank0_print(f"Total params: {total/1e6:.1f} M | Trainable: {trainable/1e6:.3f} M ({100*trainable/total:.2f} %)")

def make_supervised_data_module(text_tokenizer, data_args, max_len, kimia_token_offset) -> Dict:
    dataset_cls = LazySupervisedDataset
    # removed verbose rank0 print to keep training logs clean
    with open(data_args.data_path, "r") as f:
        all_data = [json.loads(line) for line in f]
    if data_args.eval_ratio and data_args.eval_ratio > 0:
        ev_n = int(len(all_data) * data_args.eval_ratio)
        eval_data, train_data = all_data[:ev_n], all_data[ev_n:]
        assert len(eval_data) > 0 and len(train_data) > 0
    else:
        eval_data, train_data = None, all_data

    train_dataset = dataset_cls(
        train_data, text_tokenizer=text_tokenizer, max_len=max_len, kimia_token_offset=kimia_token_offset
    )
    eval_dataset = dataset_cls(
        eval_data, text_tokenizer=text_tokenizer, max_len=max_len, kimia_token_offset=kimia_token_offset
    ) if eval_data else None

    # removed per-dataset length prints to reduce stdout noise
    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)

class KimiAudioTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Compute text-only cross-entropy. datasets.py has already performed label shift and masks.
        """
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        _, text_logits = outputs.logits
        _, text_labels, _, text_mask = labels

        ce = torch.nn.CrossEntropyLoss(reduction="none")
        loss = ce(text_logits.view(-1, text_logits.shape[-1]), text_labels.view(-1))
        loss = (loss * text_mask.view(-1)).sum() / (text_mask.view(-1).sum() + 1e-6)

        return (loss, outputs) if return_outputs else loss
    
def _unwrap_ddp(model):
    return getattr(model, "module", model)

def _to_cpu_dtype(t: torch.Tensor, dtype: torch.dtype):
    if t.device.type != "cpu":
        t = t.detach().to("cpu", non_blocking=True)
    if t.is_floating_point() and t.dtype != dtype:
        t = t.to(dtype)
    return t

# ====== Zero3 only ======
def _zero3_consolidate_to_safetensors(model, safepath: str):
    if not hasattr(model, "_zero3_consolidated_16bit_state_dict"):
        raise RuntimeError("Zero-3 consolidate API not found.")
    sd = model._zero3_consolidated_16bit_state_dict()  # already CPU / 16-bit
    if _HAVE_SAFETENSORS:
        save_file(sd, safepath)
        del sd; gc.collect()
    else:
        warnings.warn("safetensors is not installed; Zero-3 export will use more memory. Consider `pip install safetensors`.")
        torch.save(sd, safepath)
        del sd; gc.collect()

# ====== DDP: direct copy (no full state_dict assembly) ======
def _build_fresh_from_config(base_model_id_or_path: str) -> torch.nn.Module:
    from transformers import AutoConfig
    cfg = AutoConfig.from_pretrained(base_model_id_or_path, trust_remote_code=True)
    fresh = KimiAudioModel(cfg)
    fresh.to(device="cpu", dtype=torch.bfloat16)
    return fresh

@torch.inference_mode()
def _direct_copy_from_model_to_fresh(src_model: torch.nn.Module, fresh: torch.nn.Module):
    src = _unwrap_ddp(src_model).eval()
    src_params = dict(src.named_parameters())
    src_buffers = dict(src.named_buffers())

    for name, p_dst in fresh.named_parameters():
        p_src = src_params.get(name, None)
        if p_src is None:
            continue
        t = _to_cpu_dtype(p_src.data, p_dst.dtype)
        p_dst.data.copy_(t)
        if (hash(name) & 63) == 0:
            gc.collect()

    for name, b_dst in fresh.named_buffers():
        b_src = src_buffers.get(name, None)
        if b_src is None:
            continue
        t = _to_cpu_dtype(b_src.data, b_dst.dtype)
        b_dst.data.copy_(t)
        if (hash(name) & 63) == 1:
            gc.collect()

def export_split_from_model(model, out_dir: str, base_model_id_or_path: str):
    os.makedirs(out_dir, exist_ok=True)
    try:
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()

    if hf_ds.is_deepspeed_zero3_enabled():
        # ========== Zero-3 path ==========
        tmp_safe = os.path.join(out_dir, "_tmp_consolidated.safetensors")
        rank0_print("[Export] Zero-3 detected: consolidating shards to safetensors...")
        _zero3_consolidate_to_safetensors(model, tmp_safe)

        rank0_print(f"[Export] Loading fresh base via init_from_pretrained: {base_model_id_or_path}")
        fresh = KimiAudioModel.init_from_pretrained(
            model_name_or_path=base_model_id_or_path,
            model_load_kwargs={"low_cpu_mem_usage": True, "torch_dtype": torch.bfloat16}
        )
        fresh.to("cpu")

        if _HAVE_SAFETENSORS:
            with safe_open(tmp_safe, framework="pt", device="cpu") as f:
                keys = set(f.keys())
                for name, p in fresh.named_parameters():
                    if name in keys:
                        t = f.get_tensor(name)
                        if t.dtype != p.dtype: t = t.to(p.dtype)
                        p.data.copy_(t)
                for name, b in fresh.named_buffers():
                    if name in keys:
                        t = f.get_tensor(name)
                        if t.dtype != b.dtype: t = t.to(b.dtype)
                        b.data.copy_(t)
        else:
            sd = torch.load(tmp_safe, map_location="cpu")
            for name, p in fresh.named_parameters():
                if name in sd:
                    t = sd[name]
                    if t.dtype != p.dtype: t = t.to(p.dtype)
                    p.data.copy_(t); del sd[name]
                    if (hash(name) & 63) == 0: gc.collect()
            for name, b in fresh.named_buffers():
                if name in sd:
                    t = sd[name]
                    if t.dtype != b.dtype: t = t.to(b.dtype)
                    b.data.copy_(t); del sd[name]
                    if (hash(name) & 63) == 1: gc.collect()
            sd.clear(); gc.collect()
        try: os.remove(tmp_safe)
        except Exception: pass
    else:
        # ========== DDP direct-copy path ==========
        fresh = _build_fresh_from_config(base_model_id_or_path)
        _direct_copy_from_model_to_fresh(model, fresh)

    # 3) Export the three packages (prefer submodules from the training model as source weights)
    KimiAudioModel.export_model(fresh, out_dir, src_submodules=_unwrap_ddp(model))
    rank0_print(f"[Export] Exported split to: {out_dir}")

    del fresh
    gc.collect()
    try: torch.cuda.empty_cache()
    except Exception: pass

    return out_dir

class ExportSplitCallback(TrainerCallback):
    def __init__(self, export_base_dir: str, base_model_id_or_path: str,
                 every_n_epochs: int = 1, keep_last_k: Optional[int] = 3):
        self.export_base_dir = export_base_dir
        self.base_model = base_model_id_or_path
        self.every = max(1, int(every_n_epochs))
        self.keep_last_k = keep_last_k
        os.makedirs(self.export_base_dir, exist_ok=True)
        self._epoch_dirs: List[str] = []

    def _barrier(self):
        if dist.is_available() and dist.is_initialized():
            dist.barrier()

    def on_epoch_end(self, args, state, control, **kwargs):
        if state.epoch is None:
            return
        ep = int(state.epoch)
        should_export = (ep % self.every == 0)

        self._barrier()

        if should_export and getattr(args, "process_index", 0) == 0:
            model = kwargs.get("model", None)
            if model is None:
                logger.warning("[ExportSplitCallback] 'model' not found in kwargs; skip this epoch.")
            else:
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                out_dir = os.path.join(self.export_base_dir, f"epoch_{ep:03d}")
                rank0_print(f"[ExportSplitCallback] Exporting epoch {ep} -> {out_dir}")

                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
                gc.collect()

                export_split_from_model(model, out_dir, self.base_model)

                self._epoch_dirs.append(out_dir)
                if self.keep_last_k is not None and len(self._epoch_dirs) > self.keep_last_k:
                    import shutil
                    to_rm = self._epoch_dirs.pop(0)
                    try:
                        shutil.rmtree(to_rm, ignore_errors=True)
                        rank0_print(f"[ExportSplitCallback] Removed old split dir: {to_rm}")
                    except Exception as e:
                        logger.warning(f"[ExportSplitCallback] Failed to remove {to_rm}: {e}")

        self._barrier()

# ======================
# Train
# ======================
def train():
    global local_rank
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, ExportArguments))
    (model_args, data_args, training_args, export_args) = parser.parse_args_into_dataclasses()
    training_args.remove_unused_columns = False
    local_rank = training_args.local_rank

    logger.info("Loading Kimi-Audio base model")
    cache_path = model_args.model_name_or_path if os.path.exists(model_args.model_name_or_path) \
                 else snapshot_download(model_args.model_name_or_path)
    if model_args.model_path and (not os.path.exists(model_args.model_path)):
        raise ValueError(f"Model path {model_args.model_path} does not exist")

    model = KimiAudioModel.init_from_pretrained(
        model_name_or_path=model_args.model_name_or_path,
        model_load_kwargs={"low_cpu_mem_usage": True, "torch_dtype": torch.bfloat16}
    )
    tokenizer = AutoTokenizer.from_pretrained(cache_path, trust_remote_code=True)

    extra = instantiate_extra_tokens(tokenizer)
    pad_id = extra.pad
    tok_pad_token_id = getattr(tokenizer, "pad_token_id", None)
    if tok_pad_token_id is not None and tok_pad_token_id != pad_id:
        rank0_print(f"[Warn] tokenizer.pad_token_id={tok_pad_token_id} != extra.pad={pad_id}; using extra.pad.")

    data_module = make_supervised_data_module(
        text_tokenizer=tokenizer,
        data_args=data_args,
        max_len=training_args.model_max_length,
        kimia_token_offset=model.config.kimia_token_offset
    )

    # Freeze all params; train only selected submodules
    for _, p in model.named_parameters():
        p.requires_grad = False
    trainable_modules = ["model.ced_processor", "model.vq_adaptor"]
    # trainable_modules = ["model.vq_adaptor"]
    for name, p in model.named_parameters():
        if any(name.startswith(prefix) for prefix in trainable_modules):
            p.requires_grad = True

    print_trainable_params(model)

    train_dataset = data_module["train_dataset"]
    eval_dataset = data_module["eval_dataset"]

    trainer = KimiAudioTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=partial(LazySupervisedDataset.collate_fn, pad_token_id=pad_id),
        processing_class=tokenizer,
    )

    split_base = export_args.export_split_base_dir or os.path.join(training_args.output_dir, "split_ckpts")
    trainer.add_callback(ExportSplitCallback(
        export_base_dir=split_base,
        base_model_id_or_path=model_args.model_name_or_path,
        every_n_epochs=export_args.export_split_every_n_epochs,
        keep_last_k=export_args.export_split_keep_last_k
    ))

    trainer.train()
    try:
        trainer.save_state()
    except Exception:
        pass

if __name__ == "__main__":
    train()
