#%%
"""
Full definition of our Energy GPT Language Model, all of it in this single file.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
"""

import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from opt_einsum import contract as einsum
from torch.func import grad
# from torch.nn import LayerNorm


### Base components for the transformer
#--------------------------------------------------------------------

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class MultiHeadAttention(nn.Module):
    """ one head of self-attention """
    INIT_STD = 0.02
    def __init__(self, config,):# head_size):
        super().__init__()
        assert config.n_embed % config.n_head == 0
        head_size = config.n_embed // config.n_head
        # output projection
        self.c_proj = nn.Linear(config.n_embed, config.n_embed, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        
        # define all three qkv weights as a single tensor with shape (3, n_head, n_embed, head_size)
        self.qkv_weight = nn.Parameter(torch.randn(3, config.n_head, config.n_embed, head_size) * self.INIT_STD)
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        q,k,v = torch.einsum('btc,ahcs->bahts', x, self.qkv_weight).unbind(1)  # (B, n_head, T, head_size)
        if self.flash:
            # use PyTorch's built-in flash attention
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.dropout.p if self.training else 0.0,
                is_causal=True
            )  # (B, n_head, T, head_size)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, config):
        super().__init__()
        h = config.ff_hid_factor * config.n_embed # usually 4x n_embed
        self.net = nn.Sequential(
            nn.Linear(config.n_embed,h),
            nn.ReLU(),
            nn.Linear(h, config.n_embed),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        return self.net(x)



class Block_parallel(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, config):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = config.n_embed // config.n_head
        self.sa = MultiHeadAttention(config, head_size)
        self.ffwd = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embed)
        self.ln2 = nn.LayerNorm(config.n_embed)

    def forward(self, x, **kwargs):
        x = x + self.sa(self.ln1(x)) + self.ffwd(self.ln2(x))
        return x

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, config):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        # head_size = config.n_embed // config.n_head
        self.sa = MultiHeadAttention(config)
        self.ffwd = FeedForward(config)
        self.ln1 = LayerNorm(config.n_embed, bias=config.bias)
        self.ln2 = LayerNorm(config.n_embed, bias=config.bias)

    def forward(self, x, **kwargs):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class GPT(nn.Module):
    def __init__(self, config, block_class=Block):
        super().__init__()
        self.config = config
        self.block_class = block_class
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)
        self.register_buffer('pos_idx', torch.arange(config.block_size)) # (block_size,)
        # self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.blocks = self.get_blocks(config)
        self.ln_f = LayerNorm(config.n_embed, bias=config.bias) # final layer norm
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
        # weight-tying between input embedding and output embedding
        # self.lm_head.weight = self.token_embedding_table.weight
        
        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def get_blocks(self, config):
        return nn.Sequential(*[self.block_class(config) for _ in range(config.n_layer)])
    
    def block_forward(self, x):
        return self.blocks(x)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        device = idx.device

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        # torch.arange(T, device=device)
        pos_emb = self.position_embedding_table(self.pos_idx[:T]) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        # x = self.blocks(x) # (B,T,C)
        x = self.block_forward(x)
        x = self.ln_f(x) # (B,T,C)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss
    
    def generate(self, idx, max_new_tokens, greedy=False):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.config.block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)

            if not greedy:
                # apply softmax to get probabilities
                probs = F.softmax(logits, dim=-1)  # (B, C)
                # sample from the distribution
                idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
                # print(idx_next.shape, logits.shape)
            else:
                # greedy shortcut -- just use argmax of probs/logits
                idx_next = torch.argmax(logits, dim=1).unsqueeze(-1)
                # print(idx_next.shape, logits.shape)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    
class GPT_Rec(GPT):
    """ Recursive transformer where the same block is applied n_layer times """
    def __init__(self, config, block_class=Block):
        super().__init__(config, block_class=block_class)
        
    def get_blocks(self, config):
        # Only a single block that will be applied recursively
        return self.block_class(config)
    
    def block_forward(self, x):
        for _ in range(self.config.n_layer):
            x = self.blocks(x)
        return x

class GPT_Rec_parallel(GPT_Rec):
    """ Variant of recursive GPT using Block_parallel """
    def __init__(self, config, block_class=Block_parallel):
        super().__init__(config, block_class=block_class)
        
