from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
from accelerate import Accelerator

from models.token import KVCache
from utils import distance_to_next_zero_1d


def fixed_cross_entropy(
    source: torch.Tensor,
    target: torch.Tensor,
    num_items_in_batch: int = None,
    loss_masks: torch.Tensor = None,
    eps: float = 1e-6,
) -> torch.Tensor:
    loss = nn.functional.cross_entropy(source, target, reduction="none")
    if loss_masks is not None:
        loss = loss * loss_masks

    # if num_items_in_batch is not None:
    #     loss = loss.sum() / num_items_in_batch
    # else:
    loss = loss.sum() / (loss_masks.sum() + eps)
    return loss


def distribution_cross_entropy(
    source: torch.Tensor,
    target: torch.Tensor,
    labels: torch.Tensor = None,
    num_items_in_batch: int = None,
    loss_masks: torch.Tensor = None,
    top_k: int = 0,
    eps: float = 1e-9,
    weight: torch.Tensor = None,
) -> torch.Tensor:
    """
    source: (batch_size, num_classes) shape logits
    target: (batch_size, num_classes) shape target distribution (sum of each row should be 1)
    num_items_in_batch: if not None, divide the loss by this value
    """
    log_probs = F.log_softmax(source, dim=-1)

    if top_k > 0:
        topk_probs, topk_indices = target.topk(top_k, dim=-1, largest=True)
        topk_log_probs = log_probs.gather(-1, topk_indices)
        cross_entropy_per_sample = -torch.sum(topk_probs * topk_log_probs, dim=-1)
    else:
        probs = log_probs.exp()
        cross_entropy_per_sample_kl = -target * log_probs

        # weight = torch.ones_like(target.max(-1).values)
        # for i in range(4):
        #     target_top_idx = target.topk(16, dim=-1, largest=True).indices[..., -1:]
        #     weight = weight - target.gather(-1, target_top_idx).squeeze(-1)
        #     target = target.scatter(-1, target_top_idx, 0)
        #     probs = probs.scatter(-1, target_top_idx, 0)
        #     target_normed = target / (target.norm(dim=-1, keepdim=True) + eps)
        #     prob_normed = probs / probs.norm(dim=-1, keepdim=True) + eps
        #     cross_entropy_per_sample_kl += (-target_normed * prob_normed.log()) * weight.unsqueeze(-1)
        #     # import pdb; pdb.set_trace()
        #     # import pdb; pdb.set_trace()

        # cross_entropy_per_sample_kl = F.mse_loss(target, log_probs, reduction="none")
        if weight is not None:
            cross_entropy_per_sample_kl = cross_entropy_per_sample_kl * weight[0, :-1]

        cross_entropy_per_sample = cross_entropy_per_sample_kl.sum(-1)

    if loss_masks is not None:
        cross_entropy_per_sample = cross_entropy_per_sample * loss_masks

    # if num_items_in_batch is not None:
    #     loss = cross_entropy_per_sample.sum() / num_items_in_batch
    # else:
    loss = cross_entropy_per_sample.sum() / (loss_masks.sum() + eps)
    # print(loss)

    return loss


class ForDraftCausalLMMetric:
    def __init__(self, model, num_seqs, ignore_index: int = -100):
        self.model = model
        self.metrics = defaultdict(list)
        self.ignore_index = ignore_index

        self.past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            model.config.num_key_value_heads,
            model.config.max_position_embeddings * 8,
            model.config.head_dim,
            model.device,
            model.dtype,
            num_seqs=num_seqs,
        )

        self.accelerator = Accelerator()
        self.num_procs = self.accelerator.num_processes
        self.proc_idx = self.accelerator.process_index

    @torch.inference_mode()
    def __call__(
        self,
        eval_predictions,
        compute_result: bool = True,
    ):
        labels, loss_masks, hidden_states, base_hidden_states, *_ = eval_predictions.label_ids
        labels[labels == self.ignore_index] = 0

        labels = labels[self.proc_idx::self.num_procs]
        loss_masks = loss_masks[self.proc_idx::self.num_procs]
        hidden_states = hidden_states[self.proc_idx::self.num_procs]
        base_hidden_states = base_hidden_states[self.proc_idx::self.num_procs]

        inputs = {
            "input_ids": labels,
            "hidden_states": base_hidden_states,
            "return_dict": True,
        }

        last_hidden_states = inputs["hidden_states"]

        draft_hidden_states = []
        draft_logits = []

        kv_cache_indices = []
        min_value = torch.finfo(self.model.dtype).min

        for i in range(5):
            new_kv_cache_indices = self.past_key_values.allocate(labels.size(1))
            kv_cache_indices.extend(new_kv_cache_indices)

            q_len = labels.size(1)
            attention_mask = torch.full(
                (labels.size(0), 1, q_len, q_len * (i + 1)),
                min_value,
                device=self.model.device,
                dtype=self.model.dtype,
            )
            for b in range(labels.size(0)):
                attention_mask[b, 0, :, :q_len] = torch.triu(attention_mask[b, 0, :, :q_len], diagonal=1 - i)
                if i > 0:
                    attention_mask[b, 0, :, q_len * i:q_len * (i + 1)].fill_(min_value)
                for j in range(i):
                    attention_mask[b, 0, :, q_len * j:q_len * (j + 1)].diagonal(-i + j + 1).fill_(min_value)
                    attention_mask[b, 0, :, q_len * j:q_len * (j + 1)].diagonal(-i + j).fill_(0)
                    attention_mask[b, 0, :, q_len * (j + 1):q_len * (j + 2)].diagonal(-i + j + 1).fill_(0)

            inputs["hidden_states"] = last_hidden_states
            inputs["base_hidden_states"] = base_hidden_states
            inputs["position_ids"] = torch.arange(0, q_len, device=self.model.device).unsqueeze(0)
            inputs["attention_mask"] = attention_mask
            inputs["past_key_values"] = self.past_key_values
            inputs["past_key_value_indices"] = kv_cache_indices
            inputs["use_cache"] = True

            output = self.model(**inputs)
            if i == 0:
                target_logits = output["base_logits"].float()

            draft_hidden_state = output["draft_hidden_states"].float()
            draft_logit = output["logits"].float()

            draft_hidden_states.append(draft_hidden_state)
            draft_logits.append(draft_logit)
            last_hidden_states = draft_hidden_state

        self.past_key_values.free(kv_cache_indices)

        metrics = dict()
        for idx, (draft_hidden_state, draft_logit) in enumerate(zip(draft_hidden_states, draft_logits)):
            metrics.update(
                self.get_metrics(target_logits, draft_logit, labels, hidden_states, draft_hidden_state, loss_masks, prefix=f"{idx}")
            )

        sorted_keys = sorted(metrics.keys())
        for key in sorted_keys:
            value = self.accelerator.gather(metrics[key]).mean().item()
            metrics[key] = value
            self.metrics[key].append(value)

        if compute_result:
            result = dict()
            for key, value in self.metrics.items():
                result[key] = np.mean(value)
            self.metrics.clear()
            return result

        return metrics

    def get_metrics(self, logits, draft_logits, labels, hidden_states, draft_hidden_states, loss_mask, prefix=""):
        loss_mask = loss_mask[..., :-1].float().reshape(-1)
        loss_mask[loss_mask == self.ignore_index] = 0
        base_topk = logits.topk(5, dim=-1, largest=True).indices
        draft_topk = draft_logits.topk(5, dim=-1, largest=True).indices
        base_labels = base_topk[..., 0]
        metrics = dict()
        base_acc_cumsum = 0
        draft_acc_cumsum = 0
        draft_base_acc_cumsum = 0
        for i in range(5):
            base_acc = (base_topk[:, :-1, i].reshape(-1) == labels[:, 1:].reshape(-1)).float()
            base_acc = (base_acc * loss_mask).sum() / (loss_mask.sum() + 1e-6)
            base_acc_cumsum += base_acc
            draft_acc = (draft_topk[:, :-1, i].reshape(-1) == labels[:, 1:].reshape(-1)).float()
            draft_acc = (draft_acc * loss_mask).sum() / (loss_mask.sum() + 1e-6)
            draft_acc_cumsum += draft_acc
            draft_base_acc = (draft_topk[:, :-1, i].reshape(-1) == base_labels[:, :-1].reshape(-1)).float()
            draft_base_acc = (draft_base_acc * loss_mask).sum() / (loss_mask.sum() + 1e-6)
            draft_base_acc_cumsum += draft_base_acc

            metrics[f"{prefix}/base_acc_{i}"] = base_acc_cumsum.clone()
            metrics[f"{prefix}/draft_acc_{i}"] = draft_acc_cumsum.clone()
            metrics[f"{prefix}/draft_base_acc_{i}"] = draft_base_acc_cumsum.clone()

        shift_logits = draft_logits[:, :-1]
        shift_labels = logits[:, 1:].argmax(-1)
        shift_logits = shift_logits.reshape(-1, shift_logits.size(-1))
        shift_labels = shift_labels.reshape(-1)

        feature_loss = nn.functional.smooth_l1_loss(
            draft_hidden_states,
            hidden_states,
            reduction="none",
        ).mean(dim=-1)

        feature_loss = (feature_loss[:, :-1].reshape(-1) * loss_mask).sum() / (loss_mask.sum() + 1e-6)
        metrics[f"{prefix}/feature_loss"] = feature_loss.clone()

        return metrics


