from typing import Dict, Optional, Callable, List
import torch
from torch import nn
from transformers import Qwen2ForCausalLM
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
import torch.nn.functional as F
from transformers import AutoTokenizer
import math
from .utils import th_accuracy, make_pad_mask, ras_sampling
from .simul import SimulDecoder


IGNORE_ID = -1


class LabelSmoothingLoss(nn.Module):
    def __init__(self, size: int, padding_idx: int, smoothing: float, normalize_length: bool = False):
        """Construct an LabelSmoothingLoss object."""
        super(LabelSmoothingLoss, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="none")
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.normalize_length = normalize_length

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        true_dist = torch.zeros_like(x)
        true_dist.fill_(self.smoothing / (self.size - 1))
        ignore = target == self.padding_idx  # (B,)
        total = len(target) - ignore.sum().item()
        target = target.masked_fill(ignore, 0)  # avoid -1 index
        true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom


class Qwen2Encoder(torch.nn.Module):
    def __init__(self, pretrain_path):
        super().__init__()
        self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)

    def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
        T = xs.size(1)
        masks = ~make_pad_mask(xs_lens, T)
        outs = self.model(
            inputs_embeds=xs,
            attention_mask=masks,
            output_hidden_states=True,
            return_dict=True,
        )
        return outs.hidden_states[-1], masks.unsqueeze(1)

    def forward_one_step(self, xs, masks, cache=None):
        input_masks = masks[:, -1, :]
        outs = self.model(
            inputs_embeds=xs,
            attention_mask=input_masks,
            output_hidden_states=True,
            return_dict=True,
            use_cache=True,
            past_key_values=cache,
        )
        xs = outs.hidden_states[-1]
        new_cache = outs.past_key_values
        return xs, new_cache


