import logging
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from dataclasses import dataclass, asdict, fields

import torch
import torch.nn.functional as F
import numpy as np


def get_unique_ids(seq_tensor):
    seen = set()
    ids = []
    for i, t in enumerate(seq_tensor):
        if t not in seen:
            seen.add(t)
            ids.append(i)
    return torch.tensor(ids, device=seq_tensor.device).long()



def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):

    if len(ended_hyps) == 0:
        return False
    count = 0
    best_hyp = sorted(ended_hyps, key=lambda x: x.score, reverse=True)[0]
    for m in range(M):
        # get ended_hyps with their length is i - m
        hyp_length = i - m
        hyps_same_length = [x for x in ended_hyps if len(x.yseq) == hyp_length]
        if len(hyps_same_length) > 0:
            best_hyp_same_length = sorted(
                hyps_same_length, key=lambda x: x.score, reverse=True
            )[0]
            if best_hyp_same_length.score - best_hyp.score < D_end:
                count += 1

    if count == M:
        return True
    else:
        return False

logger = logging.getLogger(__name__)

@dataclass
class Hypothesis:
    """Batchfied/Vectorized hypothesis data type."""

    yseq: torch.Tensor = torch.tensor([]).long()  # (batch, maxlen)
    score: torch.Tensor = torch.zeros((1))  # (batch,)
    self_kv: torch.Tensor = torch.zeros((0, 1, 0, 0)) # (layer, batch, ctx, dim)
    cross_kv: torch.Tensor = torch.zeros((0, 1, 0, 0)) # (layer, batch, ctx, dim)
    cross_kv_lite: torch.Tensor = torch.zeros((0, 1, 0, 0)) # (layer, batch, ctx, dim)
    hs: torch.Tensor = torch.tensor([])   # (batch, maxlen, adim)

    def __len__(self) -> int:
        """Return a batch size."""
        return len(self.score)

    def select(self, ids):
        return Hypothesis(
            yseq = self.yseq[ids],
            score = self.score[ids],
            hs=self.hs,
            self_kv = self.self_kv[:, ids],
            cross_kv = self.cross_kv[:, ids],
            cross_kv_lite = self.cross_kv_lite[:, ids],
        )

    def extend(self, hs):
        self.hs = hs.unsqueeze(0)
        # prev_ctx_len = self.cross_kv.shape[2]
        # if hs.shape[0] > prev_ctx_len:
        #     self.hs=hs[prev_ctx_len:].unsqueeze(0)




