# finetune_codes/lora_utils.py
import os
import re
import gc
import types
import inspect
import logging
from typing import List, Optional

import torch
from torch import nn
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

from finetune_codes.model import KimiAudioModel
from transformers import TrainerCallback

logger = logging.getLogger(__name__)


# ------------------------
# Utilities
# ------------------------
def rank0_print(*args):
    """Print only on rank 0 (or if DDP is not initialized)."""
    if (not torch.distributed.is_available()) or (not torch.distributed.is_initialized()):
        print(*args)
        return
    if torch.distributed.get_rank() == 0:
        print(*args)


def print_trainable_params(model):
    """Summarize total/trainable parameters (printed via rank0)."""
    total = trainable = 0
    for _, p in model.named_parameters():
        n = p.numel()
        total += n
        if p.requires_grad:
            trainable += n
    rank0_print(f"Total params: {total/1e6:.1f} M | Trainable: {trainable/1e6:.3f} M ({100*trainable/total:.2f} %)")
    for n, p in model.named_parameters():
        if p.requires_grad:
            rank0_print(n, p.shape)


def _unwrap_ddp(model):
    """Return the underlying module if wrapped by DDP/accelerate, else the model itself."""
    return getattr(model, "module", model)


# ------------------------
# LoRA target selection
# ------------------------
def collect_lora_targets(
    model: nn.Module,
    include_mlp_llm: bool,
    include_adapter: bool,
    exclude_mimo: bool,
    target_modules_override: Optional[List[str]] = None,
    excluded_prefixes: Optional[List[str]] = None,
) -> List[str]:
    """
    - LLM: q/k/v/o (optionally up/gate/down)
    - Skip MIMO branch by default
    - Adapters are excluded by default (set include_adapter=True to include)
    - target_modules_override: if provided, select by substring match directly
    - excluded_prefixes: fully exclude these prefixes (useful with modules_to_save for full-precision fine-tuning)
    """
    excluded_prefixes = excluded_prefixes or []

    def _excluded(name: str) -> bool:
        return any(name.startswith(p) for p in excluded_prefixes)

    if target_modules_override:
        existing = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                if any(t in name for t in target_modules_override) and not _excluded(name):
                    existing.append(name)
        return sorted(set(existing))

    names = []
    for name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        if exclude_mimo and ("model.mimo_layers" in name):
            continue

        # LLM attention
        if re.search(r"\b(q_proj|k_proj|v_proj|o_proj)\b", name) and not _excluded(name):
            names.append(name)
            continue

        # LLM MLP
        if include_mlp_llm and re.search(r"\b(up_proj|gate_proj|down_proj)\b", name) and not _excluded(name):
            names.append(name)
            continue

        # Adapters (excluded by default; include only when include_adapter=True)
        if include_adapter and not _excluded(name):
            if name.startswith("model.ced_processor"):
                names.append(name)
                continue
            if name.startswith("model.vq_adaptor"):
                names.append(name)
                continue
            if ("aggregator_projection" in name) or ("audio_aggregator" in name):
                names.append(name)
                continue
            if re.search(r"aggregator_layer_(1|2)\.(intermediate|output)$", name):
                names.append(name)
                continue

    return sorted(set(names))


# ------------------------
# PEFT/Transformers compatibility
# ------------------------
def _ensure_prepare_inputs_for_generation(model: nn.Module) -> nn.Module:
    """
    PEFT will query prepare_inputs_for_generation (pifg). In multimodal cases we just pass through kwargs.
    """
    if hasattr(model, "prepare_inputs_for_generation"):
        return model

    def _prepare_inputs_for_generation(self, **kwargs):
        return {k: v for k, v in kwargs.items()}

    model.prepare_inputs_for_generation = types.MethodType(_prepare_inputs_for_generation, model)
    return model