class ForDraftCausalLMRLMetric:
    def __init__(self, model, num_seqs, ignore_index: int = -100):
        self.model = model
        self.metrics = defaultdict(list)
        self.ignore_index = ignore_index

        self.accelerator = Accelerator()
        self.num_procs = self.accelerator.num_processes
        self.proc_idx = self.accelerator.process_index

        self.past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            self.model.config.num_key_value_heads,
            self.model.config.max_position_embeddings * 8,
            self.model.config.head_dim,
            self.model.device,
            self.model.dtype,
            num_seqs=num_seqs,
        )

        self.depth = 3

    @torch.inference_mode()
    def __call__(
        self,
        eval_predictions,
        compute_result: bool = True,
    ):
        labels, loss_masks, hidden_states, base_hidden_states, output_topk_ids, *_ = eval_predictions.label_ids
        labels[labels == self.ignore_index] = 0
        loss_masks[loss_masks == self.ignore_index] = 0

        labels = labels[self.proc_idx::self.num_procs]
        loss_masks = loss_masks[self.proc_idx::self.num_procs]
        hidden_states = hidden_states[self.proc_idx::self.num_procs]
        base_hidden_states = base_hidden_states[self.proc_idx::self.num_procs]
        output_topk_ids = output_topk_ids[self.proc_idx::self.num_procs]

        inputs = {
            "input_ids": labels,
            "hidden_states": base_hidden_states,
            "return_dict": True,
        }

        last_input_ids = inputs["input_ids"]
        last_hidden_states = inputs["hidden_states"]

        kv_cache_indices = []
        min_value = torch.finfo(self.model.dtype).min

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        if q_len_valid == 0:
            metrics = dict()
            for depth in range(self.depth):
                for i in range(5):
                    metrics[f"{depth}/draft_base_acc_{i}"] = torch.tensor(0.0, device=self.model.device, dtype=self.model.dtype)

            sorted_keys = sorted(metrics.keys())
            for key in sorted_keys:
                value = self.accelerator.gather(metrics[key]).mean().item()
                metrics[key] = value
                self.metrics[key].append(value)

            if compute_result:
                result = dict()
                for key, value in self.metrics.items():
                    result[key] = np.mean(value)
                self.metrics.clear()
                return result

            return metrics

        total_output_log_probs = []

        for i in range(self.depth):
            if i == 0:
                new_kv_cache_indices = self.past_key_values.allocate(q_len)
            else:
                new_kv_cache_indices = self.past_key_values.allocate(q_len_valid)
            kv_cache_indices.extend(new_kv_cache_indices)

            attention_mask = torch.full(
                (labels.size(0), 1, q_len, q_len + q_len_valid * i),
                min_value,
                device=self.model.device,
                dtype=self.model.dtype,
            )

            attention_mask[0, 0, :, :q_len] = torch.triu(
                attention_mask[0, 0, :, :q_len],
                diagonal=1,
            )
            if i > 0:
                attention_mask = attention_mask[:, :, loss_masks.bool().flatten(), :]
            for j in range(i):
                attention_mask[0, 0, :, q_len + j * q_len_valid:q_len + (j + 1) * q_len_valid].diagonal(0).fill_(0)

            model_inputs = {}
            model_inputs["input_ids"] = last_input_ids
            model_inputs["position_ids"] = torch.arange(0, q_len, device=self.model.device).unsqueeze(0) + i
            if i > 0:
                model_inputs["position_ids"] = model_inputs["position_ids"][:, loss_masks.bool().flatten()]
            model_inputs["hidden_states"] = last_hidden_states
            model_inputs["attention_mask"] = attention_mask
            model_inputs["past_key_values"] = self.past_key_values
            model_inputs["past_key_value_indices"] = kv_cache_indices
            model_inputs["use_cache"] = True

            output = self.model(**model_inputs, shift_tokens=i == 0, cut_last_token=i == 0)

            output_log_probs = output["logits"].log_softmax(-1)

            # output_ids = output["logits"].argmax(-1)
            output_ids = output_topk_ids[:, loss_masks.bool().flatten(), i, 0].reshape(1, q_len_valid)

            draft_hidden_state = output["draft_hidden_states"]

            last_input_ids = output_ids
            last_hidden_states = draft_hidden_state

            if i == 0:
                last_hidden_states = last_hidden_states[:, loss_masks.bool().flatten()]

            if i == 0:
                total_output_log_probs.append(output_log_probs[:, loss_masks.bool().flatten()])
            else:
                total_output_log_probs.append(output_log_probs)

        self.past_key_values.free(kv_cache_indices)

        metrics = dict()
        for depth in range(self.depth):
            base_topk = output_topk_ids[:, loss_masks.bool().flatten(), depth]
            draft_topk = total_output_log_probs[depth].topk(5, dim=-1, largest=True).indices
            base_labels = base_topk[..., 0]
            base_acc_cumsum = 0
            draft_acc_cumsum = 0
            draft_base_acc_cumsum = 0
            for i in range(5):
                # if q_len_valid > 1:
                    # base_acc = (base_topk[:, :-1, i].reshape(-1) == labels[:, 1:].reshape(-1)).float()
                    # base_acc = base_acc.mean()
                    # base_acc_cumsum += base_acc
                    # draft_acc = (draft_topk[:, :-1, i].reshape(-1) == labels[:, 1:].reshape(-1)).float()
                    # draft_acc = draft_acc.mean()
                    # draft_acc_cumsum += draft_acc
                draft_base_acc = (draft_topk[:, :, i].reshape(-1) == base_labels[:, :].reshape(-1)).float()
                draft_base_acc = draft_base_acc.mean()
                draft_base_acc_cumsum += draft_base_acc

                # if q_len_valid > 1:
                    # metrics[f"{depth}/base_acc_{i}"] = base_acc_cumsum.clone()
                    # metrics[f"{depth}/draft_acc_{i}"] = draft_acc_cumsum.clone()
                metrics[f"{depth}/draft_base_acc_{i}"] = draft_base_acc_cumsum.clone()

        sorted_keys = sorted(metrics.keys())
        for key in sorted_keys:
            value = self.accelerator.gather(metrics[key]).mean().item()
            metrics[key] = value
            self.metrics[key].append(value)

        if compute_result:
            result = dict()
            for key, value in self.metrics.items():
                result[key] = np.mean(value)
            self.metrics.clear()
            return result

        return metrics


class DraftCausalLMDistillLoss:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        **kwargs
    ):
        self.model = model
        self.ignore_index = ignore_index
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult

        self.past_key_values = KVCache(
            model.config.draft_num_hidden_layers,
            model.config.num_key_value_heads,
            model.config.max_position_embeddings,
            model.config.head_dim,
            model.device,
            model.dtype,
            num_seqs=num_seqs,
        )

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        model_output = model(**inputs)

        # Upcast to float if we need to compute the loss to avoid potential precision issues
        logits = model_output["logits"]
        target_logits = model_output["base_logits"]
        logits = logits.float()
        target_logits = target_logits.float()
        labels = labels.to(logits.device)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_target_logits = target_logits[..., :-1, :].contiguous()
        shift_target_probs = F.softmax(shift_target_logits, dim=-1)
        shift_labels = labels[..., 1:].contiguous()

        vocab_size = shift_logits.size(-1)
        # Flatten the tokens
        shift_logits = shift_logits.view(-1, vocab_size)
        shift_target_probs = shift_target_probs.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        loss_masks = model_output["loss_masks"][..., :-1].contiguous().view(-1)
        loss_masks[loss_masks == self.ignore_index] = 0
        loss = distribution_cross_entropy(
            shift_logits,
            shift_target_probs,
            shift_labels,
            None,
            loss_masks,
            self.top_k,
        )

        feature_loss = nn.functional.smooth_l1_loss(
            model_output["draft_hidden_states"],
            model_output["hidden_states"],
            reduction="none",
        ).mean(dim=-1)[:, :-1].reshape(-1)

        feature_loss = (feature_loss * loss_masks).sum() / (loss_masks.sum() + self.eps)

        total_loss = loss * self.p_weight + feature_loss
        total_loss = total_loss * self.mult

        if return_outputs:
            return total_loss, model_output

        return total_loss


    def loss(self, logits, target_logits, loss_masks):
        logits = logits.float()
        target_logits = target_logits.float()
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_target_logits = target_logits[..., :-1, :].contiguous()
        shift_target_probs = F.softmax(shift_target_logits, dim=-1)

        # Flatten the tokens
        b, t, v = shift_logits.size()
        shift_logits = shift_logits.view(-1, v)
        shift_target_probs = shift_target_probs.view(-1, v)

        loss_masks = loss_masks.detach().clone()[..., :-1].reshape(-1)

        loss = distribution_cross_entropy(
            shift_logits,
            shift_target_probs,
            None,
            None,
            loss_masks,
            self.top_k,
        )

        return loss


class DraftCausalLMChainedLoss:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        depth: int = 3,
        **kwargs
    ):
        self.model = model
        self.device = model.device
        self.dtype = model.dtype
        self.ignore_index = ignore_index
        self.num_seqs = num_seqs
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult
        self.depth = depth

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        # Upcast to float if we need to compute the loss to avoid potential precision issues

        past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            self.model.config.num_key_value_heads,
            inputs["input_ids"].size(1) * self.depth,
            self.model.config.head_dim,
            self.model.device,
            self.model.dtype,
            num_seqs=self.num_seqs,
        )

        draft_hidden_states = []
        draft_logits = []
        base_hidden_states = inputs["base_hidden_states"]
        last_hidden_states = inputs["hidden_states"]
        last_input_ids = inputs["input_ids"]
        loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
        loss_masks[loss_masks == self.ignore_index] = 0

        kv_cache_indices = []
        min_value = torch.finfo(self.dtype).min

        for i in range(self.depth):
            new_kv_cache_indices = past_key_values.allocate(labels.size(1))
            kv_cache_indices.extend(new_kv_cache_indices)

            q_len = labels.size(1)

            attention_mask = torch.full(
                (labels.size(0), 1, q_len, q_len * (i + 1)),
                min_value,
                device=self.device,
                dtype=self.dtype,
            )

            for b in range(labels.size(0)):
                attention_mask[b, 0, :, :q_len] = torch.triu(attention_mask[b, 0, :, :q_len], diagonal=1 - i)
                if i > 0:
                    attention_mask[b, 0, :, q_len * i:q_len * (i + 1)].fill_(min_value)
                for j in range(i):
                    attention_mask[b, 0, :, q_len * j:q_len * (j + 1)].diagonal(-i + j + 1).fill_(min_value)
                    attention_mask[b, 0, :, q_len * j:q_len * (j + 1)].diagonal(-i + j).fill_(0)
                    attention_mask[b, 0, :, q_len * (j + 1):q_len * (j + 2)].diagonal(-i + j + 1).fill_(0)

            output = model(
                input_ids=last_input_ids,
                position_ids=torch.arange(0, q_len, device=self.device).unsqueeze(0),
                hidden_states=last_hidden_states,
                base_hidden_states=base_hidden_states,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                past_key_value_indices=kv_cache_indices,
                use_cache=True,
            )

            if i == 0:
                target_logits = output["base_logits"].float().detach().clone().requires_grad_(False)

            draft_hidden_state = output["draft_hidden_states"]
            draft_logit = output["logits"].float()

            draft_hidden_states.append(draft_hidden_state)
            draft_logits.append(draft_logit)
            last_hidden_states = draft_hidden_state

        past_key_values.free(kv_cache_indices)

        losses = []

        for draft_logit in draft_logits:
            loss = self.loss(draft_logit, target_logits, loss_masks, weight)
            losses.append(loss)

        feature_losses = []

        for draft_hidden_state in draft_hidden_states:
            feature_loss = nn.functional.smooth_l1_loss(
                draft_hidden_state,
                inputs["base_hidden_states"],
                reduction="none",
            ).mean(dim=-1)

            feature_loss = (feature_loss * loss_masks).sum() / (loss_masks.sum() + self.eps)

            feature_losses.append(feature_loss)

        # weight = len(losses) - torch.arange(0, len(losses), device=loss.device, dtype=loss.dtype)
        # weight = torch.arange(0, len(losses), device=loss.device, dtype=loss.dtype) + 1
        weight = torch.ones(len(losses), device=loss.device, dtype=loss.dtype)
        loss = torch.stack(losses) * weight
        loss = loss.sum() / weight.sum()
        feature_loss = torch.stack(feature_losses) * weight
        feature_loss = feature_loss.sum() / weight.sum()

        total_loss = loss * self.p_weight + feature_loss
        total_loss = total_loss * self.mult

        if return_outputs:
            return total_loss, output

        return total_loss

    def loss(self, logits, target_logits, loss_masks, weight=None):
        logits = logits.float()
        target_logits = target_logits.float()
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_target_logits = target_logits[..., :-1, :].contiguous()
        shift_target_probs = F.softmax(shift_target_logits, dim=-1)

        # Flatten the tokens
        b, t, v = shift_logits.size()
        shift_logits = shift_logits.view(-1, v)
        shift_target_probs = shift_target_probs.view(-1, v)

        loss_masks = loss_masks.detach().clone()[..., :-1].reshape(-1)

        loss = distribution_cross_entropy(
            shift_logits,
            shift_target_probs,
            None,
            None,
            loss_masks,
            self.top_k,
            weight=weight,
        )

        return loss


