"""
Full definition of a GPT Language Model, all of it in this single file.

References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
"""

import math
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F

from mingpt.utils import CfgNode as CN

# -----------------------------------------------------------------------------

class NewGELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    # CHANGED: Added optional prefixes and LoRA for the K and V matrices and 
    # optional record to store intermediate activations
    def forward(
            self, 
            x, 
            prefix_key: Optional[torch.Tensor],
            prefix_value: Optional[torch.Tensor],
            attention_lora: Optional[torch.Tensor],
            prefix_attention: Optional[torch.Tensor],
            activations_record: Optional[List[Tuple[str, torch.Tensor]]],
        ):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)

        # CHANGED: if attention_lora is provided, compute the respective updates 
        # to WQ and WK and their action on the input and add it to the pretrained values
        if attention_lora is not None:
            lora_alpha = 1
            lora_rank = attention_lora.size(1)
            q_matrix_A, q_matrix_B, k_matrix_A, k_matrix_B = (t.squeeze(0) for t in attention_lora.split(1, dim=0))
            lora_q = F.linear(x, q_matrix_A.T @ q_matrix_B)
            lora_k = F.linear(x, k_matrix_A.T @ k_matrix_B)
            q = q + lora_q * (lora_alpha/lora_rank)
            k = k + lora_k * (lora_alpha/lora_rank)

        # CHANGED: if prefix_key or prefix_value are provided, replace 
        # the first positions of xk and xv with them
        if prefix_value is not None:
            prefix_size = prefix_value.size(0)
            v = torch.cat([prefix_value.repeat(B, 1, 1), v[:, prefix_size:]], dim=1)
            # v[:, :prefix_size] = prefix_value.repeat(B, 1, 1)
        
        if prefix_key is not None:
            assert prefix_value is not None
            prefix_size = prefix_value.size(0)
            v = torch.cat([prefix_value.repeat(B, 1, 1), v[:, prefix_size:]], dim=1)
            # k[:, :prefix_size] = prefix_key.repeat(B, 1, 1)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        # CHANGED: if prefix_attention is provided, set the amount of attention 
        # on each of the prefix positions before the mask is applied and the softmax is computed
        if prefix_attention is not None:
            assert prefix_value is not None
            prefix_size = prefix_value.size(0)
            att[:,:,:,:prefix_size] = prefix_attention.unsqueeze(1).repeat(B, 1, T, 1)

        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        # CHANGED: record the attention pattern
        if activations_record is not None:
            activations_record.append(("attention", att.detach().clone().cpu()))

        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()

        if config.skip_layer_norm:
            self.ln_1 = lambda x: x
        else:
            self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

        # CHANGED: adding two layers which will be used when we want to find a 
        # more interpretable representation of the residual stream
        self.to_factorized = nn.Linear(config.n_embd, config.n_embd+2)
        self.from_factorized = nn.Linear(config.n_embd+2, config.n_embd)

        # CHANGED: callback that can be used to modify the factorized residual stream
        self.projection_callback = lambda x: x


    def forward(
            self, 
            x, 
            prefix: Optional[torch.Tensor], 
            prefix_key: Optional[torch.Tensor],
            prefix_value: Optional[torch.Tensor],
            prefix_attention: Optional[torch.Tensor],
            attention_lora: Optional[torch.Tensor], 
            activations_record: Optional[List[Tuple[str, torch.Tensor]]],
            projection: bool = False,
        ):
        
        # CHANGED: add the prefix if provided
        _bsz = x.size(0)
        if prefix is not None:
            prefix_size = prefix.size(0)
            input_for_attention = torch.cat([
                prefix[None,:,:].repeat(_bsz, 1, 1),
                x[:, prefix_size:, :]
            ], dim=1)
        else:
            input_for_attention = x

        # CHANGED: pass the optional LoRA parameter and activation record
        # CHANGED: removed the layer norm before the attention block
        attn_block_output = self.attn(
            self.ln_1(input_for_attention),
            prefix_key=prefix_key,
            prefix_value=prefix_value,
            prefix_attention=prefix_attention,
            attention_lora=attention_lora,
            activations_record=activations_record
        )

        # CHANGED: record the attention block output if desired
        if activations_record is not None:
            activations_record.append(("attn_block_output", attn_block_output.detach().clone().cpu()))
        
        # CHANGED: if learning an interpretable projection is desired
        if projection:
            factorized_pred = self.to_factorized(attn_block_output)
            factorized_pred = self.projection_callback(factorized_pred)
            attn_fwd_reconstructed = self.from_factorized(factorized_pred)

            if activations_record is not None:
                activations_record.append(("factorized_pred", factorized_pred.detach().clone().cpu()))
                activations_record.append(("attn_fwd_reconstructed", attn_fwd_reconstructed.detach().clone().cpu()))
            else:
                print("You probably want to pass an activation_record here if you use projection!")

            attn_block_output = attn_fwd_reconstructed

        x = x + attn_block_output
        x = x + self.mlpf(self.ln_2(x))
        return x

