# finetune_codes/model.py
import os
import argparse
from typing import Optional, List
import shutil
import torch
from transformers import AutoModelForCausalLM  # (kept; may be used by downstream code)
from huggingface_hub import snapshot_download

from kimia_infer.models.tokenizer.whisper_Lv3.whisper import WhisperEncoder, WhisperModel
from kimia_infer.models.tokenizer.ced_base.modeling_ced import CedEncoder, CedForAudioClassification

from .modeling_kimia import MoonshotKimiaForCausalLM
import torch.nn.functional as F
from transformers.utils import logging as hf_logging

logger = hf_logging.get_logger(__name__)

def _top_k_top_p_filtering(logits, top_k=0, top_p=1.0, min_p: float = 0.0, filter_value=-float("inf")):
    """
    Apply per-sample top-k / top-p / min-p filtering to logits.
    - top-k: keep top k tokens
    - top-p: nucleus sampling
    - min-p: filter tokens with probability below the threshold
    """
    # top-k
    if top_k and top_k > 0:
        top_k = min(top_k, logits.size(-1))
        kth_vals = torch.topk(logits, top_k)[0][..., -1, None]
        logits = torch.where(
            logits < kth_vals,
            torch.tensor(filter_value, device=logits.device, dtype=logits.dtype),
            logits,
        )

    # top-p
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = sorted_probs.cumsum(dim=-1)
        sorted_mask = cumulative_probs > top_p
        # shift mask right to keep the first token above the threshold
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = 0
        indices_to_remove = torch.scatter(
            torch.zeros_like(sorted_mask, dtype=torch.bool),
            -1,
            sorted_indices,
            sorted_mask,
        )
        logits = logits.masked_fill(indices_to_remove, filter_value)

    # min-p (applied after top-p)
    if min_p and min_p > 0.0:
        probs = torch.softmax(logits, dim=-1)
        logits = logits.masked_fill(probs < min_p, filter_value)

    return logits