class DraftCausalLMChainedRLLoss:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        depth: int = 3,
        **kwargs
    ):
        self.model = model
        self.device = model.device
        self.dtype = model.dtype
        self.ignore_index = ignore_index
        self.num_seqs = num_seqs
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult
        self.depth = depth

        self._log_idx = 0
        self._log_metrics = defaultdict(list)

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        # Upcast to float if we need to compute the loss to avoid potential precision issues
        last_input_ids = inputs["input_ids"]
        last_hidden_states = inputs["hidden_states"]
        loss_masks = inputs["loss_masks"].detach().clone().float().requires_grad_(False)
        loss_masks[loss_masks == self.ignore_index] = 0
        output_topk_ids = inputs["output_topk_ids"]
        output_topk_probs = inputs["output_topk_probs"]

        past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            model.config.num_key_value_heads,
            inputs["input_ids"].size(1) * self.depth,
            model.config.head_dim,
            model.device,
            model.dtype,
            num_seqs=self.num_seqs,
        )

        min_value = torch.finfo(model.dtype).min

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        kv_cache_indices = []

        total_output_log_probs = []

        for i in range(self.depth):
            if i == 0:
                new_kv_cache_indices = past_key_values.allocate(q_len)
            else:
                new_kv_cache_indices = past_key_values.allocate(q_len_valid)
            kv_cache_indices.extend(new_kv_cache_indices)

            attention_mask = torch.full(
                (labels.size(0), 1, q_len, q_len + q_len_valid * i),
                min_value,
                device=model.device,
                dtype=model.dtype,
            )

            attention_mask[0, 0, :, :q_len] = torch.triu(
                attention_mask[0, 0, :, :q_len],
                diagonal=1,
            )
            if i > 0:
                attention_mask = attention_mask[:, :, loss_masks.bool().flatten(), :]
            for j in range(i):
                attention_mask[0, 0, :, q_len + j * q_len_valid:q_len + (j + 1) * q_len_valid].diagonal(0).fill_(0)

            model_inputs = {}
            model_inputs["input_ids"] = last_input_ids
            model_inputs["position_ids"] = torch.arange(0, q_len, device=model.device).unsqueeze(0) + i
            if i > 0:
                model_inputs["position_ids"] = model_inputs["position_ids"][:, loss_masks.bool().flatten()]
            model_inputs["hidden_states"] = last_hidden_states
            model_inputs["attention_mask"] = attention_mask
            model_inputs["past_key_values"] = past_key_values
            model_inputs["past_key_value_indices"] = kv_cache_indices
            model_inputs["use_cache"] = True

            output = model(**model_inputs, shift_tokens=i == 0, cut_last_token=i == 0)

            if q_len_valid == 0:
                loss = torch.tensor(0.0, device=self.device, dtype=self.dtype, requires_grad=True)
                if return_outputs:
                    return loss,  {"logits": output["logits"].float()}
                return loss

            output_log_probs = output["logits"].log_softmax(-1)

            # output_ids = output["logits"].argmax(-1)
            output_ids = output_topk_ids[:, loss_masks.bool().flatten(), i, 0].reshape(1, q_len_valid)

            draft_hidden_state = output["draft_hidden_states"]

            last_input_ids = output_ids
            last_hidden_states = draft_hidden_state

            if i == 0:
                output_logits = output["logits"].float()
                # last_input_ids = last_input_ids[:, loss_masks.bool().flatten()]
                last_hidden_states = last_hidden_states[:, loss_masks.bool().flatten()]

            if i == 0:
                total_output_log_probs.append(output_log_probs[:, loss_masks.bool().flatten()])
            else:
                total_output_log_probs.append(output_log_probs)

        valid_mask = torch.ones(output_log_probs.argmax(-1).shape, device=model.device, dtype=torch.bool)
        losses = []
        for i in range(self.depth):
            output_log_probs = total_output_log_probs[i]
            target_output_ids = output_topk_ids[:, loss_masks.bool().flatten(), i]
            target_output_probs = output_topk_probs[:, loss_masks.bool().flatten(), i]
            sample_output_log_probs = torch.gather(output_log_probs, -1, target_output_ids)
            loss_weight = weight[:, loss_masks.bool().flatten()] if weight is not None else 1.0

            if False:
                loss = -torch.sum(sample_output_log_probs * target_output_probs, dim=-1).mean()
            else:
                loss = -torch.sum(loss_weight * sample_output_log_probs * target_output_probs, dim=-1)
                loss = (loss * valid_mask.float()).sum() / (valid_mask.float().sum() + self.eps)

            # ce_loss = -sample_output_log_probs[..., 0]
            # ce_loss = (ce_loss * valid_mask.float()).sum() / (valid_mask.float().sum() + self.eps)

            valid_mask = valid_mask & (output_log_probs.argmax(-1) == target_output_ids[..., 0])

            # losses.append(loss + ce_loss)
            losses.append(loss)

        past_key_values.free(kv_cache_indices)

        total_loss = torch.stack(losses).mean()
        total_loss = total_loss * self.mult

        if return_outputs:
            return total_loss, {"logits": output_logits}

        return total_loss


class DraftCausalLMTreeMetric:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        depth: int = 3,
        top_draft: int = 4,
        top_node: int = 16,
        **kwargs
    ):
        self.model = model
        self.device = model.device
        self.dtype = model.dtype
        self.ignore_index = ignore_index
        self.num_seqs = num_seqs
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult
        self.depth = depth
        self.top_draft = top_draft
        self.top_node = top_node

        self._log_idx = 0
        self._log_metrics = defaultdict(list)

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        # Upcast to float if we need to compute the loss to avoid potential precision issues

        loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
        loss_masks[loss_masks == self.ignore_index] = 0
        output_topk_ids = inputs["output_topk_ids"]
        output_topk_probs = inputs["output_topk_probs"]

        min_value = torch.finfo(model.dtype).min
        max_value = torch.finfo(model.dtype).max

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            self.model.config.num_key_value_heads,
            q_len + q_len_valid * (self.depth - 1) * self.top_draft,
            self.model.config.head_dim,
            self.model.device,
            self.model.dtype,
            num_seqs=self.num_seqs,
        )

        valid_mask = loss_masks.flatten().bool()

        kv_cache_indices = []

        last_input_ids = inputs["input_ids"]
        last_hidden_states = inputs["hidden_states"]
        last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)
        last_output_score = torch.zeros((q_len, 1), device=model.device, dtype=model.dtype)

        last_attention_mask = torch.full(
            (labels.size(0), 1, q_len, q_len),
            min_value,
            device=model.device,
            dtype=model.dtype,
            requires_grad=False,
        )

        last_attention_mask[0, 0] = torch.triu(
            last_attention_mask[0, 0],
            diagonal=1,
        )

        total_output_log_probs = []
        total_output_scores = []

        for i in range(self.depth):
            ql = q_len if i == 0 else q_len_valid
            nc = self.top_draft if i > 0 else 1

            new_kv_cache_indices = past_key_values.allocate(ql * nc)
            kv_cache_indices.extend(new_kv_cache_indices)

            model_inputs = {}
            model_inputs["input_ids"] = last_input_ids
            model_inputs["position_ids"] = last_position_ids
            model_inputs["hidden_states"] = last_hidden_states
            model_inputs["attention_mask"] = last_attention_mask
            model_inputs["past_key_values"] = past_key_values
            model_inputs["past_key_value_indices"] = kv_cache_indices
            model_inputs["use_cache"] = True

            output = model(
                **model_inputs,
                shift_tokens=i == 0,
                cut_last_token=i == 0,
            )

            if q_len_valid == 0:
                if return_outputs:
                    return torch.tensor(-1.0, device=self.device, dtype=self.dtype, requires_grad=True), output
                return torch.tensor(-1.0, device=self.device, dtype=self.dtype, requires_grad=True)

            v = output["logits"].size(-1)
            output_logits = output["logits"].log_softmax(-1).reshape(ql, nc * v)
            assert output_logits.size() == (ql, nc * v)

            assert last_output_score.size() == (ql, nc)
            # last_output_score = last_output_score.reshape(ql, nc, 1) + output_logits.reshape(ql, nc, v)
            last_output_score = output["logits"].reshape(ql, nc, v)
            last_output_score = last_output_score.reshape(ql, nc * v)

            last_hidden_states = output["draft_hidden_states"]
            h = last_hidden_states.size(-1)
            assert last_hidden_states.size() == (1, ql * nc, h)

            last_position_ids = last_position_ids + 1
            assert last_position_ids.size() == (1, ql * nc)

            if i == 0:
                assert last_hidden_states.size() == (1, q_len * nc, h)
                last_hidden_states = last_hidden_states.reshape(1, q_len, nc, h)
                last_hidden_states = last_hidden_states[:, valid_mask]
                assert last_hidden_states.size() == (1, q_len_valid, nc, h)
                last_hidden_states = last_hidden_states.reshape(1, -1, h)
                assert last_hidden_states.size() == (1, q_len_valid * nc, h)

                kv_len = last_attention_mask.size(-1)
                assert last_attention_mask.size() == (1, 1, q_len, kv_len)
                last_attention_mask = last_attention_mask.reshape(1, 1, ql, kv_len)
                last_attention_mask = last_attention_mask[:, :, valid_mask]
                assert last_attention_mask.size() == (1, 1, q_len_valid, kv_len)

                assert last_position_ids.size() == (1, q_len)
                last_position_ids = last_position_ids[:, valid_mask]
                assert last_position_ids.size() == (1, q_len_valid)

                assert output_logits.size() == (q_len, nc * v)
                output_logits = output_logits[valid_mask]
                assert output_logits.size() == (q_len_valid, nc * v)

                assert last_output_score.size() == (q_len, nc * v)
                last_output_score = last_output_score[valid_mask]
                assert last_output_score.size() == (q_len_valid, nc * v)

            output_log_probs = output_logits.reshape(q_len_valid, nc, v)[:, 0].gather(
                -1,
                output_topk_ids[0, valid_mask, i, :self.top_node].reshape(q_len_valid, self.top_node),
            )
            output_ids = output_topk_ids[0, valid_mask, i, 0].reshape(q_len_valid, 1)

            assert output_ids.size() == (q_len_valid, 1)

            last_output_score_for_sampling = last_output_score.scatter(
                -1,
                output_ids,
                max_value,
            )

            sampled_draft_idxs = torch.topk(
                last_output_score_for_sampling,
                self.top_draft,
                dim=-1,
            ).indices

            sampled_node_idxs = torch.topk(
                last_output_score_for_sampling,
                self.top_node,
                dim=-1,
            ).indices

            last_output_node_score = last_output_score.gather(
                -1,
                sampled_node_idxs,
            )
            assert last_output_node_score.size() == (q_len_valid, self.top_node)

            last_output_score = last_output_score.gather(
                -1,
                sampled_draft_idxs,
            )
            assert last_output_score.size() == (q_len_valid, self.top_draft)

            last_input_ids = (sampled_draft_idxs % v).reshape(1, q_len_valid * self.top_draft)
            assert last_input_ids.size() == (1, q_len_valid * self.top_draft)

            last_position_ids = last_position_ids.reshape(q_len_valid, nc).gather(
                1, (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft)
            )
            assert last_position_ids.size() == (q_len_valid, self.top_draft)
            last_position_ids = last_position_ids.reshape(1, q_len_valid * self.top_draft)
            assert last_position_ids.size() == (1, q_len_valid * self.top_draft)

            last_hidden_states = last_hidden_states.reshape(q_len_valid, nc, h).gather(
                1,
                (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, h)
            )
            assert last_hidden_states.size() == (q_len_valid, self.top_draft, h)
            last_hidden_states = last_hidden_states.reshape(1, q_len_valid * self.top_draft, h)
            assert last_hidden_states.size() == (1, q_len_valid * self.top_draft, h)

            past_kv_len = last_attention_mask.size(-1)
            assert last_attention_mask.size() == (1, 1, q_len_valid * nc, past_kv_len)
            last_attention_mask = last_attention_mask.reshape(q_len_valid, nc, past_kv_len).gather(
                1, (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, past_kv_len)
            )
            assert last_attention_mask.size() == (q_len_valid, self.top_draft, past_kv_len)
            last_attention_mask = last_attention_mask.reshape(
                1, 1, q_len_valid * self.top_draft, past_kv_len,
            )
            assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len)
            new_attention_mask = torch.full(
                (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft),
                min_value,
                device=model.device,
                dtype=model.dtype,
            )
            new_attention_mask[0, 0].diagonal(0).fill_(0)
            assert new_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft)
            last_attention_mask = torch.cat([
                last_attention_mask,
                new_attention_mask,
            ], dim=-1).requires_grad_(False)
            assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len + q_len_valid * self.top_draft)

            total_output_log_probs.append(output_log_probs)
            total_output_scores.append(last_output_node_score)

        past_key_values.free(kv_cache_indices)

        total_output_log_probs = torch.stack(total_output_log_probs, dim=0)
        assert total_output_log_probs.size() == (self.depth, q_len_valid, self.top_node)
        total_output_scores = torch.stack(total_output_scores, dim=0)
        assert total_output_scores.size() == (self.depth, q_len_valid, self.top_node)

        metric = total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < self.top_draft

        metric = metric.float().cumprod(dim=0).sum(dim=0).mean()

        return metric


