import logging
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np

import time


logger = logging.getLogger(__name__)


@dataclass
class Hypothesis:
    """Batchfied/Vectorized hypothesis data type."""

    yseq: torch.Tensor = torch.tensor([]).long()  # (batch, maxlen)
    self_kv: torch.Tensor = torch.zeros((0, 1, 0, 0))  # (layer, batch, ctx, dim)
    cross_kv: torch.Tensor = torch.zeros((0, 1, 0, 0))  # (layer, batch, ctx, dim)
    hs: torch.Tensor = torch.tensor([])  # (batch, maxlen, adim)

    def __len__(self) -> int:
        """Return a batch size."""
        return len(self.yseq)

    def extend(self, hs):
        self.hs = hs.unsqueeze(0)
        self.cross_kv = self.cross_kv[:, :, :0, :]


class NGramRepeatBlockProcessor:
    """Blocks repeated generation of n-grams of a specified size."""

    _ngram_size: int

    def __init__(self, ngram_size: int) -> None:
        """
        :param ngram_size:
            The size of repeated n-grams to block.
        """
        if ngram_size == 0:
            raise ValueError("`ngram_size` must be greater than 0.")

        self._ngram_size = ngram_size

    def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None:
        ngram_size = self._ngram_size

        seq_len = seqs.size(1) + 1

        if ngram_size >= seq_len:
            return

        # This is an edge case where we do not allow any of the previous values.
        if ngram_size == 1:
            # (N, 1)
            mask = torch.arange(seqs.size(0), device=probs.device).unsqueeze(1)

            probs[mask, seqs] = -torch.inf if lprob else 0

            return

        # (N, G - 1)
        ngram_prefixes = seqs[:, -ngram_size + 1 :]

        for i in range(seq_len - ngram_size):
            # (N, G - 1)
            mask = seqs[:, i : i + ngram_size - 1] - ngram_prefixes

            # (N)
            mask = mask.any(dim=-1)

            mask.logical_not_()

            probs[mask, seqs[mask, i + ngram_size - 1]] = -torch.inf if lprob else 0