class KimiAudioModel(MoonshotKimiaForCausalLM):
    """For training; contains external Whisper and CED encoders."""
    def __init__(self, config):
        super().__init__(config)
        # Training-time: encoders are loaded externally
        self.whisper_model = None
        self.ced_model = None

    @classmethod
    def init_from_pretrained(cls, model_name_or_path: str, model_load_kwargs: dict):
        if os.path.exists(model_name_or_path):
            cache_path = model_name_or_path
        else:
            cache_path = snapshot_download(model_name_or_path)

        # 1) Load LLM backbone
        model = super(KimiAudioModel, cls).from_pretrained(
            cache_path,
            trust_remote_code=True,
            **model_load_kwargs,
        )

        # 2) Load external audio encoders
        whisper_path = os.path.join(cache_path, "whisper-large-v3")
        ced_path = os.path.join(cache_path, "ced-base")
        if not os.path.exists(whisper_path):
            raise FileNotFoundError(f"Whisper model directory not found at: {whisper_path}")
        if not os.path.exists(ced_path):
            raise FileNotFoundError(f"CED model directory not found at: {ced_path}")

        model.whisper_model = WhisperEncoder(whisper_path, mel_batch_size=20)
        model.ced_model = CedEncoder(ced_path)

        # 3) Precision policy:
        #    Whisper -> BF16; CED -> FP32 (LayerNorm requires FP32)
        dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        try:
            model.whisper_model.to(device=dev, dtype=torch.bfloat16).eval()
        except Exception:
            # Fallback when dtype casting is unsupported: move to device + eval only
            model.whisper_model.to(device=dev).eval()
        # CED stays in FP32
        model.ced_model.to(device=dev, dtype=torch.float32).eval()

        return model

    @staticmethod
    def export_model(model_or_path, output_dir, src_submodules: Optional[object] = None):
        if isinstance(model_or_path, str):
            logger.info("Loading model from %s", model_or_path)
            kimiaudio = KimiAudioModel.from_pretrained(model_or_path)
        else:
            kimiaudio = model_or_path

        logger.info("Saving Kimi-Audio LM to %s", output_dir)
        audio_model = MoonshotKimiaForCausalLM(kimiaudio.config)
        audio_model = audio_model.to(dtype=torch.bfloat16)
        audio_model_state_dict = {
            k: v.to(torch.bfloat16)
            for k, v in kimiaudio.state_dict().items()
            if not k.startswith("whisper_model") and not k.startswith("ced_model")
        }
        audio_model.load_state_dict(audio_model_state_dict, strict=False)

        audio_model.save_pretrained(output_dir, torch_dtype=torch.bfloat16)
        shutil.copyfile("finetune_codes/configuration_moonshot_kimia.py", os.path.join(output_dir, "configuration_moonshot_kimia.py"))
        shutil.copyfile("finetune_codes/modeling_kimia.py", os.path.join(output_dir, "modeling_moonshot_kimia.py"))

        # Choose source weights for submodules: prefer kimiaudio, otherwise src_submodules
        src_whisper = getattr(kimiaudio, "whisper_model", None)
        if src_whisper is None and src_submodules is not None:
            src_whisper = getattr(src_submodules, "whisper_model", None)

        # NOTE: hardcoded paths below are preserved to avoid behavioral changes
        whisper_model = WhisperModel.from_pretrained("YOUR_BASE_MODEL_PATH/whisper-large-v3")
        if src_whisper is not None:
            kimiaudio_whisper_encoder_state_dict = {
                k.replace("speech_encoder.", "encoder."): v.to(torch.bfloat16)
                for k, v in src_whisper.state_dict().items() if k.startswith("speech_encoder")
            }
            missing_keys, unexpected_keys = whisper_model.load_state_dict(kimiaudio_whisper_encoder_state_dict, strict=False)
            assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
            for k in missing_keys:
                assert k.startswith("decoder"), f"Missing keys: {k}"
        else:
            logger.warning("[export_model] whisper_model is None; exporting base Whisper weights.")
        whisper_model = whisper_model.to(dtype=torch.bfloat16)
        whisper_model.save_pretrained(os.path.join(output_dir, "whisper-large-v3"), torch_dtype=torch.bfloat16)

        src_ced = getattr(kimiaudio, "ced_model", None)
        if src_ced is None and src_submodules is not None:
            src_ced = getattr(src_submodules, "ced_model", None)

        ced_model = CedForAudioClassification.from_pretrained("YOUR_BASE_MODEL_PATH/ced-base")
        if src_ced is not None:
            kimiaudio_ced_encoder_state_dict = {
                k.replace("audio_encoder.", "encoder."): v.to(torch.float32)
                for k, v in src_ced.state_dict().items() if k.startswith("audio_encoder")
            }
            missing_keys, unexpected_keys = ced_model.load_state_dict(kimiaudio_ced_encoder_state_dict, strict=False)
            assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
        else:
            logger.warning("[export_model] ced_model is None; exporting base CED weights.")

        ced_model = ced_model.to(dtype=torch.float32)
        try:
            ced_model.save_pretrained(os.path.join(output_dir, "ced-base"), dtype=torch.float32)
        except TypeError:
            # Older transformers versions do not support the `dtype` argument
            ced_model.save_pretrained(os.path.join(output_dir, "ced-base"))

        shutil.copyfile("YOUR_BASE_MODEL_PATH/special_tokens_map.json", os.path.join(output_dir, "special_tokens_map.json"))
        shutil.copyfile("YOUR_BASE_MODEL_PATH/tiktoken.model", os.path.join(output_dir, "tiktoken.model"))
        shutil.copyfile("YOUR_BASE_MODEL_PATH/tokenization_kimia.py", os.path.join(output_dir, "tokenization_kimia.py"))
        shutil.copyfile("YOUR_BASE_MODEL_PATH/tokenizer_config.json", os.path.join(output_dir, "tokenizer_config.json"))
        shutil.copyfile("YOUR_BASE_MODEL_PATH/generation_config.json", os.path.join(output_dir, "generation_config.json"))
        logger.info("Exported Kimi-Audio LM, Whisper, and CED models to %s", output_dir)

    def forward(
        self,
        audio_input_ids: torch.LongTensor = None,
        text_input_ids: torch.LongTensor = None,
        waveform: Optional[torch.FloatTensor] = None,
        is_continuous_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        generation_mode: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Training forward pass:
        - Whisper feature extraction: BF16 (autocast)
        - CED feature extraction: FP32 (disable autocast) to satisfy LayerNorm FP32 constraints
        - Backbone: BF16 (autocast)
        """
        whisper_feats = None
        ced_feats_tuple = None

        if waveform is not None and waveform.numel() > 0:
            with torch.no_grad():
                # Move waveform to the same device as the model
                dev = next(self.parameters()).device
                waveform = waveform.to(dev)

                # Whisper feature extraction (allow BF16)
                try:
                    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                        whisper_output = self.whisper_model(waveform)
                except Exception:
                    whisper_output = self.whisper_model(waveform)

                whisper_feats = whisper_output.reshape(
                    whisper_output.shape[0],
                    int(whisper_output.shape[1] // 4),
                    whisper_output.shape[2] * 4,
                ).to(torch.bfloat16)

                # CED feature extraction (force FP32: disable autocast)
                try:
                    with torch.autocast(device_type="cuda", enabled=False):
                        ced_output = self.ced_model(waveform.float())
                except Exception:
                    ced_output = self.ced_model(waveform.float())

                all_ced_hidden_states = ced_output.hidden_states
                if len(all_ced_hidden_states) > 7:
                    # Keep FP32 in encoders; cast to BF16 before feeding to backbone
                    ced_feat_4 = all_ced_hidden_states[3].to(torch.bfloat16)
                    ced_feat_8 = all_ced_hidden_states[7].to(torch.bfloat16)
                    ced_feat_last = all_ced_hidden_states[-1].to(torch.bfloat16)
                    ced_feats_tuple = (ced_feat_4, ced_feat_8, ced_feat_last)
                else:
                    logger.warning(
                        f"CED model has only {len(all_ced_hidden_states)} layers. Cannot extract features from layers 4 and 8."
                    )

        # Backbone BF16 compute
        use_cuda_autocast = torch.cuda.is_available()
        autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) if use_cuda_autocast else torch.no_grad()
        with autocast_ctx:
            return super().forward(
                audio_input_ids=audio_input_ids,
                text_input_ids=text_input_ids,
                whisper_input_feature=whisper_feats,
                ced_input_feature=ced_feats_tuple,
                is_continuous_mask=is_continuous_mask,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                generation_mode=generation_mode,
                return_dict=return_dict,
            )

    @torch.no_grad()
    def generate(
        self,
        *,
        text_input_ids: torch.LongTensor,
        audio_input_ids: torch.LongTensor = None,
        is_continuous_mask: torch.Tensor = None,
        waveform: torch.FloatTensor = None,
        attention_mask: torch.LongTensor = None,
        position_ids: torch.LongTensor = None,
        kimia_processor=None,   # optionally used to fetch blank/eos etc.
        max_new_tokens: int = 128,
        do_sample: bool = True,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 1.0,
        min_p: float = 0.0,
        eos_token_id: int = None,
        pad_token_id: int = None,
        use_cache: bool = True,
        **unused,
    ):
        """
        Custom text generation with multimodal first step; subsequent steps feed text only
        (audio stream writes blank placeholders).
        """
        device = text_input_ids.device
        B, _ = text_input_ids.shape
        eos = eos_token_id
        if eos is None and hasattr(self.config, "eos_token_id"):
            eos = self.config.eos_token_id
        if eos is None and kimia_processor is not None:
            eos = getattr(kimia_processor.tokenizer, "eos_token_id", None)

        pad = pad_token_id
        if pad is None and hasattr(self.config, "pad_token_id"):
            pad = self.config.pad_token_id
        if pad is None and kimia_processor is not None:
            pad = getattr(kimia_processor, "pad_token_id", None)
        if pad is None:
            pad = 0

        # Fallback attention mask
        if attention_mask is None:
            if audio_input_ids is not None:
                nonpad_audio = (audio_input_ids != pad)
            else:
                nonpad_audio = torch.zeros_like(text_input_ids, dtype=torch.bool)

            if text_input_ids is not None:
                nonpad_text = (text_input_ids != pad)
            else:
                nonpad_text = torch.zeros_like(audio_input_ids, dtype=torch.bool)

            attention_mask = (nonpad_audio | nonpad_text)

        past_key_values = None
        cur_text = text_input_ids
        cur_audio = audio_input_ids
        cur_mask = is_continuous_mask
        cur_wave = waveform

        generated = []
        finished_mask = torch.zeros(B, dtype=torch.bool, device=device)

        for _ in range(max_new_tokens):
            outputs = self(
                audio_input_ids=cur_audio,
                text_input_ids=cur_text,
                waveform=cur_wave,
                is_continuous_mask=cur_mask,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                return_dict=True,
            )
            past_key_values = outputs.past_key_values
            text_logits = outputs.logits[1][:, -1, :]  # [B, V]

            # mask out finished samples
            if eos is not None:
                text_logits = text_logits.masked_fill(finished_mask.unsqueeze(-1), float("-inf"))
                text_logits[:, eos] = torch.where(
                    finished_mask,
                    torch.tensor(0.0, device=device, dtype=text_logits.dtype),
                    text_logits[:, eos],
                )

            # temperature & filtering
            if temperature and temperature > 0:
                text_logits = text_logits / float(temperature)
            text_logits = _top_k_top_p_filtering(text_logits, top_k=top_k, top_p=top_p, min_p=min_p)

            # sampling/greedy
            if do_sample:
                probs = torch.softmax(text_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1)  # [B,1]
            else:
                next_tokens = torch.argmax(text_logits, dim=-1, keepdim=True)

            # update finished flags
            if eos is not None:
                finished_mask |= (next_tokens.squeeze(-1) == eos)
            generated.append(next_tokens)

            if torch.all(finished_mask):
                break

            # subsequent steps: feed text only; audio uses blank placeholders; no waveform/continuous mask
            cur_text = next_tokens
            blank = kimia_processor.extra.kimia_text_blank if kimia_processor is not None else 0
            cur_audio = torch.full_like(cur_text, blank)
            cur_mask  = torch.zeros_like(cur_text, dtype=torch.bool)
            cur_wave  = None

            # grow attention_mask / position_ids by 1
            attention_mask = torch.cat(
                [attention_mask, torch.ones(attention_mask.size(0), 1, device=device, dtype=torch.bool)],
                dim=1
            )
            pos = attention_mask.long().cumsum(-1) - 1
            pos.masked_fill_(attention_mask == 0, 0)
            position_ids = pos[:, -1:]

        if len(generated) == 0:
            return text_input_ids
        new_seq = torch.cat(generated, dim=1)  # [B, L_new]
        return torch.cat([text_input_ids, new_seq], dim=1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="YOUR_MODEL_NAME_OR_PATH")
    parser.add_argument("--output_dir", type=str, default="OUTPUT_DIR")
    parser.add_argument("--action", type=str, choices=["init_from_pretrained", "separate"], default="separate")
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    new_initial_model: bool = False

    if args.action == "init_from_pretrained":
        # initialize model and save
        model = KimiAudioModel.init_from_pretrained(args.model_name, model_load_kwargs={})
        if new_initial_model:
            model._initialize_newly_added_modules()
            KimiAudioModel.export_model(model, args.output_dir)
    elif args.action == "separate":
        # split a pretrained model into three submodels
        KimiAudioModel.export_model(args.model_name, args.output_dir)