class DraftCausalLMTreeLoss:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        depth: int = 3,
        top_draft: int = 4,
        top_node: int = 16,
        **kwargs
    ):
        self.model = model
        self.device = model.device
        self.dtype = model.dtype
        self.ignore_index = ignore_index
        self.num_seqs = num_seqs
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult
        self.depth = depth
        self.top_draft = top_draft
        self.top_node = top_node

        self._log_idx = 0
        self._log_metrics = defaultdict(list)

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        # Upcast to float if we need to compute the loss to avoid potential precision issues

        loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
        loss_masks[loss_masks == self.ignore_index] = 0
        output_topk_ids = inputs["output_topk_ids"]
        output_topk_probs = inputs["output_topk_probs"]

        min_value = torch.finfo(model.dtype).min
        max_value = torch.finfo(model.dtype).max

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            self.model.config.num_key_value_heads,
            q_len + q_len_valid * (self.depth - 1) * self.top_draft,
            self.model.config.head_dim,
            self.model.device,
            self.model.dtype,
            num_seqs=self.num_seqs,
        )

        valid_mask = loss_masks.flatten().bool()

        kv_cache_indices = []

        last_input_ids = inputs["input_ids"]
        last_hidden_states = inputs["hidden_states"]
        last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)
        last_output_score = torch.zeros((q_len, 1), device=model.device, dtype=model.dtype)

        last_attention_mask = torch.full(
            (labels.size(0), 1, q_len, q_len),
            min_value,
            device=model.device,
            dtype=model.dtype,
        )

        last_attention_mask[0, 0] = torch.triu(
            last_attention_mask[0, 0],
            diagonal=1,
        )

        total_output_log_probs = []
        total_output_scores = []
        total_sampled_logits = []

        for i in range(self.depth):
            ql = q_len if i == 0 else q_len_valid
            nc = self.top_draft if i > 0 else 1

            new_kv_cache_indices = past_key_values.allocate(ql * nc)
            kv_cache_indices.extend(new_kv_cache_indices)

            model_inputs = {}
            model_inputs["input_ids"] = last_input_ids
            model_inputs["position_ids"] = last_position_ids
            model_inputs["hidden_states"] = last_hidden_states
            model_inputs["attention_mask"] = last_attention_mask
            model_inputs["past_key_values"] = past_key_values
            model_inputs["past_key_value_indices"] = kv_cache_indices
            model_inputs["use_cache"] = True

            output = model(
                **model_inputs,
                shift_tokens=i == 0,
                cut_last_token=i == 0,
            )

            if q_len_valid == 0:
                if return_outputs:
                    return torch.tensor(-1.0, device=self.device, dtype=self.dtype, requires_grad=True), output
                return torch.tensor(-1.0, device=self.device, dtype=self.dtype, requires_grad=True)

            v = output["logits"].size(-1)
            output_logits = output["logits"].log_softmax(-1).reshape(ql, nc * v)
            assert output_logits.size() == (ql, nc * v)

            assert last_output_score.size() == (ql, nc)
            # left
            # last_output_score = last_output_score.reshape(ql, nc, 1).detach() + output_logits.reshape(ql, nc, v)
            # right
            last_output_score = last_output_score.reshape(ql, nc, 1) + output_logits.reshape(ql, nc, v)
            last_output_score = last_output_score.reshape(ql, nc * v)

            last_hidden_states = output["draft_hidden_states"]
            h = last_hidden_states.size(-1)
            assert last_hidden_states.size() == (1, ql * nc, h)

            last_position_ids = last_position_ids + 1
            assert last_position_ids.size() == (1, ql * nc)

            if i == 0:
                assert last_hidden_states.size() == (1, q_len * nc, h)
                last_hidden_states = last_hidden_states.reshape(1, q_len, nc, h)
                last_hidden_states = last_hidden_states[:, valid_mask]
                assert last_hidden_states.size() == (1, q_len_valid, nc, h)
                last_hidden_states = last_hidden_states.reshape(1, -1, h)
                assert last_hidden_states.size() == (1, q_len_valid * nc, h)

                kv_len = last_attention_mask.size(-1)
                assert last_attention_mask.size() == (1, 1, q_len, kv_len)
                last_attention_mask = last_attention_mask.reshape(1, 1, ql, kv_len)
                last_attention_mask = last_attention_mask[:, :, valid_mask]
                assert last_attention_mask.size() == (1, 1, q_len_valid, kv_len)

                assert last_position_ids.size() == (1, q_len)
                last_position_ids = last_position_ids[:, valid_mask]
                assert last_position_ids.size() == (1, q_len_valid)

                assert output_logits.size() == (q_len, nc * v)
                output_logits = output_logits[valid_mask]
                assert output_logits.size() == (q_len_valid, nc * v)

                assert last_output_score.size() == (q_len, nc * v)
                last_output_score = last_output_score[valid_mask]
                assert last_output_score.size() == (q_len_valid, nc * v)

            output_log_probs = output_logits.reshape(q_len_valid, nc, v)[:, 0].gather(
                -1,
                output_topk_ids[0, valid_mask, i, :self.top_node].reshape(q_len_valid, self.top_node),
            )
            output_ids = output_topk_ids[0, valid_mask, i, 0].reshape(q_len_valid, 1)

            assert output_ids.size() == (q_len_valid, 1)

            # last_output_score_for_sampling = last_output_score.scatter(
            #     -1,
            #     output_ids,
            #     max_value,
            # )

            # sampled_draft_idxs = torch.topk(
            #     last_output_score_for_sampling,
            #     self.top_draft,
            #     dim=-1,
            # ).indices

            # sampled_node_idxs = torch.topk(
            #     last_output_score_for_sampling,
            #     self.top_node,
            #     dim=-1,
            # ).indices

            sampled_node_idxs = inputs["sample_idxs"][:, valid_mask, i, :self.top_node].reshape(q_len_valid, self.top_node)
            sampled_draft_idxs = inputs["sample_idxs"][:, valid_mask, i, :self.top_draft].reshape(q_len_valid, self.top_draft)

            assert sampled_node_idxs.size() == (q_len_valid, self.top_node)
            assert sampled_draft_idxs.size() == (q_len_valid, self.top_draft)

            last_output_node_score = last_output_score.gather(
                -1,
                sampled_node_idxs,
            )
            assert last_output_node_score.size() == (q_len_valid, self.top_node)

            last_output_score = last_output_score.gather(
                -1,
                sampled_draft_idxs,
            )
            assert last_output_score.size() == (q_len_valid, self.top_draft)

            sampled_logits = output_logits.gather(
                -1,
                sampled_node_idxs,
            )
            assert sampled_logits.size() == (q_len_valid, self.top_node)

            last_input_ids = (sampled_draft_idxs % v).reshape(1, q_len_valid * self.top_draft)
            assert last_input_ids.size() == (1, q_len_valid * self.top_draft)

            last_position_ids = last_position_ids.reshape(q_len_valid, nc).gather(
                1, (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft)
            )
            assert last_position_ids.size() == (q_len_valid, self.top_draft)
            last_position_ids = last_position_ids.reshape(1, q_len_valid * self.top_draft)
            assert last_position_ids.size() == (1, q_len_valid * self.top_draft)

            last_hidden_states = last_hidden_states.reshape(q_len_valid, nc, h).gather(
                1,
                (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, h)
            )
            assert last_hidden_states.size() == (q_len_valid, self.top_draft, h)
            last_hidden_states = last_hidden_states.reshape(1, q_len_valid * self.top_draft, h)
            assert last_hidden_states.size() == (1, q_len_valid * self.top_draft, h)

            past_kv_len = last_attention_mask.size(-1)
            assert last_attention_mask.size() == (1, 1, q_len_valid * nc, past_kv_len)
            last_attention_mask = last_attention_mask.reshape(q_len_valid, nc, past_kv_len).gather(
                1, (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, past_kv_len)
            )
            assert last_attention_mask.size() == (q_len_valid, self.top_draft, past_kv_len)
            last_attention_mask = last_attention_mask.reshape(
                1, 1, q_len_valid * self.top_draft, past_kv_len,
            )
            assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len)
            new_attention_mask = torch.full(
                (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft),
                min_value,
                device=model.device,
                dtype=model.dtype,
            )
            new_attention_mask[0, 0].diagonal(0).fill_(0)
            assert new_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft)
            last_attention_mask = torch.cat([
                last_attention_mask,
                new_attention_mask,
            ], dim=-1)
            assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len + q_len_valid * self.top_draft)

            total_output_log_probs.append(output_log_probs)
            total_output_scores.append(last_output_node_score)
            total_sampled_logits.append(sampled_logits)

        past_key_values.free(kv_cache_indices)

        total_output_log_probs = torch.stack(total_output_log_probs, dim=1)
        assert total_output_log_probs.size() == (q_len_valid, self.depth, self.top_node)
        total_output_scores = torch.stack(total_output_scores, dim=1)
        assert total_output_scores.size() == (q_len_valid, self.depth, self.top_node)
        total_sampled_logits = torch.stack(total_sampled_logits, dim=1)
        assert total_sampled_logits.size() == (q_len_valid, self.depth, self.top_node)

        loss = -total_sampled_logits * inputs["sample_logits"][0, valid_mask, :self.depth, :self.top_node].exp()
        loss = loss.sum(dim=-1).mean()

        sampled_logits = total_output_log_probs
        sampled_target_probs = output_topk_probs[0, valid_mask, :self.depth, :self.top_node].reshape(q_len_valid, self.depth, self.top_node)
        # sampled_target_cum_scores = output_topk_probs[0, :self.depth, valid_mask, :1].log().cumsum(dim=0)
        # sampled_target_scores = output_topk_probs[0, :self.depth, valid_mask, :self.top_node].reshape(self.depth, q_len_valid, self.top_node).log().cumsum(dim=0)
        # sampled_target_scores[1:] = sampled_target_scores[1:] + sampled_target_cum_scores[:-1]
        # loss = -torch.sum(sampled_logits * sampled_target_probs, dim=-1).mean()
        # loss = torch.sum((sampled_logits.cumsum(dim=0) - sampled_target_scores) ** 2, dim=-1).mean()
        # ce_loss = -sampled_logits[:, :, :1]

        # right
        # import pdb; pdb.set_trace()
        # ce_loss_values, ce_loss_idx = (1 - (total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < self.top_draft).float().cumprod(dim=0)).max(dim=1)
        # ce_loss_mask = torch.zeros_like(ce_loss)
        # ce_loss_mask[torch.arange(0, q_len_valid), ce_loss_idx[:, 0], 0] = 1
        # ce_loss_mask[ce_loss_values.flatten() == 0] = 0

        # left
        # ce_loss_mask = (total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < self.top_draft).float().cumprod(dim=0)

        # ce_loss_mask = (total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < self.top_draft).float().cumprod(dim=0).flip(0).cumsum(dim=0).flip(0)

        # import pdb; pdb.set_trace()

        # target_sampled_scores = inputs["sampled_scores"][:, valid_mask, :self.depth, :self.top_draft].reshape(q_len_valid, self.depth, self.top_draft)
        # predict_sampled_scores = total_output_scores[..., :self.top_draft]
        # loss = F.l1_loss(target_sampled_scores, predict_sampled_scores, reduction="none")
        # print(loss.mean())
        # import pdb; pdb.set_trace()
        # loss = F.mse_loss(target_sampled_scores, predict_sampled_scores, reduction="none")
        # loss[..., 1:] = loss[..., 1:] * 0.01
        # loss = loss.sum(dim=-1).mean()
        # loss = loss * 0.01
        # loss = loss[-1, ..., 0].mean()
        # loss = loss[:, ..., 0].mean()

        # ce_loss_mask = \
        #     (1 - (total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < self.top_draft).float().cumprod(dim=0)) \
        #     * (total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < int(self.top_draft * 1.5)).float().cumprod(dim=0)

        # ce_loss_mask = torch.ones_like(ce_loss, device=model.device, dtype=torch.bool)

        # left
        # ce_loss = (ce_loss * ce_loss_mask).sum() / ((ce_loss_mask > 0).float().sum() + self.eps)
        # right
        # ce_loss = (ce_loss * ce_loss_mask).sum() / (ce_loss_mask.sum() + self.eps)
        # import pdb; pdb.set_trace()
        # total_loss = loss + 1 * ce_loss
        # total_loss = ce_loss # [TODO]
        total_loss = loss
        total_loss = total_loss * self.mult

        metric = total_output_scores.argsort(dim=-1, descending=True).argsort(dim=-1, descending=False).float()[:, :, :1] < self.top_draft

        metric = metric.float().cumprod(dim=1).sum(dim=1).mean()

        # import pdb; pdb.set_trace()

        if Accelerator().is_main_process:
            self._log_idx += 1
            if not total_loss.isnan().item():
                self._log_metrics["total_loss"].append(total_loss.item())
                self._log_metrics["logits_mean"].append(output["logits"].abs().mean().item())
                self._log_metrics[f"metric"].append(metric.item())
            if self._log_idx == 512:
                self._log_idx = 0
                print()
                for key, value in self._log_metrics.items():
                    print(f"{key}: {np.mean(value)}")
                self._log_metrics.clear()
                # import pdb; pdb.set_trace()

        # print(metric)

        if return_outputs:
            return total_loss, output

        return total_loss