def _patch_peft_forward_filter_kwargs(peft_model):
    """
    Some PEFT wrappers drop unknown kwargs. We filter via an allowlist before forwarding to base_model.
    """
    base = peft_model.get_base_model()
    try:
        allowed = set(inspect.signature(base.forward).parameters.keys())
    except Exception:
        allowed = set()

    allowed |= {
        "use_cache", "output_attentions", "output_hidden_states", "return_dict",
        "position_ids", "attention_mask", "generation_mode", "past_key_values", "inputs_embeds",
        "audio_input_ids", "text_input_ids", "waveform", "is_continuous_mask",
        "whisper_input_feature", "ced_input_feature", "labels",
    }

    def _forward(self, *args, **kwargs):
        fkwargs = {k: v for k, v in kwargs.items() if k in allowed}
        return self.base_model(*args, **fkwargs)

    peft_model.forward = types.MethodType(_forward, peft_model)
    logger.info(f"[LoRA] Patched PeftModel.forward to filter kwargs. Allowed: {sorted(list(allowed))}")


# ------------------------
# Build LoRA
# ------------------------
def attach_lora(model: KimiAudioModel, lora_args) -> nn.Module:
    """
    lora_args must have attributes:
      - lora_r, lora_alpha, lora_dropout
      - include_mlp, top_k_layers, exclude_mimo
      - adapter_name, target_modules (str|None), modules_to_save (str|None)
    """
    target_modules_override = None
    if getattr(lora_args, "target_modules", None):
        target_modules_override = [m.strip() for m in lora_args.target_modules.split(",") if m.strip()]

    # Ensure pifg exists
    model = _ensure_prepare_inputs_for_generation(model)

    modules_to_save_raw = []
    if getattr(lora_args, "modules_to_save", None):
        modules_to_save_raw = [m.strip() for m in lora_args.modules_to_save.split(",") if m.strip()]

    # Exclude these prefixes from LoRA (kept trainable in full precision)
    targets = collect_lora_targets(
        model,
        include_mlp_llm=getattr(lora_args, "include_mlp", False),
        include_adapter=False,                        # adapters excluded by default
        exclude_mimo=getattr(lora_args, "exclude_mimo", True),
        target_modules_override=target_modules_override,
        excluded_prefixes=modules_to_save_raw,
    )
    if len(targets) == 0:
        raise RuntimeError("No LoRA target modules found.")

    lconf = LoraConfig(
        r=getattr(lora_args, "lora_r", 16),
        lora_alpha=getattr(lora_args, "lora_alpha", 32),
        lora_dropout=getattr(lora_args, "lora_dropout", 0.05),
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=targets,
        modules_to_save=modules_to_save_raw,
    )
    peft_model = get_peft_model(model, lconf, adapter_name=getattr(lora_args, "adapter_name", "default"))

    # Keep a record for "full-copy" during export
    peft_model._plain_modules_to_save = modules_to_save_raw

    _patch_peft_forward_filter_kwargs(peft_model)

    # Freeze all, then unfreeze LoRA and modules_to_save params
    for _, p in peft_model.named_parameters():
        p.requires_grad = False
    for n, p in peft_model.named_parameters():
        if "lora_" in n:
            p.requires_grad = True
    for n, p in peft_model.named_parameters():
        if any(pref in n and "lora_" not in n for pref in modules_to_save_raw):
            p.requires_grad = True

    # Train only top-k layers (if set)
    top_k_layers = getattr(lora_args, "top_k_layers", 0)
    if top_k_layers and top_k_layers > 0:
        try:
            base = peft_model.get_base_model()
            L = base.config.num_hidden_layers
            start = max(0, L - top_k_layers)
            for n, p in peft_model.named_parameters():
                if "model.layers." in n and "lora_" in n:
                    try:
                        idx = int(n.split("model.layers.")[1].split(".")[0])
                        if idx < start:
                            p.requires_grad = False
                    except Exception:
                        pass
        except Exception:
            pass

    return peft_model