class GPT(nn.Module):
    """ GPT Language Model """

    @staticmethod
    def get_default_config() -> CN:
        C = CN()
        # either model_type or (n_layer, n_head, n_embd) must be given in the config
        C.model_type = 'gpt'
        C.n_layer = None
        C.n_head = None
        C.n_embd =  None
        # these options must be filled in externally
        C.vocab_size = None
        C.block_size = None
        # dropout hyperparameters
        C.embd_pdrop = 0.1
        C.resid_pdrop = 0.1
        C.attn_pdrop = 0.1

        # CHANGED: added an option to skip one layer normalization
        C.skip_layer_norm = False
        return C

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.block_size = config.block_size

        self.config = config

        type_given = config.model_type is not None
        params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None])
        assert type_given ^ params_given # exactly one of these (XOR)
        if type_given:
            # translate from model_type to detailed configuration
            config.merge_from_dict({
                # names follow the huggingface naming conventions
                # GPT-1
                'openai-gpt':   dict(n_layer=12, n_head=12, n_embd=768),  # 117M params
                # GPT-2 configs
                'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
                'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
                'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
                'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
                # Gophers
                'gopher-44m':   dict(n_layer=8, n_head=16, n_embd=512),
                # (there are a number more...)
                # I made these tiny models up
                'gpt-mini':     dict(n_layer=6, n_head=6, n_embd=192),
                'gpt-micro':    dict(n_layer=4, n_head=4, n_embd=128),
                'gpt-nano':     dict(n_layer=3, n_head=3, n_embd=48),
            }[config.model_type])

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params/1e6,))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    @classmethod
    def from_pretrained(cls, model_type):
        """
        Initialize a pretrained GPT model by copying over the weights
        from a huggingface/transformers checkpoint.
        """
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel

        # create a from-scratch initialized minGPT model
        config = cls.get_default_config()
        config.model_type = model_type
        config.vocab_size = 50257 # openai's model vocabulary
        config.block_size = 1024  # openai's model block_size
        model = GPT(config)
        sd = model.state_dict()

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear.
        # this means that we have to transpose these weights when we import them
        # assert len(keys) == len(sd)
        for k in keys:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(
            self, 
            idx, 
            targets=None,
            prefixes: Optional[torch.Tensor] = None, 
            prefix_key: Optional[torch.Tensor] = None,
            prefix_value: Optional[torch.Tensor] = None,
            prefix_attention: Optional[torch.Tensor] = None,
            attention_lora: Optional[torch.Tensor] = None, 
            activations_record: Optional[List[Tuple[str, torch.Tensor]]] = None,
            projection: bool = False,
        ):

        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        # CHANGED
        if prefixes is None:
            prefixes = [None for _ in range(len(self.transformer.h))] # type: ignore
        if prefix_key is None:
            prefix_key = [None for _ in range(len(self.transformer.h))]
        if prefix_value is None:
            prefix_value = [None for _ in range(len(self.transformer.h))]
        if prefix_attention is None:
            prefix_attention = [None for _ in range(len(self.transformer.h))]
        if attention_lora is None:
            attention_lora = [None for _ in range(len(self.transformer.h))] # type: ignore

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)

        # CHANGED: passes through arguments
        for block_idx, block in enumerate(self.transformer.h):
            x = block(
                x, 
                prefix=prefixes[block_idx], 
                prefix_key=prefix_key[block_idx],
                prefix_value=prefix_value[block_idx],
                prefix_attention=prefix_attention[block_idx],
                attention_lora=attention_lora[block_idx],
                activations_record=activations_record,
                projection=projection,
            )
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    # CHANGED: passes through arguments
    @torch.no_grad()
    def generate(
        self, 
        idx, 
        max_new_tokens, 
        temperature=1.0, 
        do_sample=False, 
        top_k=None,
        prefixes: Optional[torch.Tensor] = None, 
        prefix_key: Optional[torch.Tensor] = None,
        prefix_value: Optional[torch.Tensor] = None,
        prefix_attention: Optional[torch.Tensor] = None,
        attention_lora: Optional[torch.Tensor] = None, 
        activations_record: Optional[List[Tuple[str, torch.Tensor]]] = None,
        projection: bool = False, 
    ):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """

        eos_tokens = torch.tensor([27, 91, 437, 1659, 5239, 91, 29], dtype=torch.long, device=idx.device)

        for pos_idx in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(
                idx_cond,
                prefixes = prefixes, 
                prefix_key = prefix_key,
                prefix_value = prefix_value,
                prefix_attention = prefix_attention,
                attention_lora = attention_lora, 
                activations_record = activations_record if pos_idx==max_new_tokens-1 else None, # only last pass is needed as it has all the information from the previous passes too
                projection = projection, 
            )
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # either sample from the distribution or take the most likely element
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

            # Terminate early if batch size 1 and reached the eos tokens
            if idx.size(1) > 7 and idx.size(0) == 1 and torch.all(idx[0, -7:] == eos_tokens):
                break

        return idx