# class DraftCausalLMGRPOLoss:
#     def __init__(
#         self,
#         model,
#         num_seqs: int = 1,
#         ignore_index: int = -100,
#         eps: float = 1e-6,
#         p_weight: float = 0.1,
#         top_k: int = 0,
#         mult: float = 1.0,
#         depth: int = 3,
#         repeat: int = 4,
#         top_draft: int = 4,
#         **kwargs
#     ):
#         self.model = model
#         self.device = model.device
#         self.dtype = model.dtype
#         self.ignore_index = ignore_index
#         self.num_seqs = num_seqs
#         self.eps = eps
#         self.p_weight = p_weight
#         self.top_k = top_k
#         self.mult = mult
#         self.depth = depth
#         self.repeat = repeat
#         self.top_draft = top_draft

#         self._log_idx = 0
#         self._log_metrics = defaultdict(list)

#         self.chained_rl_loss = DraftCausalLMChainedRLLoss(
#             model,
#             num_seqs=num_seqs,
#             ignore_index=ignore_index,
#             eps=eps,
#             p_weight=p_weight,
#             top_k=top_k,
#             mult=mult,
#             depth=depth,
#         )

#     def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
#         # Upcast to float if we need to compute the loss to avoid potential precision issues

#         loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
#         loss_masks[loss_masks == self.ignore_index] = 0
#         output_topk_ids = inputs["output_topk_ids"]

#         min_value = torch.finfo(self.dtype).min

#         q_len = labels.size(1)
#         q_len_valid = int(loss_masks.sum().item())

#         past_key_values = KVCache(
#             self.model.config.num_hidden_layers,
#             self.model.config.num_key_value_heads,
#             q_len + q_len_valid * (self.depth - 1) * self.top_draft,
#             self.model.config.head_dim,
#             self.model.device,
#             self.model.dtype,
#             num_seqs=self.num_seqs,
#         )

#         valid_lens = torch.cat([
#             distance_to_next_zero_1d(output_topk_ids[0, 0, loss_masks.bool().flatten(), 0][:-1] == labels[0, loss_masks.bool().flatten()][1:]),
#             torch.zeros((1,), device=model.device, dtype=torch.long),
#         ], dim=-1)
#         valid_mask = loss_masks.bool().flatten()

#         repeat_total_output_log_probs = []
#         repeat_total_rewards = []

#         for re in range(self.repeat):
#             kv_cache_indices = []

#             last_input_ids = inputs["input_ids"]
#             last_hidden_states = inputs["hidden_states"]
#             last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)
#             last_output_score = torch.zeros((q_len, 1), device=model.device, dtype=torch.float64)
#             last_rewards = torch.zeros((1, q_len_valid, 1), device=model.device, dtype=model.dtype)

#             last_attention_mask = torch.full(
#                 (labels.size(0), 1, q_len, q_len),
#                 min_value,
#                 device=model.device,
#                 dtype=model.dtype,
#                 requires_grad=False,
#             )

#             last_attention_mask[0, 0] = torch.triu(
#                 last_attention_mask[0, 0],
#                 diagonal=1,
#             )

#             total_output_ids = []
#             total_output_idxs = []
#             total_output_log_probs = []
#             total_rewards = []

#             for i in range(self.depth):
#                 ql = q_len if i == 0 else q_len_valid
#                 nc = self.top_draft if i > 0 else 1

#                 new_kv_cache_indices = past_key_values.allocate(ql * nc)
#                 kv_cache_indices.extend(new_kv_cache_indices)

#                 model_inputs = {}
#                 model_inputs["input_ids"] = last_input_ids
#                 model_inputs["position_ids"] = last_position_ids
#                 model_inputs["hidden_states"] = last_hidden_states
#                 model_inputs["attention_mask"] = last_attention_mask
#                 model_inputs["past_key_values"] = past_key_values
#                 model_inputs["past_key_value_indices"] = kv_cache_indices
#                 model_inputs["use_cache"] = True

#                 output = model(
#                     **model_inputs,
#                     shift_tokens=i == 0,
#                     cut_last_token=i == 0,
#                 )

#                 if q_len_valid < self.depth:
#                     if return_outputs:
#                         return torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True), output
#                     return torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True)

#                 v = output["logits"].size(-1)
#                 # output_logits = -F.softplus(output["logits"]).reshape(ql, nc * v)
#                 # output_logits = F.sigmoid(output["logits"]).log().reshape(ql, nc * v)
#                 output_logits = output["logits"].to(torch.float64).reshape(ql, nc * v).log_softmax(-1)
#                 assert output_logits.size() == (ql, nc * v)
#                 assert last_output_score.size() == (ql, nc)
#                 output_score = output_logits
#                 assert output_score.size() == (ql, nc * v)
#                 output_probs = output_score.exp()
#                 output_log_probs = output_score
#                 assert output_probs.size() == (ql, nc * v)

#                 sampled_idxs = torch.multinomial(output_probs, self.top_draft, replacement=True)
#                 assert sampled_idxs.size() == (ql, self.top_draft)
#                 sampled_log_probs = output_log_probs.gather(-1, sampled_idxs)
#                 sampled_log_probs = sampled_log_probs.reshape(1, ql, self.top_draft)
#                 assert sampled_log_probs.size() == (1, ql, self.top_draft)

#                 sampled_idxs = sampled_idxs.reshape(1, ql, self.top_draft)
#                 assert sampled_idxs.size() == (1, ql, self.top_draft)
#                 sampled_p_ids = sampled_idxs % v
#                 assert sampled_p_ids.size() == (1, ql, self.top_draft)
#                 sampled_p_idxs = sampled_idxs // v
#                 assert sampled_idxs.size() == (1, ql, self.top_draft)

#                 last_hidden_states = output["draft_hidden_states"]
#                 h = last_hidden_states.size(-1)
#                 assert last_hidden_states.size() == (1, ql * nc, h)

#                 if i == 0:
#                     assert last_hidden_states.size() == (1, q_len, h)
#                     last_hidden_states = last_hidden_states[:, valid_mask]
#                     assert last_hidden_states.size() == (1, q_len_valid, h)

#                     kv_len = last_attention_mask.size(-1)
#                     assert last_attention_mask.size() == (1, 1, q_len , kv_len)
#                     last_attention_mask = last_attention_mask[:, :, valid_mask]
#                     assert last_attention_mask.size() == (1, 1, q_len_valid, kv_len)

#                     assert last_position_ids.size() == (1, q_len)
#                     last_position_ids = last_position_ids[:, valid_mask]
#                     assert last_position_ids.size() == (1, q_len_valid)

#                     assert output_score.size() == (q_len, v)
#                     output_score = output_score[valid_mask]
#                     assert output_score.size() == (q_len_valid, v)

#                     assert sampled_p_idxs.size() == (1, q_len, self.top_draft)
#                     sampled_p_idxs = sampled_p_idxs[:, valid_mask]
#                     assert sampled_p_idxs.size() == (1, q_len_valid, self.top_draft)

#                     assert sampled_p_ids.size() == (1, q_len, self.top_draft)
#                     sampled_p_ids = sampled_p_ids[:, valid_mask]
#                     assert sampled_p_ids.size() == (1, q_len_valid, self.top_draft)

#                     assert sampled_log_probs.size() == (1, q_len, self.top_draft)
#                     sampled_log_probs = sampled_log_probs[:, valid_mask]
#                     assert sampled_log_probs.size() == (1, q_len_valid, self.top_draft)

#                     assert sampled_idxs.size() == (1, q_len, self.top_draft)
#                     sampled_idxs = sampled_idxs[:, valid_mask]
#                     assert sampled_idxs.size() == (1, q_len_valid, self.top_draft)

#                 last_input_ids = sampled_p_ids.reshape(1, q_len_valid * self.top_draft)
#                 assert last_input_ids.size() == (1, q_len_valid * self.top_draft)

#                 assert last_hidden_states.size() == (1, q_len_valid * nc, h)
#                 last_hidden_states = last_hidden_states.reshape(q_len_valid, nc, h)
#                 assert last_hidden_states.size() == (q_len_valid, nc, h)
#                 last_hidden_states = last_hidden_states.gather(
#                     1,
#                     sampled_p_idxs.reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, h)
#                 )
#                 assert last_hidden_states.size() == (q_len_valid, self.top_draft, h)
#                 last_hidden_states = last_hidden_states.reshape(1, q_len_valid * self.top_draft, h)
#                 assert last_hidden_states.size() == (1, q_len_valid * self.top_draft, h)

#                 past_kv_len = last_attention_mask.size(-1)
#                 assert last_attention_mask.size() == (1, 1, q_len_valid * nc, past_kv_len)
#                 last_attention_mask = last_attention_mask.reshape(q_len_valid, nc, past_kv_len).gather(
#                     1, sampled_p_idxs.reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, past_kv_len)
#                 )
#                 assert last_attention_mask.size() == (q_len_valid, self.top_draft, past_kv_len)
#                 last_attention_mask = last_attention_mask.reshape(
#                     1, 1, q_len_valid * self.top_draft, past_kv_len,
#                 )
#                 assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len)
#                 new_attention_mask = torch.full(
#                     (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft),
#                     0,
#                     device=model.device,
#                     dtype=model.dtype,
#                 )
#                 # new_attention_mask[0, 0].diagonal(0).fill_(0)
#                 assert new_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft)
#                 last_attention_mask = torch.cat([
#                     last_attention_mask,
#                     new_attention_mask,
#                 ], dim=-1).requires_grad_(False)
#                 assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len + q_len_valid * self.top_draft)
#                 last_position_ids = last_position_ids + 1
#                 assert last_position_ids.size() == (1, q_len_valid * nc)
#                 last_position_ids = last_position_ids.reshape(q_len_valid, nc).gather(
#                     1, sampled_p_idxs.reshape(q_len_valid, self.top_draft)
#                 )
#                 assert last_position_ids.size() == (q_len_valid, self.top_draft)
#                 last_position_ids = last_position_ids.reshape(1, q_len_valid * self.top_draft)
#                 assert last_position_ids.size() == (1, q_len_valid * self.top_draft)

#                 last_output_score = output_score.reshape(1, q_len_valid, nc * v).gather(-1, sampled_idxs)
#                 assert last_output_score.size() == (1, q_len_valid, self.top_draft)
#                 last_output_score = last_output_score.reshape(q_len_valid, self.top_draft)
#                 assert last_output_score.size() == (q_len_valid, self.top_draft)
#                 last_output_score = last_output_score - last_output_score.max(-1, keepdim=True).values.detach()

#                 with torch.no_grad():
#                     targets = labels[:, i + 1:]
#                     current_rewards = (
#                         last_input_ids.reshape(1, q_len_valid, self.top_draft) \
#                         == torch.cat([
#                             targets,
#                             torch.zeros((1, q_len - targets.size(1)), device=model.device, dtype=torch.long),
#                         ], dim=-1).reshape(1, -1, 1)[:, valid_mask]
#                     )
#                     assert current_rewards.size() == (1, q_len_valid, self.top_draft)
#                     assert last_rewards.size() == (1, q_len_valid, nc)
#                     assert sampled_p_idxs.size() == (1, q_len_valid, self.top_draft)
#                     last_rewards = last_rewards.gather(-1, sampled_p_idxs)
#                     assert last_rewards.size() == (1, q_len_valid, self.top_draft)

#                     is_last_rewards_cont = (last_rewards == i) & (valid_lens >= i + 1).reshape(1, -1, 1)
#                     last_rewards = last_rewards + (current_rewards & is_last_rewards_cont).float()

