import torch
import torch.nn as nn
import math
from torch.nn import functional as F

class Embedding(nn.Module):
    def __init__(self, d_model, vocab_size, maxlen, rpe):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        pe = torch.zeros(maxlen, d_model).float()
        pe.require_grad = False
        position = torch.arange(0, maxlen).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.norm = nn.LayerNorm(d_model)
        self.rpe = rpe

    def forward(self, x):
        if self.rpe:
            embedding = self.tok_embed(x)
        else:
            embedding = self.tok_embed(x) + self.pe[:, :x.size(1)]
        return self.norm(embedding)

class NewGELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, nhead, drop, maxlen, rpe):
        super().__init__()
        assert d_model % nhead == 0
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(drop)
        self.resid_dropout = nn.Dropout(drop)
        self.register_buffer("bias", torch.tril(torch.ones(maxlen, maxlen)).view(1, 1, maxlen, maxlen))
        self.rpe = rpe
        rpe = torch.zeros(1, nhead, maxlen, maxlen)
        for i in range(1, maxlen):
            rpe = rpe - torch.tril(torch.ones(maxlen, maxlen), diagonal=-i).view(1, 1, maxlen, maxlen)
        for i in range(nhead):
            rpe[0, i] = rpe[0, i] * 2 **(-8 / nhead * (i + 1))
        self.register_buffer("RPE", rpe)
        self.n_head = nhead
        self.n_embd = d_model
        

    def forward(self, x, mask=None):
        B, T, C = x.size()
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        if self.rpe:
            att = att + self.RPE[:, :, :T, :T]
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        if mask is not None:
            att = att.masked_fill(mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class Block(nn.Module):
    def __init__(self, d_model, nhead, drop, maxlen, rpe):
        super().__init__()
        self.ln_1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model=d_model, nhead=nhead, drop=drop, maxlen=maxlen, rpe=rpe)
        self.ln_2 = nn.LayerNorm(d_model)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(d_model, 4 * d_model),
            c_proj  = nn.Linear(4 * d_model, d_model),
            act     = NewGELU(),
            dropout = nn.Dropout(drop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))

    def forward(self, x, mask=None):
        x = x + self.attn(self.ln_1(x), mask)
        x = x + self.mlpf(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            embedding = Embedding(d_model=args.dmodel, vocab_size=args.vocab, maxlen=args.maxlen, rpe=args.rpe),
            drop = nn.Dropout(args.drop),
            h = nn.ModuleList([Block(d_model=args.dmodel, nhead=args.head, drop=args.drop, maxlen=args.maxlen, rpe=args.rpe) for _ in range(args.num_layer)]),
            ln_f = nn.LayerNorm(args.dmodel),
        ))
        self.lm_head = nn.Linear(args.dmodel, args.vocab, bias=True)
        self.maxlen = args.maxlen

    def forward(self, idx):
        b, t = idx.size()
        emb = self.transformer.embedding(idx)
        x = self.transformer.drop(emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits

    def generate(self,
                 idx: torch.Tensor,
                 max_new_tokens: int | None,
                 temperature: int =1.0,
                 top_k: int | None =None,
                 top_p: float | None =None,
                 eos_token_id: int=2):
        """
        Generates text in an autoregressive manner using the GPT model.

        This function generates new tokens one at a time, conditioning on the previously generated tokens.

        It supports multiple decoding strategies, including greedy decoding, temperature scaling, top-k sampling,
        and top-p (nucleus) sampling. 
        
        It also supports batch processing and early stopping based on the <eos> token.

        Args:
            idx (torch.Tensor): The initial prefix of token indices of shape (batch_size, sequence_length).
            max_new_tokens (int): The maximum number of new tokens to generate.
            temperature (float, optional): Controls the randomness of predictions by scaling the logits before applying softmax.
                Higher values (e.g., > 1.0) make the output more random, while lower values (e.g., < 1.0) make it more deterministic.
                Defaults to 1.0.
            top_k (int, optional): If specified, only the top-k tokens with the highest probabilities are considered for sampling.
                This reduces the diversity of the output. Defaults to None.
            top_p (float, optional): If specified, uses nucleus sampling, where the smallest set of tokens with a cumulative probability
                greater than top_p is selected. This allows for dynamic token selection based on the probability distribution.
                Defaults to None.
                If both top_k and top_p are specified, top_p is applied first.
                If neither top_k nor top_p is specified, greedy decoding is used.
            eos_token_id (int, optional): The token ID of the <eos> (end-of-sequence) token. 
                If all sequences in the batch generate this token, the function stops early.
                Defaults to 2.

        Returns:
            torch.Tensor: The generated sequence of token indices, including the initial input and the newly generated tokens.
        """
        idx = idx.cuda()
        for _ in range(max_new_tokens):
            # If the sequence length exceeds the model's maximum length, truncate the sequence to the last `maxlen` tokens.
            idx_cond = idx if idx.size(1) <= self.maxlen \
                else idx[:, -self.maxlen:]

            # Get the logits for the current sequence.
            logits = self(idx_cond)
            
            # Extract the logits for the last token in the sequence and apply temperature scaling.
            logits = logits[:, -1, :] / temperature
            
            if (top_k is None) and (top_p is None):
                # If neither top_k nor top_p is specified, use greedy decoding.
                idx_next = torch.argmax(logits, dim=-1, keepdim=True)
            else:
                # If top_k is specified, mask out logits that are not in the top-k tokens.
                if top_k is not None:
                    v, _ = torch.topk(logits, top_k)
                    logits[logits < v[:, [-1]]] = -float('Inf')
                
                # If top_p is specified, perform nucleus sampling.
                if top_p is not None:
                    # Sort the logits in descending order and get the sorted indices.
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                    # Compute the cumulative probabilities of the sorted logits.
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                    # Identify tokens to remove: those with cumulative probability greater than top_p.
                    sorted_indices_to_remove = cumulative_probs > top_p
                    # Ensure at least one token is kept by shifting the mask.
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    
                    # Convert the mask to indices and set the corresponding logits to negative infinity.
                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    logits[..., indices_to_remove] = -float('Inf')
                
                # Convert the logits to probabilities using softmax.
                probs = F.softmax(logits, dim=-1)
                
                # Sample the next token from the probability distribution.
                idx_next = torch.multinomial(probs, num_samples=1)
            
            # Append the newly generated token to the sequence.
            idx = torch.cat((idx, idx_next), dim=1)
            
            # Early stopping: If all sequences in the batch generate the <eos> token, stop generation.
            if torch.all(idx_next == eos_token_id):
                break
        
        return idx