# kimia_infer/api/kimia.py
import os

import tqdm
import torch
from loguru import logger
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM

from kimia_infer.models.detokenizer import get_audio_detokenizer
from .prompt_manager import KimiAPromptManager
from kimia_infer.utils.sampler import KimiASampler


class KimiAudio(object):
    def __init__(self, model_path: str, load_detokenizer: bool = True):
        logger.info("Loading Kimi-Audio main model")

        if os.path.exists(model_path):
            cache_path = model_path
        else:
            cache_path = snapshot_download(model_path)

        logger.info(f"Looking for resources in {cache_path}")
        logger.info("Loading causal LM (with KimiAudio head)")
        device = torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu")
        self.alm = AutoModelForCausalLM.from_pretrained(
            cache_path, torch_dtype=torch.bfloat16, trust_remote_code=True
        ).to(device).eval()

        model_config = self.alm.config
        self.kimia_text_audiodelaytokens = model_config.kimia_mimo_audiodelaytokens
        self.kimia_token_offset = model_config.kimia_token_offset

        self.prompt_manager = KimiAPromptManager(
            model_path=cache_path,
            kimia_token_offset=self.kimia_token_offset,
            kimia_text_audiodelaytokens=self.kimia_text_audiodelaytokens,
        )

        if load_detokenizer:
            logger.info("Loading detokenizer (first run may compile extensions)")
            self.detokenizer = get_audio_detokenizer(cache_path)
        else:
            # In this case, audio (wav) generation is disabled.
            self.detokenizer = None

        self.prompt_manager._maybe_load_text()
        self.extra_tokens = self.prompt_manager.extra_tokens
        self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end]

    @torch.inference_mode()
    def _generate_loop(
        self,
        audio_input_ids: torch.Tensor,
        text_input_ids: torch.Tensor = None,
        max_new_tokens: int = 50,
        audio_top_k: int = 5,
        audio_temperature: float = 0.0,
        audio_repetition_penalty: float = 1.0,
        audio_repetition_window_size: int = 64,
        text_top_k: int = 5,
        text_temperature: float = 0.0,
        text_repetition_penalty: float = 1.0,
        text_repetition_window_size: int = 16,
        is_continuous_mask: torch.Tensor = None,
        whisper_input_feature: torch.Tensor = None,
        ced_input_feature: list[torch.Tensor] = None,
        output_type: str = "text",
    ):
        """
        Generic dual-stream sampling loop.
        If the underlying model does not return `audio_logits` (audio head disabled),
        it automatically falls back to text-only sampling while writing blank audio placeholders.
        """
        # Defensive: ensure positive steps
        max_new_tokens = max(1, int(max_new_tokens))

        sampler = KimiASampler(
            audio_top_k=audio_top_k,
            audio_temperature=audio_temperature,
            audio_repetition_penalty=audio_repetition_penalty,
            audio_repetition_window_size=audio_repetition_window_size,
            text_top_k=text_top_k,
            text_temperature=text_temperature,
            text_repetition_penalty=text_repetition_penalty,
            text_repetition_window_size=text_repetition_window_size,
        )

        text_stream_is_finished = False

        device = next(self.alm.parameters()).device
        # Dynamic buffers sized by max_new_tokens
        previous_audio_tokens = torch.empty((max_new_tokens,), dtype=torch.long, device=device)
        text_previous_tokens  = torch.empty((max_new_tokens,), dtype=torch.long, device=device)

        decoder_input_audio_ids = audio_input_ids.clone()
        decoder_input_text_ids  = text_input_ids.clone()
        decoder_position_ids    = torch.arange(0, decoder_input_audio_ids.shape[1], device=device).unsqueeze(0).long()
        decoder_input_whisper_feature = whisper_input_feature
        decoder_input_ced_feature     = ced_input_feature
        decoder_is_continuous_mask    = is_continuous_mask
        past_key_values = None

        last_position_id   = decoder_input_audio_ids.shape[1] - 1
        valid_text_length  = 0
        valid_audio_length = 0

        for i in tqdm.tqdm(range(max_new_tokens), desc="Generating tokens", disable=False):
            out = self.alm.forward(
                audio_input_ids=decoder_input_audio_ids,
                text_input_ids=decoder_input_text_ids,
                whisper_input_feature=decoder_input_whisper_feature,
                ced_input_feature=decoder_input_ced_feature,
                is_continuous_mask=decoder_is_continuous_mask,
                position_ids=decoder_position_ids,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
            )
            audio_logits, text_logits = out.logits
            past_key_values = out.past_key_values

            # On the first step: if audio head is disabled -> switch to text-only fast path
            if i == 0 and (audio_logits is None) and (output_type == "both"):
                logger.warning("Audio logits are disabled in the current model; falling back to text-only generation.")
                _, text_prev = self._generate_loop_text_only(
                    audio_input_ids=decoder_input_audio_ids,
                    text_input_ids=decoder_input_text_ids,
                    max_new_tokens=max_new_tokens,
                    text_temperature=text_temperature,
                    text_top_k=text_top_k,
                    text_repetition_penalty=text_repetition_penalty,
                    text_repetition_window_size=text_repetition_window_size,
                    is_continuous_mask=decoder_is_continuous_mask,
                    whisper_input_feature=decoder_input_whisper_feature,
                    ced_input_feature=decoder_input_ced_feature,
                )
                return [], text_prev

            # === Text sampling ===
            next_token_text = sampler.sample_text_logits(
                text_logits, recent_tokens=(text_previous_tokens[:i] if i > 0 else None)
            )

            # === Audio sampling (or write blank) ===
            if audio_logits is None:
                next_audio_token = torch.full_like(next_token_text, self.extra_tokens.kimia_text_blank)
            else:
                next_audio_token = sampler.sample_audio_logits(
                    audio_logits, recent_tokens=(previous_audio_tokens[:i] if i > 0 else None)
                )

            # Text end-of-sequence handling
            if text_stream_is_finished:
                next_token_text.fill_(self.extra_tokens.kimia_text_blank)
            elif int(next_token_text) == self.extra_tokens.kimia_text_eos:
                text_stream_is_finished = True
            else:
                valid_text_length += 1

            text_previous_tokens[i:i+1] = next_token_text

            # Audio delay / text-only scenarios
            if i < self.kimia_text_audiodelaytokens:
                next_audio_token.fill_(self.extra_tokens.kimia_text_blank)
            else:
                if output_type == "text" or audio_logits is None:
                    next_audio_token.fill_(self.extra_tokens.kimia_text_blank)
                else:
                    valid_audio_length += 1

            previous_audio_tokens[i:i+1] = next_audio_token

            audio_stream_is_finished = int(next_audio_token) in self.eod_ids

            # Stop conditions
            if (
                (output_type == "text" and text_stream_is_finished)
                or (output_type == "both" and audio_stream_is_finished)
            ):
                return_text_tokens = text_previous_tokens[:valid_text_length].detach().cpu().tolist()
                return_audio_tokens = previous_audio_tokens[
                    self.kimia_text_audiodelaytokens : self.kimia_text_audiodelaytokens + valid_audio_length
                ].detach().cpu().tolist()
                return return_audio_tokens, return_text_tokens

            # Next-step feeding: only the new token; multi-modal features are only needed on the first step
            decoder_input_audio_ids = next_audio_token.unsqueeze(1)
            decoder_input_text_ids  = next_token_text.unsqueeze(1)
            decoder_position_ids    = torch.full((1, 1), last_position_id + 1, device=device, dtype=torch.long)
            last_position_id += 1

            decoder_input_whisper_feature = None
            decoder_input_ced_feature     = None
            decoder_is_continuous_mask    = None

        # Hit step limit
        return_text_tokens = text_previous_tokens[:valid_text_length].detach().cpu().tolist()
        return_audio_tokens = previous_audio_tokens[
            self.kimia_text_audiodelaytokens : self.kimia_text_audiodelaytokens + valid_audio_length
        ].detach().cpu().tolist()
        return return_audio_tokens, return_text_tokens

    @torch.inference_mode()
    def generate(
        self,
        chats: list[dict],
        output_type="text",
        audio_temperature=0.0,
        audio_top_k=5,
        text_temperature=0.0,
        text_top_k=5,
        audio_repetition_penalty=1.0,
        audio_repetition_window_size=64,
        text_repetition_penalty=1.0,
        text_repetition_window_size=16,
        max_new_tokens=-1,
    ):
        """
        output_type: "text" | "both"
        - "text": text-only fast path, no audio sampling.
        - "both": if the current model has audio head disabled, fall back to text-only with a log warning.
        """
        assert output_type in ["text", "both"]

        history = self.prompt_manager.get_prompt(chats, output_type=output_type)

        audio_input_ids, text_input_ids, is_continuous_mask, _, _ = history.to_tensor()
        whisper_features = history.continuous_feature
        ced_features = history.ced_hidden_states

        generated_wav_tokens = []
        generated_text_tokens = []

        # Decide max steps dynamically
        if output_type == "both":
            max_new_tokens = max(1, int(12.5 * 120) - audio_input_ids.shape[1])
        else:
            if max_new_tokens == -1:
                max_new_tokens = max(1, 7500 - audio_input_ids.shape[1])

        device = next(self.alm.parameters()).device
        audio_input_ids = audio_input_ids.to(device)
        text_input_ids  = text_input_ids.to(device)
        is_continuous_mask = is_continuous_mask.to(device)
        whisper_features, ced_features = self._normalize_features(whisper_features, ced_features, device)

        # Text-only fast path
        if output_type == "text":
            _, generated_text_tokens = self._generate_loop_text_only(
                audio_input_ids=audio_input_ids,
                text_input_ids=text_input_ids,
                max_new_tokens=max_new_tokens,
                text_temperature=text_temperature,
                text_top_k=text_top_k,
                text_repetition_penalty=text_repetition_penalty,
                text_repetition_window_size=text_repetition_window_size,
                is_continuous_mask=is_continuous_mask,
                whisper_input_feature=whisper_features,
                ced_input_feature=ced_features,
            )
        else:
            # Generate both audio + text
            generated_wav_tokens, generated_text_tokens = self._generate_loop(
                audio_input_ids=audio_input_ids,
                text_input_ids=text_input_ids,
                max_new_tokens=max_new_tokens,
                audio_temperature=audio_temperature,
                audio_top_k=audio_top_k,
                audio_repetition_penalty=audio_repetition_penalty,
                audio_repetition_window_size=audio_repetition_window_size,
                text_top_k=text_top_k,
                text_temperature=text_temperature,
                text_repetition_penalty=text_repetition_penalty,
                text_repetition_window_size=text_repetition_window_size,
                is_continuous_mask=is_continuous_mask,
                whisper_input_feature=whisper_features,
                ced_input_feature=ced_features,
                output_type=output_type,
            )
            # If audio logits are disabled, degrade to text-only
            if len(generated_wav_tokens) == 0:
                logger.warning("Audio logits are disabled in the current model; falling back to text-only generation.")
                output_type = "text"

        # Post-processing
        generated_wav = None
        if output_type == "both":
            generated_wav_tokens = [t for t in generated_wav_tokens if t >= self.kimia_token_offset]
            generated_wav_tokens = torch.tensor(generated_wav_tokens, device=device).unsqueeze(0)
            generated_wav_tokens = generated_wav_tokens - self.kimia_token_offset
            if self.detokenizer is not None:
                generated_wav = self.detokenize_audio(generated_wav_tokens)
            else:
                logger.warning("Detokenizer not initialized; cannot return audio waveform.")

        generated_text_tokens = [t for t in generated_text_tokens if t < self.kimia_token_offset]
        generated_text = self.detokenize_text(generated_text_tokens)

        return generated_wav, generated_text

    def detokenize_audio(self, audio_tokens):
        if self.detokenizer is None:
            raise ValueError("Detokenizer is not initialized")
        self.detokenizer.clear_states()
        chunk_size = 30  # hard-coded currently
        first_chunk_size = 30
        cache_speech_collection = []

        device = next(self.alm.parameters()).device
        audio_tokens = audio_tokens.to(device).long()

        num_audio_tokens = audio_tokens.size(1)
        first_chunk_semantic_tokens = audio_tokens[:, :first_chunk_size]
        gen_speech = self.detokenizer.detokenize_streaming(
            first_chunk_semantic_tokens,
            is_final=(num_audio_tokens <= first_chunk_size),
            upsample_factor=4,
        )
        cache_speech_collection.append(gen_speech)

        if num_audio_tokens > first_chunk_size:
            res_semantic_tokens = audio_tokens[:, first_chunk_size:]
            for i in range(0, res_semantic_tokens.size(1), chunk_size):
                chunk_semantic_tokens = res_semantic_tokens[:, i : i + chunk_size]
                gen_speech = self.detokenizer.detokenize_streaming(
                    chunk_semantic_tokens,
                    upsample_factor=4,
                    is_final=(i + chunk_size >= res_semantic_tokens.size(1)),
                )
                cache_speech_collection.append(gen_speech)

        gen_speech = torch.cat(cache_speech_collection, dim=-1)
        return gen_speech

    def detokenize_text(self, text_tokens):
        valid_text_ids = []
        for x in text_tokens:
            if x == self.extra_tokens.kimia_text_eos:
                break
            valid_text_ids.append(x)
        return self.prompt_manager.text_tokenizer.decode(valid_text_ids)

    @torch.inference_mode()
    def _generate_loop_text_only(
        self,
        audio_input_ids: torch.Tensor,
        text_input_ids: torch.Tensor,
        max_new_tokens: int = 512,
        text_top_k: int = 5,
        text_temperature: float = 0.0,
        text_repetition_penalty: float = 1.0,
        text_repetition_window_size: int = 16,
        is_continuous_mask: torch.Tensor = None,
        whisper_input_feature: torch.Tensor = None,
        ced_input_feature: list[torch.Tensor] = None,
    ):
        """
        Text-only fast path: the first step may consume multi-modal features;
        subsequent steps feed text only and place blank audio tokens as placeholders.
        """
        device = next(self.alm.parameters()).device
        sampler = KimiASampler(
            audio_top_k=0, audio_temperature=0.0,
            audio_repetition_penalty=1.0, audio_repetition_window_size=64,
            text_top_k=text_top_k, text_temperature=text_temperature,
            text_repetition_penalty=text_repetition_penalty, text_repetition_window_size=text_repetition_window_size
        )

        last_pos = audio_input_ids.size(1) - 1
        decoder_audio_ids = audio_input_ids.to(device)
        decoder_text_ids  = text_input_ids.to(device)
        decoder_pos_ids   = torch.arange(0, decoder_audio_ids.shape[1], device=device).unsqueeze(0).long()
        decoder_cont_mask = is_continuous_mask.to(device) if is_continuous_mask is not None else None

        # Allow features on the first step
        whisper_feat = whisper_input_feature
        ced_feat     = ced_input_feature

        text_prev = []
        past_kv = None

        for i in range(max_new_tokens):
            out = self.alm.forward(
                audio_input_ids=decoder_audio_ids,
                text_input_ids=decoder_text_ids,
                whisper_input_feature=whisper_feat,
                ced_input_feature=ced_feat,
                is_continuous_mask=decoder_cont_mask,
                position_ids=decoder_pos_ids,
                past_key_values=past_kv,
                use_cache=True,
                return_dict=True,
            )
            text_logits = out.logits[1]
            past_kv = out.past_key_values

            next_token_text = sampler.sample_text_logits(
                text_logits,
                recent_tokens=(torch.tensor(text_prev, device=device) if len(text_prev) > 0 else None),
            )
            t = next_token_text.item()
            text_prev.append(t)

            if t == self.extra_tokens.kimia_text_eos:
                break

            # Next step only feeds the new text; audio gets blank placeholder
            decoder_audio_ids = torch.full_like(next_token_text, self.extra_tokens.kimia_text_blank).unsqueeze(1)
            decoder_text_ids  = next_token_text.unsqueeze(1)
            decoder_pos_ids   = torch.full((1, 1), last_pos + 1, device=device, dtype=torch.long)
            last_pos += 1

            # No continuous features after the first step
            whisper_feat = None
            ced_feat     = None
            decoder_cont_mask = None

        return [], text_prev

    def _normalize_features(self, whisper_features, ced_features, device):
        # whisper: Tensor or list[Tensor] -> Tensor[B,T,D]
        if whisper_features is None:
            whisper = None
        elif isinstance(whisper_features, torch.Tensor):
            whisper = whisper_features.to(device, dtype=torch.bfloat16)
        elif isinstance(whisper_features, (list, tuple)):
            parts = [t.to(device, dtype=torch.bfloat16) for t in whisper_features if isinstance(t, torch.Tensor)]
            if len(parts) == 0:
                whisper = None
            else:
                # Require same batch/hidden; concatenate along time dimension
                whisper = torch.cat(parts, dim=1)
        else:
            whisper = None

        # ced:
        #   (a) already a tuple (f4, f8, flast) -> keep as is
        #   (b) list/tuple of per-chunk tuples -> concatenate each level then pack to a tuple
        def _as_tuple3(x):
            if x is None:
                return None
            if isinstance(x, (list, tuple)):
                # case (a)
                if len(x) == 3 and all(isinstance(t, torch.Tensor) for t in x):
                    return tuple(t.to(device, dtype=torch.bfloat16) for t in x)
                # case (b)
                if all(isinstance(t, (list, tuple)) and len(t) == 3 for t in x):
                    f4  = [t[0].to(device, dtype=torch.bfloat16) for t in x]
                    f8  = [t[1].to(device, dtype=torch.bfloat16) for t in x]
                    fl  = [t[2].to(device, dtype=torch.bfloat16) for t in x]
                    return (torch.cat(f4, dim=1), torch.cat(f8, dim=1), torch.cat(fl, dim=1))
            if isinstance(x, torch.Tensor):
                # Edge compatibility: only last layer provided -> repeat thrice as a fallback
                return (x.to(device, dtype=torch.bfloat16),) * 3
            return None

        ced = _as_tuple3(ced_features)
        return whisper, ced