#                 total_output_ids.append(last_input_ids)
#                 total_output_idxs.append(sampled_p_idxs)
#                 total_output_log_probs.append(sampled_log_probs)
#                 total_rewards.append(last_rewards)

#             with torch.no_grad():
#                 for i in range(self.depth - 1, 0, -1):
#                     prev_rewards = total_rewards[i - 1][0]  # shape: (q_len_valid, ?)
#                     current_rewards = total_rewards[i][0]   # shape: (q_len_valid, top_draft)
#                     indices = total_output_idxs[i][0]         # shape: (q_len_valid, top_draft)

#                     total_rewards[i - 1][0] = prev_rewards.scatter_reduce(
#                         dim=1,
#                         index=indices,
#                         src=current_rewards,
#                         reduce='amax',
#                         include_self=True
#                     )
#                 # import pdb; pdb.set_trace()
#                 for i in range(self.depth):
#                     total_rewards[i][0] = total_rewards[0][0].max(1, keepdim=True).values.expand(-1, self.top_draft)

#                 # import pdb; pdb.set_trace()

#             past_key_values.free(kv_cache_indices)

#             total_rewards = torch.stack(total_rewards, dim=1)
#             total_output_log_probs = torch.stack(total_output_log_probs, dim=1)

#             repeat_total_rewards.append(total_rewards)
#             repeat_total_output_log_probs.append(total_output_log_probs)

#         repeat_total_rewards = torch.stack(repeat_total_rewards, dim=1)
#         repeat_total_output_log_probs = torch.stack(repeat_total_output_log_probs, dim=1)

#         with torch.no_grad():
#             assert repeat_total_rewards.size() == (1, self.repeat, self.depth, q_len_valid, self.top_draft)
#             rewards = repeat_total_rewards.permute(0, 1, 3, 2, 4).reshape(1, self.repeat, q_len_valid, self.depth * self.top_draft).max(-1).values
#             rewards_mean = rewards.mean(1, keepdim=True)
#             rewards_std = rewards.std(1, keepdim=True)
#             advantages = (rewards - rewards_mean) / (rewards_std + self.eps)
#             advantages = advantages.reshape(1, self.repeat, 1, q_len_valid, 1)

#             rewards_mean = repeat_total_rewards.max(dim=-1, keepdim=True).values.mean(1, keepdim=True)

#         loss_per_token = -advantages * (repeat_total_output_log_probs - repeat_total_output_log_probs.detach()).exp()
#         assert loss_per_token.size() == (1, self.repeat, self.depth, q_len_valid, self.top_draft)
#         losses = []
#         mean_advantages = []
#         mean_returns = []
#         top_returns = []
#         for depth in range(1, self.depth + 1):
#             if depth == self.depth:
#                 valid_mask = valid_lens >= depth
#             else:
#                 valid_mask = valid_lens == depth
#             if valid_mask.sum() == 0:
#                 losses.append(torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True))
#                 mean_returns.append(torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True))
#                 mean_advantages.append(torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True))
#                 top_returns.append(torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True))
#                 continue

#             loss = loss_per_token[:, :, :depth, valid_mask].mean().float()
#             losses.append(loss)
#             mean_returns.append(rewards_mean[:, :, 0, valid_mask].mean())
#             mean_advantages.append(advantages[:, :, 0, valid_mask].mean())
#             top_returns.append(repeat_total_rewards[:, 0, 0, valid_mask].max(-1).values.mean())

#         # weight = len(losses) - torch.arange(0, len(losses), device=loss.device, dtype=loss.dtype)
#         # weight = torch.arange(0, len(losses), device=loss.device, dtype=loss.dtype) + 1
#         weight = torch.ones(len(losses), device=losses[0].device, dtype=losses[0].dtype)
#         loss = torch.stack(losses) * weight
#         loss = loss.sum() / weight.sum()

#         # chained_rl_loss, output = self.chained_rl_loss(model, inputs, labels, return_outputs=True)
#         total_loss = loss #+ chained_rl_loss
#         total_loss = total_loss * self.mult

#         if Accelerator().is_main_process:
#             # total_loss.backward(retain_graph=True)
#             # for weight in model.parameters():
#             #     if weight.grad is not None:
#             #         if weight.grad.isnan().any() or weight.grad.isinf().any():
#             #             import pdb; pdb.set_trace()
#             self._log_idx += 1
#             if not total_loss.isnan().item():
#                 self._log_metrics["total_loss"].append(total_loss.item())
#                 self._log_metrics["logits_mean"].append(output["logits"].abs().mean().item())
#                 for i in range(self.depth):
#                     self._log_metrics[f"returns_{i}"].append(mean_returns[i].item())
#                     # self._log_metrics[f"advantages_{i}"].append(mean_advantages[i].item())
#                     self._log_metrics[f"top_returns_{i}"].append(top_returns[i].item())
#             if self._log_idx == 16:
#                 self._log_idx = 0
#                 for key, value in self._log_metrics.items():
#                     print(f"{key}: {np.mean(value)}")
#                 self._log_metrics.clear()

#         if return_outputs:
#             return total_loss, output

#         return total_loss


# class DraftCausalLMGRPOLoss:
#     def __init__(
#         self,
#         model,
#         num_seqs: int = 1,
#         ignore_index: int = -100,
#         eps: float = 1e-6,
#         p_weight: float = 0.1,
#         top_k: int = 0,
#         mult: float = 1.0,
#         depth: int = 3,
#         repeat: int = 4,
#         top_draft: int = 4,
#         **kwargs
#     ):
#         self.model = model
#         self.device = model.device
#         self.dtype = model.dtype
#         self.ignore_index = ignore_index
#         self.num_seqs = num_seqs
#         self.eps = eps
#         self.p_weight = p_weight
#         self.top_k = top_k
#         self.mult = mult
#         self.depth = depth
#         self.repeat = repeat
#         self.top_draft = top_draft

#         self._log_idx = 0
#         self._log_metrics = defaultdict(list)

#         self._tree_metric = DraftCausalLMTreeMetric(
#             model,
#             num_seqs=num_seqs,
#             ignore_index=ignore_index,
#             eps=eps,
#             top_k=top_k,
#             mult=mult,
#             depth=depth,
#             top_draft=top_draft,
#             top_node=16,
#         )

#     def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
#         # Upcast to float if we need to compute the loss to avoid potential precision issues

#         loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
#         loss_masks[loss_masks == self.ignore_index] = 0

#         min_value = torch.finfo(self.dtype).min

#         q_len = labels.size(1)
#         q_len_valid = int(loss_masks.sum().item())

#         past_key_values = KVCache(
#             self.model.config.num_hidden_layers,
#             self.model.config.num_key_value_heads,
#             q_len + q_len_valid * (self.depth - 1) * self.top_draft,
#             self.model.config.head_dim,
#             self.model.device,
#             self.model.dtype,
#             num_seqs=self.num_seqs,
#         )

#         valid_mask = loss_masks.bool().flatten()

#         repeat_total_sampled_ids = []
#         repeat_total_sampled_log_probs = []

#         for re in range(self.repeat):
#             kv_cache_indices = []

#             last_input_ids = inputs["input_ids"]
#             last_hidden_states = inputs["hidden_states"]
#             last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)

#             last_attention_mask = torch.full(
#                 (labels.size(0), 1, q_len, q_len),
#                 min_value,
#                 device=model.device,
#                 dtype=model.dtype,
#                 requires_grad=False,
#             )

#             last_attention_mask[0, 0] = torch.triu(
#                 last_attention_mask[0, 0],
#                 diagonal=1,
#             )

#             total_sampled_ids = []
#             total_sampled_log_probs = []

#             for i in range(self.depth):
#                 ql = q_len if i == 0 else q_len_valid

#                 new_kv_cache_indices = past_key_values.allocate(ql)
#                 kv_cache_indices.extend(new_kv_cache_indices)

#                 model_inputs = {}
#                 model_inputs["input_ids"] = last_input_ids
#                 model_inputs["position_ids"] = last_position_ids
#                 model_inputs["hidden_states"] = last_hidden_states
#                 model_inputs["attention_mask"] = last_attention_mask
#                 model_inputs["past_key_values"] = past_key_values
#                 model_inputs["past_key_value_indices"] = kv_cache_indices
#                 model_inputs["use_cache"] = True

#                 output = model(
#                     **model_inputs,
#                     shift_tokens=i == 0,
#                     cut_last_token=i == 0,
#                 )

#                 if q_len_valid < self.depth:
#                     if return_outputs:
#                         return torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True), output
#                     return torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True)

#                 v = output["logits"].size(-1)
#                 output_logits = output["logits"].reshape(ql, v).log_softmax(-1)
#                 assert output_logits.size() == (ql, v)
#                 output_probs = output_logits.exp()
#                 output_log_probs = output_logits
#                 assert output_probs.size() == (ql, v)

#                 sampled_ids = torch.multinomial(output_probs, 1, replacement=True)
#                 assert sampled_ids.size() == (ql, 1)
#                 sampled_log_probs = output_log_probs.gather(-1, sampled_ids)
#                 sampled_log_probs = sampled_log_probs.reshape(1, ql)
#                 assert sampled_log_probs.size() == (1, ql)

#                 sampled_ids = sampled_ids.reshape(1, ql)
#                 assert sampled_ids.size() == (1, ql)

#                 last_hidden_states = output["draft_hidden_states"]
#                 h = last_hidden_states.size(-1)
#                 assert last_hidden_states.size() == (1, ql, h)

#                 if i == 0:
#                     assert last_hidden_states.size() == (1, q_len, h)
#                     last_hidden_states = last_hidden_states[:, valid_mask]
#                     assert last_hidden_states.size() == (1, q_len_valid, h)

#                     kv_len = last_attention_mask.size(-1)
#                     assert last_attention_mask.size() == (1, 1, q_len , kv_len)
#                     last_attention_mask = last_attention_mask[:, :, valid_mask]
#                     assert last_attention_mask.size() == (1, 1, q_len_valid, kv_len)

#                     assert last_position_ids.size() == (1, q_len)
#                     last_position_ids = last_position_ids[:, valid_mask]
#                     assert last_position_ids.size() == (1, q_len_valid)

#                     assert sampled_ids.size() == (1, q_len)
#                     sampled_ids = sampled_ids[:, valid_mask]
#                     assert sampled_ids.size() == (1, q_len_valid)

#                     assert sampled_log_probs.size() == (1, q_len)
#                     sampled_log_probs = sampled_log_probs[:, valid_mask]
#                     assert sampled_log_probs.size() == (1, q_len_valid)

#                 last_input_ids = sampled_ids.reshape(1, q_len_valid)
#                 assert last_input_ids.size() == (1, q_len_valid)

#                 assert last_hidden_states.size() == (1, q_len_valid, h)

#                 past_kv_len = last_attention_mask.size(-1)
#                 assert last_attention_mask.size() == (1, 1, q_len_valid, past_kv_len)
#                 new_attention_mask = torch.full(
#                     (1, 1, q_len_valid, q_len_valid),
#                     0,
#                     device=model.device,
#                     dtype=model.dtype,
#                 )
#                 assert new_attention_mask.size() == (1, 1, q_len_valid, q_len_valid)
#                 last_attention_mask = torch.cat([
#                     last_attention_mask,
#                     new_attention_mask,
#                 ], dim=-1).requires_grad_(False)
#                 assert last_attention_mask.size() == (1, 1, q_len_valid, past_kv_len + q_len_valid)

#                 last_position_ids = last_position_ids + 1
#                 assert last_position_ids.size() == (1, q_len_valid)

#                 total_sampled_ids.append(sampled_ids)
#                 total_sampled_log_probs.append(sampled_log_probs)

#             past_key_values.free(kv_cache_indices)

#             total_sampled_ids = torch.stack(total_sampled_ids, dim=1)
#             total_sampled_log_probs = torch.stack(total_sampled_log_probs, dim=1)

#             assert total_sampled_ids.size() == (1, self.depth, q_len_valid)
#             assert total_sampled_log_probs.size() == (1, self.depth, q_len_valid)

#             repeat_total_sampled_ids.append(total_sampled_ids)
#             repeat_total_sampled_log_probs.append(total_sampled_log_probs)

