import time
import math

import torch
import torch.nn.functional as F

import crypten
import crypten.nn as cnn
import crypten.communicator as comm
from crypten.common.functions import maximum

from utils import MPCIdentity, MPCLinear, encrypt_tensor, PumaGeLU, PumaLayerNorm, PumaSoftmax


class GPT2(cnn.Module):
    def __init__(self, config):
        super(GPT2, self).__init__()
        self.config = config

        self.embed_dim = config.hidden_size
        self.wte = MPCLinear(config.vocab_size, self.embed_dim, bias=False)
        self.wpe = MPCLinear(self.embed_dim, config.max_position_embeddings, bias=False)

        self.drop = MPCIdentity()
        self.h = cnn.ModuleList([GPT2Block(config, i) for i in range(config.num_hidden_layers)])
        self.ln_f = PumaLayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.lm_head = MPCLinear(self.embed_dim, config.vocab_size, bias=False)

        self.smax = PumaSoftmax(dim=-1)
        self.cat = cnn.Concat(dimension=1)
        self.cat_last = cnn.Concat(dimension=2)


    def forward(self, input_ids, attention_mask=None, past_list=None, past_length=0):
        if past_list is None:
            past_list = [None for _ in range(self.config.num_hidden_layers)]

        input_shape = input_ids.size()[:2]

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe.weight[past_length:input_shape[-1] + past_length]
        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)

        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

        presents = []
        for layer_id, layer in enumerate(self.h):
            hidden_states, present = layer(hidden_states, attention_mask, past_list[layer_id])
            presents = presents + [present]
            past_list[layer_id] = None

        output = self.ln_f(hidden_states)
        output = output.reshape(output_shape)
        output = self.lm_head(output)

        return output, presents


    @crypten.no_grad()
    def generate(self, idx, max_new_tokens, attention_mask=None, kv_cache=None, target_ids=None):
        device = idx.device
        generation_stage = False
        all_probs = None
        if kv_cache is not None:
            past_length = kv_cache[0][0].size(2)
        else:
            past_length = 0

        for token_id in range(max_new_tokens):
            idx_cond = idx.clone() if idx.size(1) <= self.config.max_position_embeddings else idx[:, -self.config.max_position_embeddings:,:]

            if not generation_stage:
                logits, kv_cache = self(idx_cond, attention_mask, kv_cache, past_length)
            else:
                logits, kv_cache = self(idx_cond, attention_mask, kv_cache, past_length)
                
            logits = logits[:, -1:, :]
            probs = self.smax(logits)
            if target_ids is not None:
                target_probs = torch.stack([probs[:,:,tgt_idx].get_plain_text().to(torch.float32).cpu() for tgt_idx in target_ids],dim=-1)
                if all_probs is None:
                    all_probs = target_probs
                else:
                    all_probs = torch.cat([all_probs, target_probs], dim=1)
            
            idx_next = maximum.argmax(probs, dim=-1, one_hot=True)

            if not generation_stage:
                past_length += idx.size(1)
                idx = idx_next
                generation_stage = True
            else:
                past_length += 1
                idx = crypten.cat([idx, idx_next], dim=1)

        return idx.get_plain_text().cpu(), None, all_probs


class GPT2Block(cnn.Module):
    def __init__(self, config, layer_idx=None):
        super(GPT2Block, self).__init__()
        hidden_size = config.hidden_size
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
        self.layer_idx = layer_idx

        self.ln_1 = PumaLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
        self.ln_2 = PumaLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config, layer_idx=layer_idx)


    def forward(self, hidden_states, attention_mask=None, past=None):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attention_output, present = self.attn(hidden_states, attention_mask, past)
        hidden_states = attention_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = residual + feed_forward_hidden_states

        return hidden_states, present
        

class GPT2Attention(cnn.Module):
    def __init__(self, config, layer_idx):
        super(GPT2Attention, self).__init__()
        self.config = config
        self.layer_idx = layer_idx
        max_positions = config.max_position_embeddings
        self.masked_bias = -1e4

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.c_attn = MPCConv1D(3 * self.embed_dim, self.embed_dim)
        self.c_proj = MPCConv1D(self.embed_dim, self.embed_dim)

        self.attn_dropout = MPCIdentity()
        self.resid_dropout = MPCIdentity()

        self.smax = PumaSoftmax(dim=-1)
        self.cat = cnn.Concat(dimension=2)


    def gen_attention_mask(self, nd, ns):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        i = torch.arange(nd)[:,None]
        j = torch.arange(ns)
        m = i >= j - ns + nd
        return m

    def forward(self, hidden_states, attention_mask=None, layer_past=None):
        qkv_states = self.c_attn(hidden_states)
        q_len = qkv_states.size(2) // 3
        query_states, key_states, value_states = qkv_states[:,:,:q_len], qkv_states[:,:,q_len:2*q_len], qkv_states[:,:,2*q_len:]

        shape_q = (*query_states.shape[:-1], -1, self.head_dim)
        shape_kv = (*key_states.shape[:-1], -1, self.head_dim)

        query_states = query_states.reshape(shape_q).transpose(1, 2)
        key_states = key_states.reshape(shape_kv).transpose(1, 2)
        value_states = value_states.reshape(shape_kv).transpose(1, 2)

        if layer_past is not None:
            past_key, past_value = layer_past
            key_states = self.cat([past_key, key_states])
            value_states = self.cat([past_value, value_states])
        present = (key_states, value_states)

        w = query_states.matmul(key_states.transpose(-1, -2))
        w = w / math.sqrt(self.head_dim)

        _, _, nd, ns = w.shape
        attention_mask = crypten.cryptensor(self.gen_attention_mask(nd, ns).to(device=w.device, dtype=torch.float32), src=0)
        attention_mask = attention_mask.reshape(1, 1, nd, ns)
    
        w = w * attention_mask + self.masked_bias * (1-attention_mask)
        w = self.smax(w)
        attn_output = w.matmul(value_states)

        attn_output = attn_output.transpose(1,2)
        attn_output = attn_output.reshape(*attn_output.shape[:2],-1)
        attn_output = self.c_proj(attn_output)

        return attn_output, present


class GPT2MLP(cnn.Module):
    def __init__(self, intermediate_size, config, layer_idx=None):
        super().__init__()
        self.layer_idx = layer_idx
        embed_dim = config.hidden_size
        self.c_fc = MPCConv1D(intermediate_size, embed_dim)
        self.c_proj = MPCConv1D(embed_dim, intermediate_size)
        self.act = PumaGeLU()
        self.dropout = MPCIdentity()


    def forward(self, hidden_states):
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)
        return hidden_states
    

class MPCConv1D(cnn.Linear):
    def __init__(self, nf, nx):
        super(MPCConv1D, self).__init__(nx, nf)
        self.nf = nf
        self.nx = nx

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = x.reshape(-1, x.size(-1)).matmul(self.weight.transpose(0,1)) + self.bias
        x = x.reshape(size_out)
        return x
