# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel

from cosmos_predict1.autoregressive.networks.transformer import Transformer


def sample_top_p(logits, temperature, top_p, return_probs: bool = False):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        logits (torch.Tensor): Logits of the probability distribution.
        temperature (float): Temperature for sampling.
        top_p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
    """
    probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
    # Sort the probabilities in descending order and get their indices.
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    # Compute the cumulative sum of the sorted probabilities.
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    # Create a mask where the cumulative probability exceeds the threshold p.
    mask = probs_sum - probs_sort > top_p
    # Set the probabilities that exceed the threshold to 0.
    probs_sort[mask] = 0.0
    # Renormalize the remaining probabilities so they sum to 1.
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # Sample from the renormalized probability distribution.
    # next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
    # Gather the indices of the sampled tokens.
    next_token = torch.gather(probs_idx, -1, next_token)
    if return_probs:
        # Initialize a tensor for unsorted probabilities
        probs_unsorted = torch.zeros_like(probs_sort)
        # Scatter the sorted probabilities back to their original order
        probs_unsorted.scatter_(-1, probs_idx, probs_sort)
    else:
        probs_unsorted = None
    return next_token, probs_unsorted


def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
    """
    Multinomial sampling without a cuda synchronization.
    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)


def logits_to_probs(
    logits,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """
    Sample from the logits using top-k sampling.
    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    # logits: [batch_size, seq_len, vocab_size]
    if temperature == 0.0:
        idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        probs = None
    else:
        probs = logits_to_probs(logits[:, -1, :], temperature, top_k)
        idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


def prefill(
    model: Transformer,
    input_pos: torch.Tensor,
    tokens: torch.Tensor = None,
    token_embeddings: torch.Tensor = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    **kwargs,
) -> torch.Tensor:
    logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs)
    # Only top-p or top-k can be provided
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    if top_p is not None:
        return sample_top_p(logits, temperature=temperature, top_p=top_p)[0]
    else:
        return sample_top_k(logits, temperature=temperature, top_k=top_k)[0]


def decode_one_token(
    model: Transformer,
    tokens: torch.Tensor,
    input_pos: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Decode a single token from the autoregressive model.
    """
    logits = model(tokens=tokens, input_pos=input_pos, **kwargs)
    if top_p is not None:
        return sample_top_p(logits, temperature=temperature, top_p=top_p)
    else:
        return sample_top_k(logits, temperature=temperature, top_k=top_k)


def decode_n_tokens(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    num_new_tokens: int,
    stop_tokens: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    return_probs: bool = False,
    decode_one_token_function=decode_one_token,
    **kwargs,
):
    """
    Decode n tokens from the autoregressive model.
    Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    new_tokens, new_probs = [], []
    batch_size = cur_token.shape[0]
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    if stop_tokens is not None:
        # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch
        eos_reached = torch.tensor([False] * batch_size, device="cuda")
    for t in range(num_new_tokens):
        with sdpa_kernel([SDPBackend.MATH]):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token_function(
                model,
                tokens=cur_token,
                input_pos=input_pos,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                **kwargs,
            )
            input_pos += 1
            if stop_tokens is not None and len(stop_tokens) > 0:
                eos_reached = eos_reached | (torch.isin(next_token, stop_tokens))
                if eos_reached.all():
                    break
            new_tokens.append(next_token.clone())
            if return_probs:
                new_probs.append(next_prob.clone())
            cur_token = next_token.clone()

    if return_probs:
        return new_tokens, new_probs
    else:
        return new_tokens