class BeamSearch(torch.nn.Module):
    """Batch beam search implementation."""
    def __init__(
        self,
        model,
        beam_size: int,
        d_feature = 1024,
        block_size: int = 32,
        wait_k:int = 3,
        normalize_length: bool = True,
        forced_lang_id = None,
        task='transcribe',
        use_lite = False,
    ):

        super().__init__()
        # set scorers
        self.model = model
        self.sos = model.sos
        self.eos = model.eos
        self.d_feature = d_feature

        self.block_size = block_size
        self.wait_k = wait_k

        self.beam_size = beam_size
        self.current_size = None
        self.normalize_length = normalize_length
        self.forced_lang_id = forced_lang_id
        self.decoder = model.asr_decoder if task=='transcribe' else model.ast_decoder
        self.use_lite = use_lite
        if task == 'transcribe':
            assert not self.use_lite

    @staticmethod
    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
        return torch.cat((xs, x))


    def batch_beam(
        self, weighted_scores: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        n_vocab = weighted_scores.shape[-1]
        top_ids = weighted_scores.reshape(-1).topk(self.current_size)[1]

        prev_hyp_ids = torch.div(top_ids, n_vocab, rounding_mode="trunc")

        new_token_ids = top_ids % n_vocab
        return prev_hyp_ids, new_token_ids


    def init_hyp(self, x: torch.Tensor) -> Hypothesis:
        device, dtype = x.device, x.dtype
        primer = [self.sos] if self.forced_lang_id is None else [self.sos, self.forced_lang_id]

        return Hypothesis(
            yseq=torch.tensor(primer, device=device).long().unsqueeze(0),
            score= torch.zeros(1, device=device, dtype=dtype),
            hs = torch.zeros(1, 0, x.shape[-1], device=device, dtype=dtype),
            self_kv = torch.zeros(0, 0, 0, 0, device=device, dtype=dtype),
            cross_kv = torch.zeros(0, 0, 0, 0, device=device, dtype=dtype),
            cross_kv_lite = torch.zeros(0, 1, 0, 0, device=device, dtype=dtype),
        )



    def search(
        self,
        running_hyps: Hypothesis,
        x: torch.Tensor,
        use_lite: bool = False,
    ) -> Hypothesis:

        forward_fn = self.decoder.forward_cache_lite if use_lite else self.decoder.forward_cache

        n_batch = len(running_hyps)
        next_token_logits, new_self_caches, new_cross_caches = forward_fn(
            running_hyps.yseq,
            running_hyps.hs.expand(n_batch, -1, -1),
            running_hyps.self_kv,
            running_hyps.cross_kv_lite if use_lite else running_hyps.cross_kv,
            )

        weighted_scores = F.log_softmax(next_token_logits, dim=-1)



        if weighted_scores.shape[1] == 2:
            assert not use_lite
            revice_score, weighted_scores = weighted_scores[:, 0],  weighted_scores[:, 1]

            if running_hyps.yseq[0, -1] not in torch.topk(revice_score[0], 3)[1]:
                weighted_scores = revice_score
                weighted_scores += running_hyps.score.to(dtype=x.dtype, device=x.device).unsqueeze(1)

                unique_prev_ids = get_unique_ids(running_hyps.yseq[:, :-1])
                for i in range(n_batch):
                    if i not in unique_prev_ids:
                        weighted_scores[i] -= 100  # remove repeated beam

                prev_hyp_id, new_token_id = self.batch_beam(weighted_scores)
                return Hypothesis(
                    yseq = torch.stack([self.append_token(running_hyps.yseq[i, :-1], j) for i, j in zip(prev_hyp_id, new_token_id)]),
                    score = torch.tensor([weighted_scores[i, j] for i, j in zip(prev_hyp_id, new_token_id)]),
                    hs = running_hyps.hs,
                    self_kv = torch.stack([new_self_caches[:, i, :-1] for i in prev_hyp_id], dim=1),
                    cross_kv = torch.stack([new_cross_caches[:, i, :-1] for i in prev_hyp_id], dim=1),
                    cross_kv_lite = torch.stack([running_hyps.cross_kv_lite[:, i, :-1] for i in prev_hyp_id], dim=1),
                )

            last_ids = running_hyps.yseq[:, -1]
            revice_score = revice_score[torch.arange(n_batch), last_ids]
            weighted_scores += revice_score.unsqueeze(1)
        else:
            weighted_scores = weighted_scores.squeeze(1)

        weighted_scores += running_hyps.score.to(dtype=x.dtype, device=x.device).unsqueeze(1)



        prev_hyp_id, new_token_id = self.batch_beam(weighted_scores)
        if use_lite: # update self_kv_lite
            new_score = torch.tensor([running_hyps.score[i] for i in prev_hyp_id]) 
            new_self_kv = torch.stack([running_hyps.self_kv[:, i] for i in prev_hyp_id], dim=1)
            new_cross_kv = torch.stack([running_hyps.cross_kv[:, i] for i in prev_hyp_id], dim=1)
            new_cross_kv_lite = torch.stack([new_cross_caches[:, i] for i in prev_hyp_id], dim=1)
        else: # update self_kv and score
            new_score = torch.tensor([weighted_scores[i, j] for i, j in zip(prev_hyp_id, new_token_id)])
            new_self_kv = torch.stack([new_self_caches[:, i] for i in prev_hyp_id], dim=1)
            new_cross_kv = torch.stack([new_cross_caches[:, i] for i in prev_hyp_id], dim=1)
            new_cross_kv_lite = torch.stack([running_hyps.cross_kv_lite[:, i] for i in prev_hyp_id], dim=1)
        
        new_hyps = Hypothesis(
            yseq = torch.stack([self.append_token(running_hyps.yseq[i], j) for i, j in zip(prev_hyp_id, new_token_id)]),
            score = new_score,
            hs = running_hyps.hs,
            self_kv = new_self_kv,
            cross_kv = new_cross_kv,
            cross_kv_lite = new_cross_kv_lite,
        )
        return new_hyps


    def forward(
        self,
        asr_out: torch.Tensor,
        cif_score: torch.Tensor = None,
        maxlenratio: float = 0.0,
    ):

        self.conservative = True  # always true
        self.current_size = self.beam_size

        if self.block_size is not None:
            cur_end_frame = 0
        else:
            cur_end_frame = asr_out.shape[0]
        process_idx = 0
        if cur_end_frame < asr_out.shape[0]:
            h = asr_out.narrow(0, 0, cur_end_frame)
        else:
            h = asr_out

        # set length bounds
        if maxlenratio == 0:
            maxlen = min(asr_out.shape[0], 300)
        else:
            maxlen = max(1, int(maxlenratio * asr_out.size(0)))

        # main loop of prefix search
        running_hyps = self.init_hyp(h)
        prev_hyps = []
        ended_hyps = []

        continue_decode = True

        while continue_decode:
            move_to_next_block = False
            if (
                self.block_size
                and cur_end_frame + int(self.block_size)
                < asr_out.shape[0]
            ):
                cur_end_frame += int(self.block_size)
            else:
                cur_end_frame = asr_out.shape[0]

            if cur_end_frame < asr_out.shape[0]:
                h = asr_out.narrow(0, 0, cur_end_frame)
                n_speech_tokens = running_hyps.yseq.shape[1] - 1
                keep_decoding = torch.all(n_speech_tokens < cif_score[:cur_end_frame].sum() - self.wait_k)
            else:
                h = asr_out
                keep_decoding = True

            running_hyps.extend(h)

            while process_idx < maxlen and keep_decoding:
                use_lite = self.use_lite and process_idx >= -1 and (process_idx % 2) == 1
                best = self.search(running_hyps, h, use_lite=use_lite)

                if process_idx == maxlen - 1:
                    # end decoding
                    running_hyps = self.post_process(
                        process_idx, maxlen, best
                    )
                local_ended_hyps = []
                is_local_eos = (
                    best.yseq[:, -1] == self.eos
                )
                for i in range(is_local_eos.shape[0]):
                    if is_local_eos[i]:
                        hyp = best.select(i)
                        local_ended_hyps.append(hyp)

                if maxlenratio == 0.0 and end_detect(
                    local_ended_hyps, process_idx
                ):
                    continue_decode = False
                    break
                if len(local_ended_hyps) > 0 and cur_end_frame < asr_out.shape[0]:
                    move_to_next_block = True

                if move_to_next_block:
                    if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
                        running_hyps = prev_hyps
                        process_idx -= 1
                        prev_hyps = []
                    break

                prev_hyps = running_hyps
                running_hyps = self.post_process(
                    process_idx, maxlen, best
                )

                if cur_end_frame >= asr_out.shape[0]:
                    for hyp in local_ended_hyps:
                        ended_hyps.append(hyp)
                        self.current_size -= 1

                if len(running_hyps) == 0:
                    continue_decode = False
                    break
                process_idx += 1

                if cur_end_frame < asr_out.shape[0]:
                    n_speech_tokens = running_hyps.yseq.shape[1] - 1
                    if cif_score is not None and torch.any(n_speech_tokens >= cif_score[:cur_end_frame].sum() - self.wait_k):
                        break


        if self.normalize_length:
            nbest_hyps = sorted(
                ended_hyps, key=lambda x: x.score / (len(x.yseq) - 1), reverse=True
            )
        else:
            nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)

        best = nbest_hyps[0]

        return best


    def post_process(
        self,
        i: int,
        maxlen: int,
        running_hyps: Hypothesis,
    ) -> Hypothesis:

        n_batch = running_hyps.yseq.shape[0]
        if i == maxlen - 1:
            yseq_eos = torch.cat(
                (
                    running_hyps.yseq,
                    torch.full(
                        (n_batch, 1),
                        self.eos,
                        device=running_hyps.yseq.device,
                        dtype=torch.int64,
                    ),
                ),
                1,
            )
            running_hyps.yseq.resize_as_(yseq_eos)
            running_hyps.yseq[:] = yseq_eos
        is_eos = (
            running_hyps.yseq[:, -1] == self.eos
        )
        remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1).cpu()
        return running_hyps.select(remained_ids)