"""Base LLM model (nanoGPT) with next-token prediction loss.

We use the nanoGPT model from https://github.com/karpathy/nanoGPT. We carried
out a small refactoring of the source code and also removed unused code blocks.

"""
import inspect
import math
from rotary_embedding_torch import RotaryEmbedding
import torch
import torch.nn as nn
from torch.nn import functional as F


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        assert config.n_embd % config.n_head == 0
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        # Define key, query, value projections for all heads, but in a batch:
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd,
            bias=config.bias)
        
        # Define output projection:
        self.c_proj = nn.Linear(config.n_embd, config.n_embd,
            bias=config.bias)
        
        # Define regularization:
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        
        # Define flash attention:
        self.flash = hasattr(torch.nn.functional,
            'scaled_dot_product_attention')
        if not self.flash:
            print('WARNING: flash Attention requires PyTorch >= 2.0')
            # Set causal mask to ensure that attention is only applied
            # to the left in the input sequence:
            b = torch.tril(torch.ones(config.block_size, config.block_size))
            b = b.view(1, 1, config.block_size, config.block_size)
            self.register_buffer('bias', b)


        self.rotary_emb = RotaryEmbedding(config.n_embd)

    def forward(self, x):
        B, T, C = x.size()
        # B - batch size
        # T - sequence length
        # C - embedding dimensionality (n_embd)

        # Calculate query, key, values for all heads in batch:
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        # apply rotary emb
        # TODO: replace with torchtune version mb
        k = self.rotary_emb.rotate_queries_or_keys(k)
        q = self.rotary_emb.rotate_queries_or_keys(q)  


        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # (B, nh, T, hs)
        
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # (B, nh, T, hs)
 

        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # (B, nh, T, hs)

        # Causal self-attention;
        # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # Efficient attention using Flash Attention CUDA kernels:
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=True)
        else:
            # Manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        
        # Re-assemble all head outputs side by side:
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # Output projection:
        y = self.resid_dropout(self.c_proj(y))
        
        return y


class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        """LayerNorm but with an optional bias."""
        super().__init__()
        
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1.E-5)


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd,
            bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd,
            bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class ModelGPT(nn.Module):
    def __init__(self, manager, config):
        super().__init__()

        self.manager = manager

        self.block_size = config.block_size

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))

        self._head_init(config)

        self.apply(self._init_weights)
        # Apply special scaled init to residual projections, per GPT-2 paper:
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p,
                    mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def get_num_params(self, non_embedding=True):
        """Return the number of parameters in the model.
        
        For non-embedding count (default), the position embeddings get
        subtracted. The token embeddings would too, except due to the parameter
        sharing these params are actually used as weights in the final layer,
        so we include them.
        
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def body_forward(self, idx):
        device = idx.device
        b, t = idx.size()
        
        msg = f'Cannot forward sequence of length {t}, '
        msg += f'block size is only {self.block_size}'
        assert t <= self.block_size, msg
        
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        # Forward the GPT model itself:
        # (token embeddings of shape (b, t, n_embd);
        #  position embeddings of shape (t, n_embd))
        tok_emb = self.transformer.wte(idx) 
        pos_emb = self.transformer.wpe(pos)

        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)

        return x


    def forward(self, idx, targets=None, with_w_norm=True):
        x = self.body_forward(idx)

        return self._head_forward(x, targets, with_w_norm=with_w_norm)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generation of the new sequence.
        
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t))
        and complete the sequence max_new_tokens times, feeding the predictions
        back into the model each time. Most likely you'll want to make sure to
        be in model.eval() mode of operation for this.
        
        """
        for _ in range(max_new_tokens):
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size:]
            
            # Forward the model to get the logits for the index in the sequence:
            logits, _ = self(idx_cond)
            
            # Pluck the logits at the final step and scale by temperature:
            logits = logits[:, -1, :] / temperature
            
            # Optionally crop the logits to only the top k options:
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            # Apply softmax to convert logits to (normalized) probabilities:
            probs = F.softmax(logits, dim=-1)
            
            # Sample from the distribution:
            idx_next = torch.multinomial(probs, num_samples=1)
            
            # Append sampled index to the running sequence and continue:
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

    def _head_forward(self, x, targets=None, with_w_norm=True):
        if targets is not None:
            # If we are given some desired targets also calculate the loss:
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1),
                ignore_index=-1)
        else:
            # Inference-time mini-optimization:
            # (only forward the lm_head on the very last position;
            #  note: using list [-1] to preserve the time dim)
            logits = self.lm_head(x[:, [-1], :]) 
            loss = None

        return logits, loss

    def _head_init(self, config):
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # self.transformer.wte.weight = self.lm_head.weight

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)