# The code below is built on top of the great work going on in the x_transformers repo:
#       https://github.com/lucidrains/x-transformers/blob/main/x_transformers

from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from entmax import entmax_bisect


# nucleus

def top_p(logits, thres=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)


# topk

def top_k(logits, thres=0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs


# entmax

ENTMAX_ALPHA = 1.3
entmax = entmax_bisect


class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, quant_guides=None, ignore_index=-100, ignore_quant_index=None):
        super().__init__()
        self.ignore_index = ignore_index
        self.ignore_quant_index = ignore_quant_index
        self.net = net
        self.with_values = net.with_values
        self.max_seq_len = net.max_seq_len

    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9,
                 **kwargs):
        device = start_tokens.device
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        mask = kwargs.pop('mask', None)

        if mask is None:
            mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            mask = mask[:, -self.max_seq_len:]

            logits = self.net(x, mask=mask, **kwargs)[:, -1, :]

            if filter_logits_fn in {top_k, top_p}:
                filtered_logits = filter_logits_fn(logits, thres=filter_thres)
                probs = F.softmax(filtered_logits / temperature, dim=-1)

            elif filter_logits_fn is entmax:
                probs = entmax(logits / temperature, alpha=ENTMAX_ALPHA, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out

    @torch.no_grad()
    def predict(self, x, **kwargs):
        if not self.with_values:
            xi = x[:, :-1]
            out = self.net(xi, **kwargs)
        else:
            xi = x[0][:, :-1]
            qi = x[1][:, :-1]
            out, quants_out = self.net(xi, quants=qi, **kwargs)
            return torch.argmax(out, dim=-1), torch.argmax(quants_out, dim=-1)
        return torch.argmax(out, dim=-1)

    def forward(self, x, **kwargs):
        # help auto-solve a frequent area of confusion around input masks in auto-regressive
        # if user supplies a mask that is only off by one from the source sequence, resolve it for them
        mask = kwargs.get('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]
            kwargs['mask'] = mask

        if not self.with_values:
            xi = x[:, :-1]
            xo = x[:, 1:]
            out = self.net(xi, **kwargs)
        else:
            xi = x[0][:, :-1]
            qi = x[1][:, :-1]
            xo = x[0][:, 1:]
            qo = x[1][:, 1:]
            out, quants_out = self.net(xi, quants=qi, **kwargs)
            token_loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index=self.ignore_index)
            quant_loss = F.cross_entropy(quants_out.transpose(1, 2), qo, ignore_index=self.ignore_quant_index)
            return token_loss, quant_loss

        token_loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index=self.ignore_index)
        return token_loss