class GreedySearch(torch.nn.Module):
    """Batch beam search implementation."""

    def __init__(
        self,
        model,
        d_feature=1024,
        block_size: int = 32,
        wait_k: int = 3,
        forced_lang_id=None,
    ):

        super().__init__()
        self.model = model
        self.sos = model.sos
        self.eos = model.eos
        self.d_feature = d_feature

        self.block_size = block_size
        self.wait_k = wait_k

        self.forced_lang_id = forced_lang_id
        self.decoder = model.ast_decoder
        self.encoder = model.ast_encoder
        self.trunc_emb = model.trunc_emb
        self.logits_processor = NGramRepeatBlockProcessor(3)

    @staticmethod
    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
        if len(xs.shape) == 2:
            x = x.unsqueeze(0)
        return torch.cat((xs, x), dim=-1)

    def init_hyp(self, device, dtype) -> Hypothesis:
        # device, dtype = x.device, x.dtype
        primer = [self.sos] if self.forced_lang_id is None else [self.sos, self.forced_lang_id]

        return Hypothesis(
            yseq=torch.tensor(primer, device=device).long().unsqueeze(0),
            hs=torch.zeros(1, 0, 0, device=device, dtype=dtype),
            self_kv=torch.zeros(0, 0, 0, 0, device=device, dtype=dtype),
            cross_kv=torch.zeros(0, 0, 0, 0, device=device, dtype=dtype),
        )

    def search(
        self,
        running_hyps: Hypothesis,
    ):

        forward_fn = self.decoder.inference
        n_batch = len(running_hyps)
        
        next_token_logits, new_self_caches, new_cross_caches, simul_score = forward_fn(
            running_hyps.yseq,
            running_hyps.hs.expand(n_batch, -1, -1),
            running_hyps.self_kv,
            running_hyps.cross_kv,
        )

        weighted_scores = F.log_softmax(next_token_logits, dim=-1).squeeze(1)

        self.logits_processor(running_hyps.yseq, weighted_scores, lprob=True)

        new_token_id = weighted_scores.flatten().argmax()

        new_hyps = Hypothesis(
            yseq=self.append_token(running_hyps.yseq, new_token_id),
            hs=running_hyps.hs,
            self_kv=new_self_caches,
            cross_kv=new_cross_caches,
        )

        return new_hyps, simul_score.flatten()

    def forward(
        self,
        asr_out: torch.Tensor,
        simul_threhold: float = 0.5,
        maxlenratio: float = 0.0,
    ):

        def add_compute_time(compute_times, times, current_cost):
            start_id, end_id = len(compute_times), len(times)
            for i in range(start_id, end_id):
                compute_times.append(times[i]+current_cost)
            return compute_times

        if self.block_size is not None:
            cur_end_frame = 0
        else:
            cur_end_frame = asr_out.shape[0]
        # cur_end_frame = asr_out.shape[0]

        process_idx = 0

        # set length bounds
        if maxlenratio == 0:
            maxlen = min(asr_out.shape[0], 150)
        else:
            maxlen = max(1, int(maxlenratio * asr_out.size(0)))

        # main loop of prefix search
        running_hyps = self.init_hyp(asr_out.device, asr_out.dtype)
        prev_hyps = []
        times = [None, None]
        compute_times = [None, None]

        while True:
            move_to_next_block = False
            if self.block_size and cur_end_frame + int(self.block_size) < asr_out.shape[0]:
                cur_end_frame += int(self.block_size)
            else:
                cur_end_frame = asr_out.shape[0]

            if cur_end_frame < asr_out.shape[0]:
                h = asr_out.narrow(0, 0, cur_end_frame)
                is_final = False
            else:
                h = asr_out
                is_final = True

            is_trunc = torch.tensor([not is_final])
            is_trunc = is_trunc.long().to(h.device)
            trunc_embedding = self.trunc_emb(is_trunc)
            h = torch.cat([trunc_embedding, h], dim=0)
            h = self.encoder(h.unsqueeze(0), None)[0]

            running_hyps.extend(h)
            if not is_final:
                if running_hyps.self_kv.shape[2] + 1 == running_hyps.yseq.shape[1]:
                    running_hyps.self_kv = running_hyps.self_kv[:, :, :-5]
                running_hyps.cross_kv = running_hyps.cross_kv[:, :, :0]
                # self.decoder.enable_streaming()
            else:
                running_hyps.self_kv = running_hyps.self_kv[:, :, :0]
                running_hyps.cross_kv = running_hyps.cross_kv[:, :, :0]
                # self.decoder.disable_streaming()

            time_previous = time.perf_counter()
            while process_idx < maxlen:
                best, simul_score = self.search(running_hyps)

                if process_idx == maxlen - 1:
                    # end decoding
                    compute_times = add_compute_time(compute_times, times, time.perf_counter()-time_previous)
                    return running_hyps, times, compute_times
                is_local_eos = best.yseq[0, -1] == self.eos

                if is_local_eos:
                    if cur_end_frame < asr_out.shape[0]:
                        move_to_next_block = True
                    else:
                        compute_times = add_compute_time(compute_times, times, time.perf_counter()-time_previous)
                        return running_hyps, times, compute_times

                if not is_final and simul_score > simul_threhold:
                    move_to_next_block = True

                # if cur_end_frame == asr_out.shape[0] and simul_score > simul_threhold:
                #     if process_idx >=3:
                #         return running_hyps, times

                if move_to_next_block:
                    # if process_idx > 1 and len(prev_hyps) > 0:
                    #     times.pop(-1)
                    #     running_hyps = prev_hyps
                    #     process_idx -= 1
                    #     prev_hyps = []
                    break

                times.append(cur_end_frame * 0.02)
                # prev_hyps = running_hyps
                running_hyps = best
                process_idx += 1
                
            compute_times = add_compute_time(compute_times, times, time.perf_counter()-time_previous)
