# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch

from cosmos1.models.autoregressive.networks.transformer import Transformer
from torch.nn.attention import SDPBackend, sdpa_kernel
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None

def sample_top_p(logits, temperature, top_p, return_probs: bool = False):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        logits (torch.Tensor): Logits of the probability distribution.
        temperature (float): Temperature for sampling.
        top_p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
    """
    probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
    # Sort the probabilities in descending order and get their indices.
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    # Compute the cumulative sum of the sorted probabilities.
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    # Create a mask where the cumulative probability exceeds the threshold p.
    mask = probs_sum - probs_sort > top_p
    # Set the probabilities that exceed the threshold to 0.
    probs_sort[mask] = 0.0
    # Renormalize the remaining probabilities so they sum to 1.
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # Sample from the renormalized probability distribution.
    # next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
    # Gather the indices of the sampled tokens.
    next_token = torch.gather(probs_idx, -1, next_token)
    if return_probs:
        # Initialize a tensor for unsorted probabilities
        probs_unsorted = torch.zeros_like(probs_sort)
        # Scatter the sorted probabilities back to their original order
        probs_unsorted.scatter_(-1, probs_idx, probs_sort)
    else:
        probs_unsorted = None
    return next_token, probs_unsorted

def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
    """
    Multinomial sampling without a cuda synchronization.
    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)