# ------------------------
# Merge LoRA and export split packages
# ------------------------
def _build_cpu_shadow_model_and_merge(original_model: nn.Module) -> nn.Module:
    """
    Rebuild an isomorphic base on CPU -> load current (possibly PEFT) weights ->
    if PEFT: merge_and_unload -> copy modules_to_save back into the merged base.
    """
    base_or_peft = _unwrap_ddp(original_model)

    # 1) Try to rebuild from config._name_or_path to ensure module name alignment
    name_or_path = None
    try:
        base_ref = base_or_peft.get_base_model() if isinstance(base_or_peft, PeftModel) else base_or_peft
        name_or_path = getattr(base_ref.config, "_name_or_path", None)
    except Exception:
        pass

    if name_or_path:
        base_cpu = KimiAudioModel.init_from_pretrained(
            model_name_or_path=name_or_path,
            model_load_kwargs={"torch_dtype": torch.float32, "device_map": {"": "cpu"}, "low_cpu_mem_usage": True},
        )
    else:
        base_cfg = base_or_peft.get_base_model().config if isinstance(base_or_peft, PeftModel) else base_or_peft.config
        base_cpu = KimiAudioModel(base_cfg)

    base_cpu.eval()
    _ensure_prepare_inputs_for_generation(base_cpu)

    # 2) Pull current state dict to CPU
    with torch.no_grad():
        sd_cpu = {k: v.detach().to("cpu") for k, v in base_or_peft.state_dict().items()}

    # 3) Non-PEFT: load and return
    if not isinstance(base_or_peft, PeftModel):
        _ = base_cpu.load_state_dict(sd_cpu, strict=False)
        return base_cpu

    # 4) Get active adapter config
    peft_cfg_map = base_or_peft.peft_config
    active_name = getattr(base_or_peft, "active_adapter", None) or next(iter(peft_cfg_map.keys()))
    active_cfg = peft_cfg_map[active_name]

    # 5) Inject LoRA into CPU base, then load training weights
    from peft import get_peft_model as _get_peft_model
    shadow_peft_cpu = _get_peft_model(base_cpu, active_cfg, adapter_name=active_name)
    try:
        shadow_peft_cpu.set_adapter(active_name)
    except Exception:
        pass

    missing, unexpected = shadow_peft_cpu.load_state_dict(sd_cpu, strict=False)
    if len(unexpected) > 0:
        logger.warning(f"[Export] Unexpected keys when loading into CPU PEFT: {len(unexpected)}")
    if len(missing) > 0:
        logger.warning(f"[Export] Missing keys when loading into CPU PEFT: {len(missing)}")

    # 6) Merge LoRA
    with torch.no_grad():
        merged_cpu = shadow_peft_cpu.merge_and_unload()
    merged_cpu.eval()

    # 7) Copy modules_to_save weights back into merged_cpu
    modules_to_copy = set()
    try:
        cfg = peft_cfg_map[active_name]
        if getattr(cfg, "modules_to_save", None):
            modules_to_copy.update(cfg.modules_to_save)
    except Exception:
        pass
    modules_to_copy.update(getattr(base_or_peft, "_plain_modules_to_save", []) or [])

    if modules_to_copy:
        dst = merged_cpu.state_dict()
        moved = 0

        def strip_base(k: str) -> str:
            return k[11:] if k.startswith("base_model.") else k

        for k, v in sd_cpu.items():
            k2 = strip_base(k)
            if any(k2.startswith(pref) for pref in modules_to_copy) and (k2 in dst):
                dst[k2] = v
                moved += 1
        merged_cpu.load_state_dict(dst, strict=False)
        logger.info(f"[Export] Copied {moved} parameters from modules_to_save into merged CPU base.")
    else:
        logger.info("[Export] No modules_to_save to copy into merged CPU base.")

    return merged_cpu


def export_split_from_model(model_or_wrapper: nn.Module, export_dir: str) -> str:
    """
    Merge LoRA -> call KimiAudioModel.export_model to export three packages (LM/whisper/ced).
    src_submodules passes the current training model so that whisper/ced submodules' weights can be exported.
    """
    os.makedirs(export_dir, exist_ok=True)
    try:
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()

    merged_cpu = _build_cpu_shadow_model_and_merge(model_or_wrapper)

    KimiAudioModel.export_model(
        merged_cpu,
        export_dir,
        src_submodules=_unwrap_ddp(model_or_wrapper),
    )
    logger.info(f"[Export] Exported split model to: {export_dir}")
    return export_dir