#         repeat_total_sampled_ids = torch.stack(repeat_total_sampled_ids, dim=1)
#         repeat_total_sampled_log_probs = torch.stack(repeat_total_sampled_log_probs, dim=1)

#         assert repeat_total_sampled_ids.size() == (1, self.repeat, self.depth, q_len_valid)
#         assert repeat_total_sampled_log_probs.size() == (1, self.repeat, self.depth, q_len_valid)

#         target_sampled_ids = inputs["sampled_ids"][:, :, :self.depth, valid_mask]
#         target_repeat = target_sampled_ids.size(1)
#         assert target_sampled_ids.size() == (1, target_repeat, self.depth, q_len_valid)

#         is_correct = target_sampled_ids.reshape(1, 1, target_repeat, self.depth, q_len_valid) \
#             == repeat_total_sampled_ids.reshape(1, self.repeat, 1, self.depth, q_len_valid)
#         rewards = is_correct.float().cumprod(dim=-2).sum(dim=-2).mean(dim=2)
#         assert rewards.size() == (1, self.repeat, q_len_valid)

#         with torch.no_grad():
#             assert rewards.size() == (1, self.repeat, q_len_valid)
#             rewards_mean = rewards.mean(1, keepdim=True)
#             rewards_std = rewards.std(1, keepdim=True)
#             advantages = (rewards - rewards_mean) / (rewards_std + self.eps)
#             advantages = advantages.reshape(1, self.repeat, 1, q_len_valid)

#         loss_per_token = -advantages * (repeat_total_sampled_log_probs - repeat_total_sampled_log_probs.detach()).exp()
#         assert loss_per_token.size() == (1, self.repeat, self.depth, q_len_valid)
#         losses = []

#         losses.append(loss_per_token.mean())

#         # weight = len(losses) - torch.arange(0, len(losses), device=loss.device, dtype=loss.dtype)
#         # weight = torch.arange(0, len(losses), device=loss.device, dtype=loss.dtype) + 1
#         weight = torch.ones(len(losses), device=losses[0].device, dtype=losses[0].dtype)
#         loss = torch.stack(losses) * weight
#         loss = loss.sum() / weight.sum()

#         # chained_rl_loss, output = self.chained_rl_loss(model, inputs, labels, return_outputs=True)
#         total_loss = loss #+ chained_rl_loss
#         total_loss = total_loss * self.mult

#         if Accelerator().is_main_process:
#             self._log_idx += 1
#             if not total_loss.isnan().item():
#                 self._log_metrics["total_loss"].append(total_loss.item())
#                 self._log_metrics["logits_mean"].append(output["logits"].abs().mean().item())

#                 # if self._log_idx % 16 == 0:
#                 #     metric = self._tree_metric(model, inputs, labels, return_outputs=False)
#                 #     self._log_metrics[f"metric"].append(metric.item())
#             if self._log_idx == 16:
#                 self._log_idx = 0
#                 for key, value in self._log_metrics.items():
#                     print(f"{key}: {np.mean(value)}")
#                 self._log_metrics.clear()

#         if return_outputs:
#             return total_loss, output

#         return total_loss



class DraftCausalLMGRPOLoss:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        depth: int = 3,
        repeat: int = 4,
        top_draft: int = 4,
        **kwargs
    ):
        self.model = model
        self.device = model.device
        self.dtype = model.dtype
        self.ignore_index = ignore_index
        self.num_seqs = num_seqs
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult
        self.depth = depth
        self.repeat = repeat
        self.top_draft = top_draft

        self._log_idx = 0
        self._log_metrics = defaultdict(list)

        self._tree_metric = DraftCausalLMTreeMetric(
            model,
            num_seqs=num_seqs,
            ignore_index=ignore_index,
            eps=eps,
            top_k=top_k,
            mult=mult,
            depth=depth,
            top_draft=top_draft,
            top_node=16,
        )

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        # Upcast to float if we need to compute the loss to avoid potential precision issues

        loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
        loss_masks[loss_masks == self.ignore_index] = 0

        min_value = torch.finfo(self.dtype).min

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        past_key_values = KVCache(
            self.model.config.num_hidden_layers,
            self.model.config.num_key_value_heads,
            q_len + q_len_valid * (self.depth - 1),
            self.model.config.head_dim,
            self.model.device,
            self.model.dtype,
            num_seqs=self.num_seqs,
        )

        valid_mask = loss_masks.bool().flatten()

        repeat_total_sampled_ids = []
        repeat_total_sampled_log_probs = []

        for re in range(self.repeat):
            kv_cache_indices = []

            last_input_ids = inputs["input_ids"]
            last_hidden_states = inputs["hidden_states"]
            last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)

            last_attention_mask = torch.full(
                (labels.size(0), 1, q_len, q_len),
                min_value,
                device=model.device,
                dtype=model.dtype,
                requires_grad=False,
            )

            last_attention_mask[0, 0] = torch.triu(
                last_attention_mask[0, 0],
                diagonal=1,
            )

            total_sampled_ids = []
            total_sampled_log_probs = []

            for i in range(self.depth):
                ql = q_len if i == 0 else q_len_valid

                new_kv_cache_indices = past_key_values.allocate(ql)
                kv_cache_indices.extend(new_kv_cache_indices)

                model_inputs = {}
                model_inputs["input_ids"] = last_input_ids
                model_inputs["position_ids"] = last_position_ids
                model_inputs["hidden_states"] = last_hidden_states
                model_inputs["attention_mask"] = last_attention_mask
                model_inputs["past_key_values"] = past_key_values
                model_inputs["past_key_value_indices"] = kv_cache_indices
                model_inputs["use_cache"] = True

                output = model(
                    **model_inputs,
                    shift_tokens=i == 0,
                    cut_last_token=i == 0,
                )

                if q_len_valid < self.depth:
                    if return_outputs:
                        return torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True), output
                    return torch.tensor(0.0, device=model.device, dtype=model.dtype, requires_grad=True)

                v = output["logits"].size(-1)
                output_logits = output["logits"].reshape(ql, v).log_softmax(-1)
                assert output_logits.size() == (ql, v)
                output_probs = output_logits.exp()
                output_log_probs = output_logits
                assert output_probs.size() == (ql, v)

                sampled_ids = inputs["sampled_ids"][:, re, i].reshape(q_len, 1)
                if i > 0:
                    sampled_ids = sampled_ids[valid_mask, :]
                assert sampled_ids.size() == (ql, 1)
                sampled_log_probs = output_log_probs.gather(-1, sampled_ids)
                sampled_log_probs = sampled_log_probs.reshape(1, ql)
                assert sampled_log_probs.size() == (1, ql)

                sampled_ids = sampled_ids.reshape(1, ql)
                assert sampled_ids.size() == (1, ql)

                last_hidden_states = output["draft_hidden_states"]
                h = last_hidden_states.size(-1)
                assert last_hidden_states.size() == (1, ql, h)

                if i == 0:
                    assert last_hidden_states.size() == (1, q_len, h)
                    last_hidden_states = last_hidden_states[:, valid_mask]
                    assert last_hidden_states.size() == (1, q_len_valid, h)

                    kv_len = last_attention_mask.size(-1)
                    assert last_attention_mask.size() == (1, 1, q_len , kv_len)
                    last_attention_mask = last_attention_mask[:, :, valid_mask]
                    assert last_attention_mask.size() == (1, 1, q_len_valid, kv_len)

                    assert last_position_ids.size() == (1, q_len)
                    last_position_ids = last_position_ids[:, valid_mask]
                    assert last_position_ids.size() == (1, q_len_valid)

                    assert sampled_ids.size() == (1, q_len)
                    sampled_ids = sampled_ids[:, valid_mask]
                    assert sampled_ids.size() == (1, q_len_valid)

                    assert sampled_log_probs.size() == (1, q_len)
                    sampled_log_probs = sampled_log_probs[:, valid_mask]
                    assert sampled_log_probs.size() == (1, q_len_valid)

                last_input_ids = sampled_ids.reshape(1, q_len_valid)
                assert last_input_ids.size() == (1, q_len_valid)

                assert last_hidden_states.size() == (1, q_len_valid, h)

                past_kv_len = last_attention_mask.size(-1)
                assert last_attention_mask.size() == (1, 1, q_len_valid, past_kv_len)
                new_attention_mask = torch.full(
                    (1, 1, q_len_valid, q_len_valid),
                    0,
                    device=model.device,
                    dtype=model.dtype,
                )
                assert new_attention_mask.size() == (1, 1, q_len_valid, q_len_valid)
                last_attention_mask = torch.cat([
                    last_attention_mask,
                    new_attention_mask,
                ], dim=-1).requires_grad_(False)
                assert last_attention_mask.size() == (1, 1, q_len_valid, past_kv_len + q_len_valid)

                last_position_ids = last_position_ids + 1
                assert last_position_ids.size() == (1, q_len_valid)

                total_sampled_ids.append(sampled_ids)
                total_sampled_log_probs.append(sampled_log_probs)

            past_key_values.free(kv_cache_indices)

            total_sampled_ids = torch.stack(total_sampled_ids, dim=1)
            total_sampled_log_probs = torch.stack(total_sampled_log_probs, dim=1)

            assert total_sampled_ids.size() == (1, self.depth, q_len_valid)
            assert total_sampled_log_probs.size() == (1, self.depth, q_len_valid)

            repeat_total_sampled_ids.append(total_sampled_ids)
            repeat_total_sampled_log_probs.append(total_sampled_log_probs)

        repeat_total_sampled_ids = torch.stack(repeat_total_sampled_ids, dim=1)
        repeat_total_sampled_log_probs = torch.stack(repeat_total_sampled_log_probs, dim=1)

        assert repeat_total_sampled_ids.size() == (1, self.repeat, self.depth, q_len_valid)
        assert repeat_total_sampled_log_probs.size() == (1, self.repeat, self.depth, q_len_valid)

        loss = -repeat_total_sampled_log_probs[:, :, :].mean()

        # chained_rl_loss, output = self.chained_rl_loss(model, inputs, labels, return_outputs=True)
        total_loss = loss #+ chained_rl_loss
        total_loss = total_loss * self.mult

        if Accelerator().is_main_process:
            self._log_idx += 1
            if not total_loss.isnan().item():
                self._log_metrics["total_loss"].append(total_loss.item())
                self._log_metrics["logits_mean"].append(output["logits"].abs().mean().item())

                # if self._log_idx % 16 == 0:
                #     metric = self._tree_metric(model, inputs, labels, return_outputs=False)
                #     self._log_metrics[f"metric"].append(metric.item())
            if self._log_idx == 512:
                self._log_idx = 0
                for key, value in self._log_metrics.items():
                    print(f"{key}: {np.mean(value)}")
                self._log_metrics.clear()

        if return_outputs:
            return total_loss, output

        return total_loss