def logits_to_probs(
    logits,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """
    Sample from the logits using top-k sampling.
    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    # logits: [batch_size, seq_len, vocab_size]
    if temperature == 0.0:
        idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        probs = None
    else:
        probs = logits_to_probs(logits[:, -1, :], temperature, top_k)
        idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs

def prefill(
    model: Transformer,
    input_pos: torch.Tensor,
    tokens: torch.Tensor = None,
    token_embeddings: torch.Tensor = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    **kwargs,
) -> torch.Tensor:
    logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs)
    # import pdb; pdb.set_trace()
    # Only top-p or top-k can be provided
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    if top_p is not None:
        return sample_top_p(logits, temperature=temperature, top_p=top_p)[0]
    else:
        return sample_top_k(logits, temperature=temperature, top_k=top_k)[0]

# TAG
def decode_one_token(
    model: Transformer,
    tokens: torch.Tensor,
    input_pos: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Decode a single token from the autoregressive model.
    """
    # import pdb; pdb.set_trace()
    logits = model(tokens=tokens, input_pos=input_pos, **kwargs)
    if top_p is not None:
        return sample_top_p(logits, temperature=temperature, top_p=top_p)
    else:
        return sample_top_k(logits, temperature=temperature, top_k=top_k)

# TAG
def decode_n_tokens(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    num_new_tokens: int,
    stop_tokens: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    return_probs: bool = False,
    decode_one_token_function=decode_one_token,
    **kwargs,
):
    """
    Decode n tokens from the autoregressive model.
    Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    new_tokens, new_probs = [], []
    batch_size = cur_token.shape[0]
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    if stop_tokens is not None:
        # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch
        eos_reached = torch.tensor([False] * batch_size, device="cuda")
    # TAG
    for t in range(num_new_tokens):
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token_function(
                model,
                tokens=cur_token,
                input_pos=input_pos,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                **kwargs,
            )
            input_pos += 1
            if stop_tokens is not None and len(stop_tokens) > 0:
                eos_reached = eos_reached | (torch.isin(next_token, stop_tokens))
                if eos_reached.all():
                    break
            new_tokens.append(next_token.clone())
            if return_probs:
                new_probs.append(next_prob.clone())
            cur_token = next_token.clone()

    if return_probs:
        return new_tokens, new_probs
    else:
        return new_tokens

# MODIFIED !!!

def sample_n_top_p(logits, temperature, top_p, return_probs: bool = False):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        logits (torch.Tensor): Logits of the probability distribution.
        temperature (float): Temperature for sampling.
        top_p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
    """
    probs = torch.softmax(logits[:, :, :] / temperature, dim=-1)
    # Sort the probabilities in descending order and get their indices.
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    # Compute the cumulative sum of the sorted probabilities.
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    # Create a mask where the cumulative probability exceeds the threshold p.
    mask = probs_sum - probs_sort > top_p
    # Set the probabilities that exceed the threshold to 0.
    probs_sort[mask] = 0.0
    # Renormalize the remaining probabilities so they sum to 1.
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # Sample from the renormalized probability distribution.
    # next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
    # Gather the indices of the sampled tokens.
    next_token = torch.gather(probs_idx, -1, next_token)
    if return_probs:
        # Initialize a tensor for unsorted probabilities
        probs_unsorted = torch.zeros_like(probs_sort)
        # Scatter the sorted probabilities back to their original order
        probs_unsorted.scatter_(-1, probs_idx, probs_sort)
    else:
        probs_unsorted = None
    return next_token, probs_unsorted

def sample_n_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """
    Sample from the logits using top-k sampling.
    logits: [batch_size, n, vocab_size]
    """
    # logits: [batch_size, n, vocab_size]
    if temperature == 0.0:
        # Modify for multiple logits (n items)
        idx_next = torch.argmax(logits[:, :, :], dim=-1, keepdim=True)  # Use all n logits for top-k
        probs = None
    else:
        # Compute probabilities for all n logits
        probs = logits_to_n_probs(logits, temperature, top_k)
        # Sample from multinomial distribution for the top-k logits
        idx_next = multinomial_sample_one_no_sync(probs)

    return idx_next, probs

def logits_to_n_probs(
    logits,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def decode_some_token(
    model: Transformer,
    tokens: torch.Tensor,
    input_pos: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Decode a single token from the autoregressive model.
    """
    logits = model(tokens=tokens, input_pos=input_pos, **kwargs)
    if top_p is not None:
        return sample_n_top_p(logits, temperature=temperature, top_p=top_p)
    else:
        return sample_n_top_k(logits, temperature=temperature, top_k=top_k)
    
COLUMN = 64
ROW = 40
PROMPT_LEN = 5120
FRAME = 3

def decode_diagd_tokens(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    max_new_tokens: int,
    stop_tokens: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    return_probs: bool = False,
    decode_one_token_function=decode_some_token,
    window_size: int = 2, 
    **kwargs,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    max_new_tokens += 1
    prompt_tokens = kwargs.pop("prompt_tokens", None)
    new_tokens, new_probs = [], []
    batch_size, cur_len = cur_token.shape
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    # TAG
    step = 0
    last_token_in_row_idx = None
    row_token_num = torch.zeros((ROW,), dtype=torch.long)
    ongoing_row_list = []
    image_start_token_id_index = 0

    image_num = 1
    while True:
        if cur_len >= max_new_tokens:
            break

        if cur_len % (COLUMN * ROW) == 0  and image_start_token_id_index is None: 
            image_start_token_id_index = cur_len
            image_num += 1

        image_token_num = len(cur_token[0][image_start_token_id_index:])

        input_pos = None
        if image_token_num == 1 :
            ongoing_row_list.append(0)
            row_token_num[0] += 1 
            if row_token_num[0] == window_size:
                ongoing_row_list.append(1)
        if image_token_num >= 1:
            input_id, input_pos = prepare_diagd_inputs(cur_token, last_token_in_row_idx, ongoing_row_list=ongoing_row_list, row_token_num=row_token_num, prompt_tokens=prompt_tokens, image_num=image_num)  
        elif image_token_num == 0:
            input_id = cur_token[:, -1:]
            input_pos = [cur_len - 1 + PROMPT_LEN]

        num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token_function(
                model,
                tokens=input_id,
                input_pos=input_pos,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                **kwargs,
            )
        if len(ongoing_row_list) == 0:
            step += 1
            cur_len += 1
            cur_token = torch.cat([cur_token, next_token[0]], dim=-1)
        else:
            step += 1
            need_remove_row = None
            if last_token_in_row_idx is not None:
                num_skip_tokens = 1
                last_token_in_row_idx = None
            else:
                num_skip_tokens = 0
            cur_len += num_new_tokens - num_skip_tokens
            for i in range(0, num_new_tokens - num_skip_tokens):
                position = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0) + (image_num-1) * ROW * COLUMN ## this is position in input_ids
                cur_token = position_insert(cur_token, next_token[:,i+num_skip_tokens], position) ## input_ids[:, position] = next_tokens
                row_token_num[ongoing_row_list[i]] += 1

                if row_token_num[ongoing_row_list[i]] == window_size and ongoing_row_list[i] < ROW - 1:
                    ongoing_row_list.append(ongoing_row_list[i]+1)
                elif ongoing_row_list[i] == ROW - 1 and row_token_num[ongoing_row_list[i]] == COLUMN:
                    last_token_in_row_idx = None
                    row_token_num = torch.zeros((ROW,), dtype=torch.long)
                    ongoing_row_list = []
                    image_start_token_id_index = None
                    need_remove_row = None
                    break
                if row_token_num[ongoing_row_list[i]] == COLUMN: ## this row is done
                    last_token_in_row_idx = position
                    need_remove_row = ongoing_row_list[i]
            if need_remove_row is not None:
                ongoing_row_list.remove(need_remove_row)
    cur_token = cur_token[:,1:]
    
    # print("step: ", step)
    return [cur_token[:, i].unsqueeze(1) for i in range(cur_token.size(1))]

def prepare_diagd_inputs(
    input_ids,
    last_token_in_row_idx,
    ongoing_row_list,
    row_token_num,
    prompt_tokens,
    image_num,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    new_input_ids = []
    new_position_ids = [] ## the global index of tokens
    if last_token_in_row_idx is not None:
        new_input_ids.append(input_ids[:, last_token_in_row_idx].unsqueeze(-1))
        new_position_ids.append(last_token_in_row_idx)

    for i in ongoing_row_list:
        if row_token_num[i] == 0:
            idx_in_input_ids = i * COLUMN - 1
            global_idx = PROMPT_LEN + i * COLUMN - 1 + (image_num - 1)*ROW*COLUMN
            new_input_ids.append(prompt_tokens[:, idx_in_input_ids].unsqueeze(-1))
        else:
            idx_in_input_ids = torch.sum(row_token_num[:(i + 1)], dim=0) - 1 + (image_num - 1)*ROW*COLUMN
            global_idx = PROMPT_LEN + i * COLUMN + row_token_num[i] - 1 + (image_num - 1)*ROW*COLUMN
            new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1))
        new_position_ids.append(global_idx)
    input_ids = torch.cat(new_input_ids, dim=1)
    position_ids = torch.tensor(new_position_ids, device=input_ids.device)

    return input_ids, position_ids

def position_insert(input_ids, next_tokens, position):
    return torch.cat((input_ids[:, :position], next_tokens[:, :], input_ids[:, position:]), dim=1)

def decode_diagd_video_tokens(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    max_new_tokens: int,
    stop_tokens: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    return_probs: bool = False,
    decode_one_token_function=decode_some_token,
    window_size: int = 2, 
    **kwargs,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    global FRAME
    max_new_tokens += 1
    prompt_tokens = kwargs.pop("prompt_tokens", None)
    _, cur_len = cur_token.shape
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    # TAG
    step = 0
    last_token_in_row_idx = None
    row_token_num_v = []
    ongoing_row_list_v = []

    ongoing_row_list_v.append(0)
    row_token_num_v.append(torch.zeros((ROW,), dtype=torch.long))
    row_token_num_v[0][0] += 1
    if row_token_num_v[0][0] == window_size:
        ongoing_row_list_v.append(1)

    while True:
        if cur_len >= max_new_tokens:
            break
        input_id, input_pos = prepare_diagd_video_inputs(cur_token, last_token_in_row_idx, ongoing_row_list_v=ongoing_row_list_v, row_token_num_v=row_token_num_v, prompt_tokens=prompt_tokens)  


        num_new_tokens = input_id.shape[1] 
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token_function(
                model,
                tokens=input_id,
                input_pos=input_pos,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                **kwargs,
            )

        step += 1
        need_remove_row = None
        if last_token_in_row_idx is not None:
            num_skip_tokens = 1
            last_token_in_row_idx = None
        else:
            num_skip_tokens = 0
        cur_len += num_new_tokens - num_skip_tokens
        for i in range(0, num_new_tokens - num_skip_tokens):
            last_frame = (torch.stack(row_token_num_v[:ongoing_row_list_v[i] // ROW]).sum() if ongoing_row_list_v[i] // ROW > 0 else torch.tensor(0, dtype=torch.long))
            position = last_frame + torch.sum(row_token_num_v[ongoing_row_list_v[i] // ROW][:(ongoing_row_list_v[i] % ROW + 1)], dim=0)  

            cur_token = position_insert(cur_token, next_token[:,i+num_skip_tokens], position) ## input_ids[:, position] = next_tokens
            row_token_num_v[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] += 1

            if row_token_num_v[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] == window_size and ongoing_row_list_v[i] < ROW * FRAME - 1:

                ongoing_row_list_v.append(ongoing_row_list_v[i]+1)
                if ongoing_row_list_v[-1] % ROW == 0:
                    row_token_num_v.append(torch.zeros((ROW,), dtype=torch.long))
            
            if row_token_num_v[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] == COLUMN: ## this row is done
                last_token_in_row_idx = position
                need_remove_row = ongoing_row_list_v[i]

        if need_remove_row is not None:
            ongoing_row_list_v.remove(need_remove_row)
    cur_token = cur_token[:,1:]
    # print("step: ", step)
    return [cur_token[:, i].unsqueeze(1) for i in range(cur_token.size(1))]

def prepare_diagd_video_inputs(
    input_ids,
    last_token_in_row_idx,
    ongoing_row_list_v,
    row_token_num_v,
    prompt_tokens,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    new_input_ids = []
    new_position_ids = [] ## the global index of tokens
    if last_token_in_row_idx is not None:
        new_input_ids.append(input_ids[:, last_token_in_row_idx].unsqueeze(-1))
        new_position_ids.append(last_token_in_row_idx)

    for j in ongoing_row_list_v:
        # if j % ROW == 0 and row_token_num_v[j//ROW][j%ROW] == 0:
        #     idx_in_input_ids = -1 #PROMPT_LEN + (j//ROW - 1) * ROW*COLUMN + 1 - 1
        #     global_idx = PROMPT_LEN + j * COLUMN + 1 - 1
        #     new_input_ids.append(prompt_tokens[:, idx_in_input_ids].unsqueeze(-1))
        if row_token_num_v[j//ROW][j%ROW] == 0:
            idx_in_input_ids = (j%ROW) * COLUMN - 1
            global_idx = PROMPT_LEN + j * COLUMN - 1
            new_input_ids.append(prompt_tokens[:, idx_in_input_ids].unsqueeze(-1))
        else:
            last_frame = torch.stack(row_token_num_v[:j // ROW]).sum() if j // ROW > 0 else torch.tensor(0, dtype=torch.long)
            idx_in_input_ids = last_frame + torch.sum(row_token_num_v[j // ROW][:(j % ROW + 1)], dim=0) - 1
            global_idx = PROMPT_LEN + j * COLUMN + row_token_num_v[j//ROW][j%ROW] - 1
            new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1))
        new_position_ids.append(global_idx)
    input_ids = torch.cat(new_input_ids, dim=1)
    position_ids = torch.tensor(new_position_ids, device=input_ids.device)

    return input_ids, position_ids

def decode_insane_tokens(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    max_new_tokens: int,
    stop_tokens: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    return_probs: bool = False,
    decode_one_token_function=decode_some_token,
    window_size: int = 2,
    **kwargs,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    global FRAME
    row_token_num_insane = torch.zeros([FRAME,ROW], dtype=int, device="cuda")
    # max_new_tokens += 1
    prompt_tokens = kwargs.pop("prompt_tokens", None)
    _, cur_len = cur_token.shape
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    # TAG
    step = 0
    last_token_in_row_idx = None
    ongoing_row_list_v = []

    ongoing_row_list_v.append(0)
    row_token_num_insane[0][0] += 1
    if row_token_num_insane[0][0] == window_size:
        ongoing_row_list_v.append(1)

    while True:
        # if cur_len >= max_new_tokens:
        #     break
        # print(ongoing_row_list_v)
        # print(cur_len)
        if len(ongoing_row_list_v) == 0:
            break
        input_id, input_pos = prepare_insane_inputs(cur_token, last_token_in_row_idx, ongoing_row_list_v=ongoing_row_list_v, row_token_num_v=row_token_num_insane, prompt_tokens=prompt_tokens)  


        num_new_tokens = input_id.shape[1] 
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token_function(
                model,
                tokens=input_id,
                input_pos=input_pos,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                **kwargs,
            )

        step += 1
        need_remove_row = []
        if last_token_in_row_idx is not None:
            num_skip_tokens = 1
            last_token_in_row_idx = None
        else:
            num_skip_tokens = 0
        cur_len += num_new_tokens - num_skip_tokens
        for i in range(0, num_new_tokens - num_skip_tokens):
            assert len(ongoing_row_list_v) > i
            last_frame = row_token_num_insane[:ongoing_row_list_v[i] // ROW].sum() if ongoing_row_list_v[i] // ROW > 0 else torch.tensor(0, dtype=torch.long)
            position = last_frame + torch.sum(row_token_num_insane[ongoing_row_list_v[i] // ROW][:(ongoing_row_list_v[i] % ROW + 1)], dim=0)  


            cur_token = position_insert(cur_token, next_token[:,i+num_skip_tokens], position) 
            row_token_num_insane[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] += 1

            if row_token_num_insane[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] == window_size and ongoing_row_list_v[i] < ROW * FRAME - 1:
                if ongoing_row_list_v[i] % ROW == 0 and ongoing_row_list_v[i]+ROW < ROW * FRAME:
                    ongoing_row_list_v.append(ongoing_row_list_v[i]+ROW)
                if row_token_num_insane[(ongoing_row_list_v[i]+1)//ROW][(ongoing_row_list_v[i]+1)%ROW] == 0:
                    ongoing_row_list_v.append(ongoing_row_list_v[i]+1)
            
            if row_token_num_insane[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] == COLUMN: ## this row is done
                need_remove_row.append(ongoing_row_list_v[i])

        if need_remove_row is not None:
            ongoing_row_list_v = [item for item in ongoing_row_list_v if item not in need_remove_row]
            need_remove_row = []
        # ongoing_row_list_v.sort()
    cur_token = cur_token[:,1:]
    # import pdb; pdb.set_trace()
    # print("step: ", step)
    return [cur_token[:, i].unsqueeze(1) for i in range(cur_token.size(1))]

def prepare_insane_inputs(
    input_ids,
    last_token_in_row_idx,
    ongoing_row_list_v,
    row_token_num_v,
    prompt_tokens,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    new_input_ids = []
    new_position_ids = [] ## the global index of tokens
    # if last_token_in_row_idx is not None:
    #     new_input_ids.append(input_ids[:, last_token_in_row_idx].unsqueeze(-1))
    #     new_position_ids.append(last_token_in_row_idx)

    for j in ongoing_row_list_v:
        if row_token_num_v[j//ROW][j%ROW] == 0:
            idx_in_input_ids = (j%ROW) * COLUMN - 1
            global_idx = PROMPT_LEN + j * COLUMN - 1
            new_input_ids.append(prompt_tokens[:, idx_in_input_ids].unsqueeze(-1))
        else:
            last_frame = row_token_num_v[:j // ROW].sum() if j // ROW > 0 else torch.tensor(0, dtype=torch.long)
            idx_in_input_ids = last_frame + torch.sum(row_token_num_v[j // ROW][:(j % ROW + 1)], dim=0) - 1
            global_idx = PROMPT_LEN + j * COLUMN + row_token_num_v[j//ROW][j%ROW] - 1
            new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1))
        new_position_ids.append(global_idx)
    input_ids = torch.cat(new_input_ids, dim=1)
    position_ids = torch.tensor(new_position_ids, device=input_ids.device)

    return input_ids, position_ids


def prepare_insane_inputs(
    input_ids,
    last_token_in_row_idx,
    ongoing_row_list_v,
    row_token_num_v,
    prompt_tokens,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    new_input_ids = []
    new_position_ids = [] ## the global index of tokens
    # if last_token_in_row_idx is not None:
    #     new_input_ids.append(input_ids[:, last_token_in_row_idx].unsqueeze(-1))
    #     new_position_ids.append(last_token_in_row_idx)

    for j in ongoing_row_list_v:
        if row_token_num_v[j//ROW][j%ROW] == 0:
            idx_in_input_ids = (j%ROW) * COLUMN - 1
            global_idx = PROMPT_LEN + j * COLUMN - 1
            new_input_ids.append(prompt_tokens[:, idx_in_input_ids].unsqueeze(-1))
        else:
            last_frame = row_token_num_v[:j // ROW].sum() if j // ROW > 0 else torch.tensor(0, dtype=torch.long)
            idx_in_input_ids = last_frame + torch.sum(row_token_num_v[j // ROW][:(j % ROW + 1)], dim=0) - 1
            global_idx = PROMPT_LEN + j * COLUMN + row_token_num_v[j//ROW][j%ROW] - 1
            new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1))
        new_position_ids.append(global_idx)
    input_ids = torch.cat(new_input_ids, dim=1)
    position_ids = torch.tensor(new_position_ids, device=input_ids.device)

    return input_ids, position_ids


def decode_insane_tokens_new(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    max_new_tokens: int,
    stop_tokens: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    return_probs: bool = False,
    decode_one_token_function=decode_some_token,
    window_size: int = 2,
    frame_size: int =2,
    **kwargs,
):
    global COLUMN
    global ROW
    global PROMPT_LEN
    global FRAME
    row_token_num_insane = torch.zeros([FRAME,ROW], dtype=int, device="cuda")
    # max_new_tokens += 1
    prompt_tokens = kwargs.pop("prompt_tokens", None)
    _, cur_len = cur_token.shape
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    # TAG
    step = 0
    last_token_in_row_idx = None
    ongoing_row_list_v = []

    ongoing_row_list_v.append(0)
    row_token_num_insane[0][0] += 1
    if row_token_num_insane[0][0] == window_size:
        ongoing_row_list_v.append(1)

    while True:
        # if cur_len >= max_new_tokens:
        #     break
        # print(ongoing_row_list_v)
        # print(cur_len)
        if len(ongoing_row_list_v) == 0:
            break
        input_id, input_pos = prepare_insane_inputs(cur_token, last_token_in_row_idx, ongoing_row_list_v=ongoing_row_list_v, row_token_num_v=row_token_num_insane, prompt_tokens=prompt_tokens)  


        num_new_tokens = input_id.shape[1] 
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token_function(
                model,
                tokens=input_id,
                input_pos=input_pos,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                **kwargs,
            )

        step += 1
        need_remove_row = []
        if last_token_in_row_idx is not None:
            num_skip_tokens = 1
            last_token_in_row_idx = None
        else:
            num_skip_tokens = 0
        cur_len += num_new_tokens - num_skip_tokens
        for i in range(0, num_new_tokens - num_skip_tokens):
            assert len(ongoing_row_list_v) > i
            last_frame = row_token_num_insane[:ongoing_row_list_v[i] // ROW].sum() if ongoing_row_list_v[i] // ROW > 0 else torch.tensor(0, dtype=torch.long)
            position = last_frame + torch.sum(row_token_num_insane[ongoing_row_list_v[i] // ROW][:(ongoing_row_list_v[i] % ROW + 1)], dim=0)  


            cur_token = position_insert(cur_token, next_token[:,i+num_skip_tokens], position) 
            row_token_num_insane[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] += 1

            if row_token_num_insane[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] == window_size and ongoing_row_list_v[i] < ROW * FRAME - 1:
                # if ongoing_row_list_v[i] % ROW == 0 and ongoing_row_list_v[i]+ROW < ROW * FRAME:
                #     ongoing_row_list_v.append(ongoing_row_list_v[i]+ROW)
                if row_token_num_insane[(ongoing_row_list_v[i]+1)//ROW][(ongoing_row_list_v[i]+1)%ROW] == 0:
                    ongoing_row_list_v.append(ongoing_row_list_v[i]+1)
                if ongoing_row_list_v[i] % ROW == frame_size -1 and ongoing_row_list_v[i] // ROW < FRAME - 1:  
                    ongoing_row_list_v.append((ongoing_row_list_v[i] // ROW)*ROW + ROW)
            
                
            
            if row_token_num_insane[ongoing_row_list_v[i]//ROW][ongoing_row_list_v[i]%ROW] == COLUMN: ## this row is done
                need_remove_row.append(ongoing_row_list_v[i])

        if len(need_remove_row) > 0:
            ongoing_row_list_v = [item for item in ongoing_row_list_v if item not in need_remove_row]
            need_remove_row = []
        # ongoing_row_list_v.sort()
    cur_token = cur_token[:,1:]
    # import pdb; pdb.set_trace()
    # print("step: ", step)
    return [cur_token[:, i].unsqueeze(1) for i in range(cur_token.size(1))]