class Qwen2LM(nn.Module):
    def __init__(
        self,
        qwen_pretrain_path: str = None,
        llm_input_size: int = 896,
        llm_output_size: int = 896,
        speech_token_size: int = 6561,
        llm: torch.nn.Module = None,
        length_normalized_loss: bool = True,
        lsm_weight: float = 0.0,
        mix_ratio: List[int] = [5, 15],
    ):
        torch.nn.Module.__init__(self)
        self.llm_input_size = llm_input_size
        self.llm_output_size = llm_output_size
        self.speech_token_size = speech_token_size

        # 2. build speech token language model related modules
        self.sos_eos = 0
        self.task_id = 1
        self.fill_token = 2

        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
        if llm is None:
            llm = Qwen2Encoder(qwen_pretrain_path)
        self.llm = llm
        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
        n_state = llm_input_size
        self.simul_blocks = SimulDecoder(n_state=n_state, n_head=n_state // 64, n_enc_dim=n_state, n_layer=6)

        self.criterion_ce = LabelSmoothingLoss(
            size=speech_token_size + 3,
            padding_idx=IGNORE_ID,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )

        # 3. [Optional] build speech token related modules
        self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)

        self.mix_ratio = mix_ratio
        self.load_state_dict(torch.load(qwen_pretrain_path + "/../llm.pt"), strict=False)

    def prepare_lm_input_target(
        self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len
    ):
        lm_target, lm_input = [], []
        text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
        speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
        text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
        speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
        for i in range(len(text_token)):
            this_lm_target = torch.tensor(speech_token[i].tolist() + [self.speech_token_size])
            this_lm_input = torch.concat(
                [
                    self.llm_embedding.weight[self.sos_eos].reshape(1, -1),
                    text_token_emb[i],
                    self.llm_embedding.weight[self.task_id].reshape(1, -1),
                    speech_token_emb[i],
                ],
                dim=0,
            )
            lm_target.append(this_lm_target)
            lm_input.append(this_lm_input)
        lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
        lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
        lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
        return lm_target, lm_input, lm_input_len

    def _forward(
        self,
        text_token: torch.Tensor,
        text_token_len: torch.Tensor,
        speech_token: torch.Tensor,
        speech_token_len: torch.Tensor,
    ):
        # 1. encode text_token
        text_token_emb = self.llm.model.model.embed_tokens(text_token)

        # 2. encode speech_token
        speech_token_emb = self.speech_embedding(speech_token)

        # 3. prepare llm_input/target
        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(
            text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len
        )
        lm_target = lm_target.to(text_token.device)

        # 4. run lm forward
        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(text_token.device))

        out_text = pad_sequence(
            [lm_output[i, : text_token_len[i] + 1] for i in range(len(text_token))],
            batch_first=True,
            padding_value=IGNORE_ID,
        )
        out_speech = pad_sequence(
            [lm_output[i, text_token_len[i] + 1 : lm_input_len[i]] for i in range(len(text_token))],
            batch_first=True,
            padding_value=IGNORE_ID,
        )
        return out_text, out_speech, lm_target

    def _criterion(self, logits: torch.Tensor, lm_target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            logits: (B, T, C)
            lm_target: (B, T)
        """
        # 1. compute loss
        loss = self.criterion_ce(logits, lm_target)
        # 2. compute accuracy
        acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
        return {"loss": loss, "acc": acc, "logits": logits}

    def _kld_loss(self, logits_stu, logits_tea, token_mask=None, reduction="mean"):
        b, t, _ = logits_stu.shape
        logits_stu = logits_stu.reshape((-1, logits_stu.shape[-1])).float()
        logits_tea = logits_tea.reshape((-1, logits_tea.shape[-1])).float().detach()
        loss = F.kl_div(
            F.log_softmax(logits_stu, dim=-1),
            F.softmax(logits_tea, dim=-1),
            reduction="none",
        ).sum(-1)
        loss = loss.reshape((b, t))
        if token_mask is not None:
            token_mask = token_mask.reshape((b, t))
            loss = loss.masked_fill(~token_mask, 0.0)
            num_token = token_mask.sum()
        else:
            num_token = b * t
        loss = loss.clamp(max=10, min=0)
        if reduction == "mean":
            loss = loss.sum() / num_token
        elif reduction == "sum":
            loss = loss.sum()
        return loss

    def forward(
        self,
        batch: dict,
    ) -> Dict[str, Optional[torch.Tensor]]:
        """
        Args:
            text: (B, L, D)
            text_lengths: (B,)
            audio: (B, T, N) or (B, T)
            audio_lengths: (B,)
        """

        out_text_trunc, out_speech_trunc, lm_target = self._forward(
            batch["tunc_text"],
            batch["tunc_text_len"],
            batch["speech"],
            batch["speech_len"],
        )

        out_text, out_speech, _ = self._forward(
            batch["text"],
            batch["text_len"],
            batch["speech"],
            batch["speech_len"],
        )

        simul_out, route_scores = self.simul_blocks(
            out_speech_trunc,
            out_text_trunc,
            batch["tunc_text_len"] + 1,
            out_text,
            batch["text_len"] + 1,
        )
        route_scores = route_scores.masked_fill((lm_target == IGNORE_ID).expand_as(route_scores), 0.0)
        ave_score = torch.sum(route_scores) / torch.sum(route_scores > 0.0)
        sent_ave_score = torch.sum(route_scores, dim=-1) / torch.sum(route_scores > 0.0, dim=-1)

        lm_target_trunc = lm_target.masked_fill(route_scores > 0.3, IGNORE_ID)

        logits_trunc = self.llm_decoder(out_speech_trunc)
        logits_full = self.llm_decoder(out_speech)
        logits_simul = self.llm_decoder(simul_out)

        loss_norm = F.smooth_l1_loss(sent_ave_score, batch["norm_target"], beta=0.05)
        loss_norm = loss_norm + F.smooth_l1_loss(route_scores[:, 0], torch.zeros_like(route_scores[:, 0]), beta=0.05)

        smooth_norm = route_scores[:, :-1] - route_scores[:, 1:]
        if (smooth_norm > 0.1).sum() > 0:
            loss_norm = loss_norm + smooth_norm[smooth_norm > 0.1].sum() / torch.sum(route_scores > 0.0)

        kd_loss = self._kld_loss(logits_trunc, logits_full, lm_target_trunc != IGNORE_ID)
        loss_full_selected = self.criterion_ce(logits_full, lm_target_trunc)

        return {
            "full": self._criterion(logits_full, lm_target),
            "trunc": self._criterion(logits_trunc, lm_target_trunc),
            "simul": self._criterion(logits_simul, lm_target),
            "loss_norm": loss_norm,
            "loss_kd": kd_loss,
            "route_score": ave_score,
            "loss_full_selected": loss_full_selected,
        }

    @torch.inference_mode()
    def inference(
        self,
        text: torch.Tensor,
        prompt_text: torch.Tensor,
        prompt_speech_token: torch.Tensor = None,
        sampling: int = 25,
        max_token_text_ratio: float = 20000,
        min_token_text_ratio: float = 2,
        is_final=False,
        threshold=0.05,
    ):
        device = text.device
        text_len = text.shape[1]
        prompt_text_len = prompt_text.shape[1]
        text = torch.concat([prompt_text, text], dim=1)
        if prompt_speech_token is not None:
            prev_tokens = prompt_speech_token.flatten().cpu().tolist()
        else:
            prev_tokens = []
        text_len += prompt_text_len
        text = self.llm.model.model.embed_tokens(text)
        # 3. concat llm_input
        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
        if prompt_speech_token is not None:
            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
        else:
            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)

        # 4. cal min/max_length
        if not is_final:
            min_len = max(int((text_len - 4) * min_token_text_ratio) - len(prev_tokens), 1)
        else:
            min_len = int(text_len * min_token_text_ratio) - len(prev_tokens)
        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)

        # 5. step by step decode
        out_tokens = []
        cache = None
        for i in range(max_len):
            y_pred, cache = self.llm.forward_one_step(
                lm_input,
                masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(
                    torch.bool
                ),
                cache=cache,
            )
            r_score = torch.sigmoid(self.simul_blocks.router(y_pred[:, -1])).flatten()
            if is_final:
                # min_len -= min_token_text_ratio
                threshold = 0.9
            if r_score > threshold and len(out_tokens) >= min_len:
                break
                # if i == 0:
                #     yield -1
                #     break
                # else:
                #     break
            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)

            # if len(out_tokens) > 0:
            # for t in (prev_tokens + out_tokens)[-10:]:
            #     logp[:, t] = -10000.0
            # top_ids = logp.argmax(-1).item()

            top_ids = self.sampling_ids(
                logp.squeeze(dim=0), (prev_tokens + out_tokens), sampling, ignore_eos=not is_final
            ).item()

            if top_ids == self.speech_token_size:
                break
            if top_ids > self.speech_token_size:
                continue
            # in stream mode, yield token one by one
            yield top_ids
            out_tokens.append(top_ids)
            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

    def simulate_stream_inference(
        self,
        text,
        tokenizer,
        chunk_size=4,
        threshold=0.05,
        prompt_speech_text="",
        prompt_speech_token=[],
        min_token_text_ratio=0.0,
    ):
        device = next(self.parameters()).device
        if isinstance(tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        if prompt_speech_text != "":
            prompt_speech_text_token = tokenizer([prompt_speech_text], return_tensors="pt")["input_ids"]
        else:
            prompt_speech_text_token = None
        text_tokens = tokenizer([text], return_tensors="pt")
        eos = torch.full([1, 1], tokenizer.eos_token_id).long()
        if prompt_speech_text_token is not None:
            out_speech_token = prompt_speech_token["flow_prompt_speech_token"].cpu().flatten().tolist()
        else:
            out_speech_token = []
        text_tokens = torch.cat([text_tokens["input_ids"], eos], dim=1)
        num_chunk = math.ceil(text_tokens.shape[1] / chunk_size)
        print(tokenizer.decode(text_tokens[0, :chunk_size]))
        for i in range(0, num_chunk):
            chunk_output = []
            speech_prompt = (
                None if len(out_speech_token) == 0 else torch.Tensor(out_speech_token).long().unsqueeze(0).to(device)
            )
            if prompt_speech_text_token is not None:
                prompt_text = torch.cat([prompt_speech_text_token, text_tokens[:, : i * chunk_size]], dim=1).to(device)
            else:
                prompt_text = text_tokens[:, : i * chunk_size].to(device)
            text_chunk = text_tokens[:, i * chunk_size : (i + 1) * chunk_size].to(device)
            print(tokenizer.decode(text_chunk[0]))
            for token in self.inference(
                text_chunk,
                prompt_text,
                speech_prompt,
                threshold=threshold,
                min_token_text_ratio=min_token_text_ratio,
                is_final=(i == num_chunk - 1),
            ):
                chunk_output.append(token)
                # print(token)
                # if token == -1:
                #     out_speech_token = out_speech_token[:-10]
                # else:
                #     chunk_output.append(token)
            print(chunk_output)
            out_speech_token.extend(chunk_output)
        return out_speech_token

    def stream_inference(
        self,
        contents,
        tokenizer,
        prompt_speech_text="",
        prompt_speech_token=[],
        threshold=0.3,
        min_token_text_ratio=4.0,
    ):
        device = next(self.parameters()).device
        if isinstance(tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        if prompt_speech_text != "":
            prompt_text_token = tokenizer([prompt_speech_text], return_tensors="pt")["input_ids"]
        else:
            prompt_text_token = torch.zeros(1, 0).long().to(device)

        if prompt_text_token.shape[1] > 0:
            out_speech_token = prompt_speech_token["flow_prompt_speech_token"].cpu().flatten().tolist()
        else:
            out_speech_token = []
        prompt_text_token = prompt_text_token.to(device)
        latencys = []
        for i, content in enumerate(contents):
            chunk_output = []
            new_text_tokens = tokenizer([content["text"]], return_tensors="pt")["input_ids"]
            if i == len(contents) - 1:
                eos = torch.full([1, 1], tokenizer.eos_token_id).long()
                new_text_tokens = torch.cat([new_text_tokens, eos], dim=1)
            speech_prompt = (
                None if len(out_speech_token) == 0 else torch.Tensor(out_speech_token).long().unsqueeze(0).to(device)
            )
            new_text_tokens = new_text_tokens.to(device)

            for token in self.inference(
                new_text_tokens,
                prompt_text_token,
                speech_prompt,
                threshold=threshold,
                min_token_text_ratio=min_token_text_ratio,
                is_final=i == len(contents) - 1,
            ):
                chunk_output.append(token)
                latencys.append(content["latency"])

            # print(chunk_output)
            out_speech_token.extend(chunk_output)
            prompt_text_token = torch.cat([prompt_text_token, new_text_tokens], dim=-1)
        return out_speech_token, latencys

    def sampling_ids(
        self,
        weighted_scores: torch.Tensor,
        decoded_tokens: List,
        sampling: int,
        ignore_eos: bool = True,
    ):
        num_trials, max_trials = 0, 100
        while True:
            top_ids = ras_sampling(weighted_scores, decoded_tokens, sampling)
            if (not ignore_eos) or (self.speech_token_size not in top_ids):
                break
            num_trials += 1
            if num_trials > max_trials:
                raise RuntimeError(
                    "sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!".format(
                        max_trials
                    )
                )
        return top_ids