class DraftCausalLMQLoss:
    def __init__(
        self,
        model,
        num_seqs: int = 1,
        ignore_index: int = -100,
        eps: float = 1e-6,
        p_weight: float = 0.1,
        top_k: int = 0,
        mult: float = 1.0,
        depth: int = 3,
        top_draft: int = 4,
        top_node: int = 1024,
        **kwargs
    ):
        self.model = model
        self.device = model.device
        self.dtype = model.dtype
        self.ignore_index = ignore_index
        self.num_seqs = num_seqs
        self.eps = eps
        self.p_weight = p_weight
        self.top_k = top_k
        self.mult = mult
        self.depth = depth
        self.top_draft = top_draft
        self.top_node = top_node

        self._log_idx = 0
        self._log_metrics = defaultdict(list)

        self._tree_metric = DraftCausalLMTreeMetric(
            model,
            num_seqs=num_seqs,
            ignore_index=ignore_index,
            eps=eps,
            top_k=top_k,
            mult=mult,
            depth=depth,
            top_draft=top_draft,
            top_node=16,
        )

    def __call__(self, model, inputs, labels, return_outputs=False, weight=None):
        # Upcast to float if we need to compute the loss to avoid potential precision issues

        loss_masks = inputs["loss_masks"].float().detach().clone().requires_grad_(False)
        loss_masks[loss_masks == self.ignore_index] = 0
        output_topk_ids = inputs["output_topk_ids"]

        min_value = torch.finfo(model.dtype).min
        max_value = torch.finfo(model.dtype).max

        q_len = labels.size(1)
        q_len_valid = int(loss_masks.sum().item())

        past_key_values = KVCache(
            self.model.config.draft_num_hidden_layers,
            self.model.config.num_key_value_heads,
            q_len + q_len_valid * (self.depth - 1) * self.top_draft,
            self.model.config.head_dim,
            self.model.device,
            self.model.dtype,
            num_seqs=self.num_seqs,
        )

        valid_mask = loss_masks.flatten().bool()

        kv_cache_indices = []

        last_input_ids = inputs["input_ids"]
        last_hidden_states = inputs["hidden_states"]
        last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)
        last_output_score = torch.full(
            (q_len, 1),
            fill_value=-self.depth,
            device=model.device,
            dtype=model.dtype,
        )

        last_attention_mask = torch.full(
            (labels.size(0), 1, q_len, q_len),
            min_value,
            device=model.device,
            dtype=model.dtype,
            requires_grad=False,
        )

        last_attention_mask[0, 0] = torch.triu(
            last_attention_mask[0, 0],
            diagonal=1,
        )

        total_output_scores = []
        total_sampled_node_idxs = []

        for i in range(self.depth):
            ql = q_len if i == 0 else q_len_valid
            nc = self.top_draft if i > 0 else 1

            new_kv_cache_indices = past_key_values.allocate(ql * nc)
            kv_cache_indices.extend(new_kv_cache_indices)

            model_inputs = {}
            model_inputs["input_ids"] = last_input_ids
            model_inputs["position_ids"] = last_position_ids
            model_inputs["hidden_states"] = last_hidden_states
            model_inputs["attention_mask"] = last_attention_mask
            model_inputs["past_key_values"] = past_key_values
            model_inputs["past_key_value_indices"] = kv_cache_indices
            model_inputs["use_cache"] = True

            output = model(
                **model_inputs,
                shift_tokens=i == 0,
                cut_last_token=i == 0,
            )

            if q_len_valid == 0:
                if return_outputs:
                    return torch.tensor(1.0, device=self.device, dtype=self.dtype, requires_grad=True), output
                return torch.tensor(1.0, device=self.device, dtype=self.dtype, requires_grad=True)

            v = output["logits"].size(-1)
            output_logits = output["logits"]
            # output_logits = F.sigmoid(output["logits"].reshape(ql, nc * v) / 10)
            # output_logits = output["logits"].log_softmax(-1).reshape(ql, nc * v)
            # output_logits = output["logits"].softmax(-1).reshape(ql, nc * v)
            # output_logits = output["logits"].reshape(ql, nc * v).softmax(-1)
            # output_logits = output["logits"].sigmoid().log().reshape(ql, nc * v)
            output_logits = output_logits.reshape(ql, nc * v)
            assert output_logits.size() == (ql, nc * v)

            assert last_output_score.size() == (ql, nc)
            # last_output_score = last_output_score.reshape(ql, nc, 1).detach() + output_logits.reshape(ql, nc, v)
            # last_output_score = last_output_score.reshape(ql, nc, 1).detach() + output_logits.reshape(ql, nc, v) - 1
            last_output_score = F.softplus(output_logits.reshape(ql, nc, v))
            last_output_score = last_output_score.reshape(ql, nc * v)

            last_hidden_states = output["draft_hidden_states"]
            h = last_hidden_states.size(-1)
            assert last_hidden_states.size() == (1, ql * nc, h)

            last_position_ids = last_position_ids + 1
            assert last_position_ids.size() == (1, ql * nc)

            if i == 0:
                assert last_hidden_states.size() == (1, q_len * nc, h)
                last_hidden_states = last_hidden_states.reshape(1, q_len, nc, h)
                last_hidden_states = last_hidden_states[:, valid_mask]
                assert last_hidden_states.size() == (1, q_len_valid, nc, h)
                last_hidden_states = last_hidden_states.reshape(1, -1, h)
                assert last_hidden_states.size() == (1, q_len_valid * nc, h)

                kv_len = last_attention_mask.size(-1)
                assert last_attention_mask.size() == (1, 1, q_len, kv_len)
                last_attention_mask = last_attention_mask.reshape(1, 1, ql, kv_len)
                last_attention_mask = last_attention_mask[:, :, valid_mask]
                assert last_attention_mask.size() == (1, 1, q_len_valid, kv_len)

                assert last_position_ids.size() == (1, q_len)
                last_position_ids = last_position_ids[:, valid_mask]
                assert last_position_ids.size() == (1, q_len_valid)

                assert output_logits.size() == (q_len, nc * v)
                output_logits = output_logits[valid_mask]
                assert output_logits.size() == (q_len_valid, nc * v)

                assert last_output_score.size() == (q_len, nc * v)
                last_output_score = last_output_score[valid_mask]
                assert last_output_score.size() == (q_len_valid, nc * v)

            output_ids = output_topk_ids[0, i, valid_mask, 0].reshape(q_len_valid, 1)
            assert output_ids.size() == (q_len_valid, 1)

            last_output_score_for_sampling = last_output_score.scatter(
                -1,
                output_ids,
                max_value,
            )

            sampled_node_idxs = torch.topk(
                last_output_score_for_sampling,
                self.top_node,
                dim=-1,
            ).indices
            assert sampled_node_idxs.size() == (q_len_valid, self.top_node)

            # random_node_idxs = torch.randint_like(sampled_node_idxs, 0, nc * v)

            # sampled_node_idxs = torch.where(
            #     torch.rand_like(sampled_node_idxs.float()) < 0.01,
            #     random_node_idxs,
            #     sampled_node_idxs,
            # )

            sampled_draft_idxs = sampled_node_idxs[..., :self.top_draft]
            assert sampled_draft_idxs.size() == (q_len_valid, self.top_draft)

            last_output_node_score = last_output_score.gather(
                -1,
                sampled_node_idxs,
            )
            assert last_output_node_score.size() == (q_len_valid, self.top_node)

            last_output_score = last_output_score.gather(
                -1,
                sampled_draft_idxs,
            )
            assert last_output_score.size() == (q_len_valid, self.top_draft)

            last_input_ids = (sampled_draft_idxs % v).reshape(1, q_len_valid * self.top_draft)
            assert last_input_ids.size() == (1, q_len_valid * self.top_draft)

            last_position_ids = last_position_ids.reshape(q_len_valid, nc).gather(
                1, (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft)
            )
            assert last_position_ids.size() == (q_len_valid, self.top_draft)
            last_position_ids = last_position_ids.reshape(1, q_len_valid * self.top_draft)
            assert last_position_ids.size() == (1, q_len_valid * self.top_draft)

            last_hidden_states = last_hidden_states.reshape(q_len_valid, nc, h).gather(
                1,
                (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, h)
            )
            assert last_hidden_states.size() == (q_len_valid, self.top_draft, h)
            last_hidden_states = last_hidden_states.reshape(1, q_len_valid * self.top_draft, h)
            assert last_hidden_states.size() == (1, q_len_valid * self.top_draft, h)

            past_kv_len = last_attention_mask.size(-1)
            assert last_attention_mask.size() == (1, 1, q_len_valid * nc, past_kv_len)
            last_attention_mask = last_attention_mask.reshape(q_len_valid, nc, past_kv_len).gather(
                1, (sampled_draft_idxs // v).reshape(q_len_valid, self.top_draft, 1).expand(q_len_valid, self.top_draft, past_kv_len)
            )
            assert last_attention_mask.size() == (q_len_valid, self.top_draft, past_kv_len)
            last_attention_mask = last_attention_mask.reshape(
                1, 1, q_len_valid * self.top_draft, past_kv_len,
            )
            assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len)
            new_attention_mask = torch.full(
                (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft),
                min_value,
                device=model.device,
                dtype=model.dtype,
            )
            new_attention_mask[0, 0].diagonal(0).fill_(0)
            assert new_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, q_len_valid * self.top_draft)
            last_attention_mask = torch.cat([
                last_attention_mask,
                new_attention_mask,
            ], dim=-1).requires_grad_(False)
            assert last_attention_mask.size() == (1, 1, q_len_valid * self.top_draft, past_kv_len + q_len_valid * self.top_draft)

            total_output_scores.append(last_output_node_score)
            total_sampled_node_idxs.append(sampled_node_idxs)

        past_key_values.free(kv_cache_indices)

        total_output_scores = torch.stack(total_output_scores, dim=0)
        assert total_output_scores.size() == (self.depth, q_len_valid, self.top_node)

        total_sampled_node_idxs = torch.stack(total_sampled_node_idxs, dim=0)
        assert total_sampled_node_idxs.size() == (self.depth, q_len_valid, self.top_node)

        last_reward = torch.ones((1, q_len_valid, 1), device=model.device, dtype=torch.long)
        rewards = []
        for i in range(self.depth):
            last_reward = last_reward.gather(-1, (total_sampled_node_idxs[i] // v).reshape(1, q_len_valid, self.top_node))
            # last_reward = torch.where(
            #     last_reward == 0,
            #     last_reward - 1 + ((total_sampled_node_idxs[i] % v) == output_topk_ids[0, i, valid_mask, :1]).float().reshape(1, q_len_valid, self.top_node),
            #     -torch.ones_like(last_reward),
            # )
            last_reward = torch.where(
                last_reward == 1,
                last_reward & ((total_sampled_node_idxs[i] % v) == output_topk_ids[0, i, valid_mask, :1]).reshape(1, q_len_valid, self.top_node),
                torch.zeros_like(last_reward),
            )
            rewards.append(last_reward)

        # losses = []
        # for i in range(self.depth):
        #     neg_prob = (1 - total_output_scores[i].exp()).clamp(self.eps, 1 - self.eps)
        #     loss = -total_output_scores[i] * (rewards[i] == 0).float() - neg_prob.log() * (rewards[i] != 0).float()
        #     loss = loss.sum(-1)
        #     loss = loss.mean()
        #     losses.append(loss)

        # weight = torch.ones(len(losses), device=losses[0].device, dtype=losses[0].dtype)
        # loss = (torch.stack(losses) * weight).sum() / weight.sum()

        # last_goal = torch.ones_like(rewards[-1]) * 10
        # last_goal = rewards[-1] * 10
        last_goal = torch.zeros_like(rewards[-1], dtype=torch.float)
        goals = []
        for i in range(self.depth - 1, -1, -1):
            if i < self.depth - 1:
                last_goal = torch.full_like(last_goal, 0) \
                    .scatter_reduce(
                        dim=-1,
                        index=(total_sampled_node_idxs[i + 1] // v).reshape(1, q_len_valid, self.top_node),
                        src=last_goal,
                        reduce='amax',
                        include_self=True,
                    )
            reward = rewards[i].reshape(1, q_len_valid, self.top_node)
            last_goal = reward + last_goal * 0.9
            goals.insert(0, last_goal)

        losses = []
        for i in range(self.depth):
            # import pdb; pdb.set_trace()
            # print()
            # print(total_output_scores[i])
            # print(goals[i][0])
            loss = F.mse_loss(total_output_scores[i], goals[i][0], reduction='none')
            loss = loss.sum(-1).mean()
            losses.append(loss)

        loss = sum(losses)

        total_loss = loss
        total_loss = total_loss * self.mult


        if Accelerator().is_main_process:
            self._log_idx += 1
            if not total_loss.isnan().item():
                self._log_metrics["total_loss"].append(total_loss.item())
                self._log_metrics["logits_mean"].append(output["logits"].abs().mean().item())
                if self._log_idx % 16 == 0:
                    metric = self._tree_metric(model, inputs, labels, return_outputs=False)
                    self._log_metrics[f"metric"].append(metric.item())
            if self._log_idx == 128:
                self._log_idx = 0
                for key, value in self._log_metrics.items():
                    print(f"{key}: {np.mean(value)}")
                    # if key == "metric" and np.mean(value) < 0.6:
                    #     import pdb; pdb.set_trace()
                self._log_metrics.clear()

        # print(metric)

        if return_outputs:
            return total_loss, output

        return total_loss