# ------------------------
# Callbacks: export by epoch (legacy behavior kept)
# ------------------------
class ExportSplitCallback(TrainerCallback):
    """
    Export the merged split at the end of each epoch.
    """
    def __init__(self, export_base_dir: str, every_n_epochs: int = 1, keep_last_k: Optional[int] = None):
        super().__init__()
        self.export_base_dir = export_base_dir
        self.every = max(1, int(every_n_epochs))
        self.keep_last_k = keep_last_k
        os.makedirs(self.export_base_dir, exist_ok=True)
        self._epoch_dirs: List[str] = []

    def on_epoch_end(self, args, state, control, **kwargs):
        if getattr(args, "process_index", 0) != 0:
            return
        if state.epoch is None:
            return
        ep = int(state.epoch)
        if ep % self.every != 0:
            return

        model = kwargs.get("model", None)
        if model is None:
            return

        out_dir = os.path.join(self.export_base_dir, f"epoch_{ep:03d}")
        logger.info(f"[ExportSplitCallback] Exporting split model for epoch {ep} -> {out_dir}")

        try:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        except Exception:
            pass
        gc.collect()

        export_split_from_model(model, out_dir)

        self._epoch_dirs.append(out_dir)
        if self.keep_last_k is not None and len(self._epoch_dirs) > self.keep_last_k:
            import shutil
            to_rm = self._epoch_dirs.pop(0)
            try:
                shutil.rmtree(to_rm, ignore_errors=True)
                logger.info(f"[ExportSplitCallback] Removed old split dir: {to_rm}")
            except Exception as e:
                logger.warning(f"[ExportSplitCallback] Failed to remove {to_rm}: {e}")


# ------------------------
# Callbacks: export by step (for GRPO, etc.)
# ------------------------
class ExportSplitByStepCallback(TrainerCallback):
    """
    Export the merged split every `every_n_steps`.
    - Only process_index==0 performs export
    - Also export once at training end if the last step didn't coincide with the interval
    """
    def __init__(self, export_base_dir: str, every_n_steps: int = 100, keep_last_k: Optional[int] = None):
        super().__init__()
        assert every_n_steps > 0, "every_n_steps must be > 0"
        self.export_base_dir = export_base_dir
        self.every_n_steps = int(every_n_steps)
        self.keep_last_k = keep_last_k
        os.makedirs(self.export_base_dir, exist_ok=True)
        self._step_dirs: List[str] = []
        self._last_export_step: int = -1

    def _maybe_cleanup(self):
        if self.keep_last_k is not None and len(self._step_dirs) > self.keep_last_k:
            import shutil
            to_rm = self._step_dirs.pop(0)
            try:
                shutil.rmtree(to_rm, ignore_errors=True)
                logger.info(f"[ExportSplitByStepCallback] Removed old split dir: {to_rm}")
            except Exception as e:
                logger.warning(f"[ExportSplitByStepCallback] Failed to remove {to_rm}: {e}")

    def on_step_end(self, args, state, control, **kwargs):
        # Called after an optimization step (global_step has increased)
        if getattr(args, "process_index", 0) != 0:
            return
        gs = int(state.global_step or 0)
        if gs <= 0:
            return
        if gs % self.every_n_steps != 0:
            return

        model = kwargs.get("model", None)
        if model is None:
            return

        out_dir = os.path.join(self.export_base_dir, f"step_{gs:06d}")
        logger.info(f"[ExportSplitByStepCallback] Exporting split model for step {gs} -> {out_dir}")

        try:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        except Exception:
            pass
        gc.collect()

        export_split_from_model(model, out_dir)
        self._step_dirs.append(out_dir)
        self._last_export_step = gs
        self._maybe_cleanup()

    def on_train_end(self, args, state, control, **kwargs):
        # Fallback export at the end of training (if the last step didn't hit the interval)
        if getattr(args, "process_index", 0) != 0:
            return
        gs = int(state.global_step or 0)
        if gs <= 0:
            return
        if self._last_export_step == gs:
            return

        model = kwargs.get("model", None)
        if model is None:
            return

        out_dir = os.path.join(self.export_base_dir, f"step_{gs:06d}_final")
        logger.info(f"[ExportSplitByStepCallback] Final export split model at step {gs} -> {out_dir}")

        try:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        except Exception:
            pass
        gc.collect()

        export_split_from_model(model, out_dir)
        self._step_dirs.append(out_dir)
        self._maybe_cleanup()
