import time
import math

import torch
import torch.nn.functional as F
import torch.nn as nn

class GPT2(nn.Module):
    def __init__(self, config):
        super(GPT2, self).__init__()
        self.config = config

        self.embed_dim = config.hidden_size
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Identity()
        self.h = nn.ModuleList([GPT2Block(config) for i in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.lm_head = nn.Linear(self.embed_dim, config.vocab_size, bias=False)


    def forward(self, input_ids, attention_mask=None, past_list=None, past_length=0):
        if past_list is None:
            past_list = [None for _ in range(self.config.num_hidden_layers)]

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0)

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

        hidden_states = self.drop(hidden_states)
        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
        
        presents = []
        for layer_id, layer in enumerate(self.h):
            hidden_states, present = layer(hidden_states, attention_mask, past_list[layer_id])
            presents = presents + [present]

        output = self.ln_f(hidden_states)
        output = output.view(output_shape)
        output = self.lm_head(output)

        return output, presents


    def generate(self, idx, max_new_tokens, attention_mask=None, temperature=1.0, kv_cache=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,s,v)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        device = idx.device
        kv_cache = None
        generation_stage = False
        all_probs = None
        if kv_cache is not None:
            past_length = kv_cache[0][0].size(2)
        else:
            past_length = 0
        
        for token_id in range(max_new_tokens):
            if not generation_stage:
                logits, kv_cache = self(idx, attention_mask, kv_cache, past_length)
            else:
                logits, kv_cache = self(idx[:, -1:], attention_mask, kv_cache, past_length)

            logits = logits[:, -1:, :] / temperature
            probs = logits.softmax(dim=-1)
            if all_probs is None:
                all_probs = probs.clone().cpu()
            else:
                all_probs = torch.cat([all_probs, probs.clone().cpu()], dim=1)

            idx_next = torch.argmax(logits, dim=-1)

            if not generation_stage:
                past_length += idx.size(1)
                idx = idx_next
                generation_stage = True
            else:
                past_length += 1
                idx = torch.cat([idx, idx_next], dim=1)

        return idx, kv_cache, all_probs


class GPT2Block(nn.Module):
    def __init__(self, config):
        super(GPT2Block, self).__init__()
        hidden_size = config.hidden_size
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config=config)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config)


    def forward(self, hidden_states, attention_mask=None, past=None):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attention_output, present = self.attn(hidden_states, attention_mask, past)
        hidden_states = attention_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = residual + feed_forward_hidden_states

        return hidden_states, present
        

class GPT2Attention(nn.Module):
    def __init__(self, config):
        super(GPT2Attention, self).__init__()
        self.config = config
        max_positions = config.max_position_embeddings
        self.bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                        1, 1, max_positions, max_positions
                    )
        self.masked_bias = torch.tensor(-1e4)

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Identity()
        self.resid_dropout = nn.Identity()

    
    def mask_to_cpu(self):
        self.bias = self.bias.cpu()

    def gen_attention_mask(self, nd, ns):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        i = torch.arange(nd)[:,None]
        j = torch.arange(ns)
        m = i >= j - ns + nd
        return m

    def forward(self, hidden_states, attention_mask=None, layer_past=None):
        query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)

        shape_q = (*query_states.shape[:-1], -1, self.head_dim)
        shape_kv = (*key_states.shape[:-1], -1, self.head_dim)

        query_states = query_states.view(shape_q).transpose(1, 2)
        key_states = key_states.view(shape_kv).transpose(1, 2)
        value_states = value_states.view(shape_kv).transpose(1, 2)

        if layer_past is not None:
            past_key, past_value = layer_past
            key_states = torch.cat((past_key, key_states), dim=-2)
            value_states = torch.cat((past_value, value_states), dim=-2)
        present = (key_states, value_states)

        w = query_states.matmul(key_states.transpose(-1, -2))
        w = w / math.sqrt(self.head_dim)

        _, _, nd, ns = w.shape
        attention_mask = self.gen_attention_mask(nd, ns).to(device=w.device, dtype=w.dtype)
        attention_mask = torch.reshape(attention_mask, [1, 1, nd, ns]).contiguous()
    
        w = w * attention_mask + self.masked_bias * (1-attention_mask)

        w = w.softmax(dim=-1)
        attn_output =  torch.matmul(w, value_states)

        attn_output = attn_output.transpose(1,2).contiguous()
        attn_output = attn_output.reshape(*attn_output.shape[:2],-1).contiguous()
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        return attn_output, present


class GPT2MLP(nn.Module):
    def __init__(self, intermediate_size, config):
        super().__init__()
        embed_dim = config.hidden_size
        self.c_fc = Conv1D(intermediate_size, embed_dim)
        self.c_proj = Conv1D(embed_dim, intermediate_size)
        self.act = nn.GELU(approximate="tanh")
        self.dropout = nn.Identity()

    def forward(self, hidden_states):
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states
    

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        self.nx = nx
        self.weight = nn.Parameter(torch.empty(nx, nf))
        self.bias = nn.Parameter(torch.zeros(nf))
        
        nn.init.normal_(self.weight, std=0.02)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x
