import numpy as np
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2Model
from torch.distributions.categorical import Categorical


def layer_init(layer, std=0.02, bias_const=0.0):
    torch.nn.init.normal_(layer.weight, std = std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

# MLP
class Agent(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Agent, self).__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(input_dim, 256)),
            nn.GELU(),
            layer_init(nn.Linear(256, 256)),
            nn.GELU(),
            layer_init(nn.Linear(256, output_dim), std=1.0),
        )

    def forward(self, x):
        return self.critic(x)

# Huggingface GPT2
class TFAgent(nn.Module):
    def __init__(self, args, input_dim, output_dim, config, mode = 'gpt2'):
        super(TFAgent, self).__init__()
        configuration = GPT2Config(
            n_positions=config.max_len,
            n_embd=config.embed_dim,
            n_layer=config.num_layers,
            n_head=config.num_heads,
            resid_pdrop=config.dropout,
            embd_pdrop=config.dropout,
            attn_pdrop=config.dropout,
            use_cache=False,
        )

        self.n_dims = input_dim
        self.output_dim = output_dim
        self._read_in = layer_init(nn.Linear(input_dim, config.embed_dim), std = 0.1)
        if mode == 'gpt2' :
            self._backbone = GPT2Model(configuration)
        else :
            self._backbone = Decoder(config)
        self._read_out = layer_init(nn.Linear(config.embed_dim, output_dim))

    def forward(self, x) :
        embeds = self._read_in(x)
        hidden = self._backbone(inputs_embeds=embeds).last_hidden_state
        return self._read_out(hidden)

# RNN
class RNNAgent(nn.Module) :
    def __init__(self, device, input_size, hidden_size, num_layers, output_size) :
        super(RNNAgent, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.device = device
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, nonlinearity='relu')
        self.linear = layer_init(nn.Linear(hidden_size, output_size))

    def forward(self, x) : # shape of x: (bs, length, input_dim)
        output, _ = self.rnn(x.to(self.device))
        output = self.linear(output)
        
        return output # shape of output: (bs, length, output_dim)
    
# LSTM
class LSTMAgent(nn.Module) :
    def __init__(self, device, input_size, hidden_size, num_layers, output_size) :
        super(LSTMAgent, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.device = device
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.linear = layer_init(nn.Linear(hidden_size, output_size))

    def forward(self, x) : # shape of x: (bs, length, input_dim)
        output, _ = self.lstm(x.to(self.device))
        output = self.linear(output)
        
        return output # shape of output: (bs, length, output_dim)

# handcrafted transformer
class AttentionBlock(nn.Module) :
    def __init__(self, config) -> None:
        super(AttentionBlock, self).__init__()
        self.embed_dim = config.embed_dim
        self.attn = nn.MultiheadAttention(config.embed_dim, config.num_heads, config.dropout, batch_first = True)
        if config.mlp_layers == 2 :
            self.mlp = nn.Sequential(
                layer_init(nn.Linear(config.embed_dim, 4 * config.embed_dim)),
                nn.GELU(),
                layer_init(nn.Linear(4 * config.embed_dim, config.embed_dim)),
                nn.Dropout(config.dropout),
            )
        else :
            assert config.mlp_layers == 3
            self.mlp = nn.Sequential(
                layer_init(nn.Linear(config.embed_dim, 4 * config.embed_dim)),
                nn.GELU(),
                layer_init(nn.Linear(4 * config.embed_dim, 4 * config.embed_dim)),
                nn.GELU(),
                layer_init(nn.Linear(4 * config.embed_dim, config.embed_dim)),
                nn.Dropout(config.dropout),
            )
        self.ln_1 = nn.LayerNorm(config.embed_dim) if config.used_ln else nn.Identity()
        self.ln_2 = nn.LayerNorm(config.embed_dim) if config.used_ln else nn.Identity()
    
    # post-LN architecture
    def forward(self, x) :
        y = self.ln_1(x)
        mask = nn.Transformer.generate_square_subsequent_mask(y.shape[1]).to(y.device)
        res = self.attn(y, y, y, is_causal = True, attn_mask = mask)
        x = x + res[0]
        return x + self.mlp(self.ln_2(x))

class PostionalEmbedding(nn.Module) :
    def __init__(self, config) -> None:
        super(PostionalEmbedding, self).__init__()
        if config.pos_embed_type == 'vanilla' :
            # sinusoidal positional embedding 
            self.pos_embd = nn.Parameter(torch.zeros(1, config.max_len, config.embed_dim), requires_grad = False)
            pos = torch.arange(0, config.max_len, dtype = torch.float32).unsqueeze(1)
            div = torch.exp(torch.arange(0, config.embed_dim, 2).float() * (-np.log(10000.0) / config.embed_dim))
            self.pos_embd[:, :, 0::2] = torch.sin(pos * div)
            self.pos_embd[:, :, 1::2] = torch.cos(pos * div)
        elif config.pos_embed_type == 'learnable' :
            self.pos_embd = nn.Parameter(torch.zeros(1, config.max_len, config.embed_dim))
        elif config.pos_embed_type == 'additive' :
            raise NotImplementedError
        else :
            assert config.pos_embed_type == 'none'
            self.pos_embd = nn.Parameter(torch.zeros(1, config.max_len, config.embed_dim), requires_grad = False)
            #TODO: check if pos_embed is fixed
    
    def forward(self, x) :
        return self.pos_embd[:, :x.shape[1], :]
    
class TFOutput() :
    def __init__(self, x) -> None:
        self.last_hidden_state = x
    
class Decoder(nn.Module) :
    def __init__(self, config) -> None:
        super(Decoder, self).__init__()
        self.config = config
        # no vocabulary embedding
        self.pos_embd = PostionalEmbedding(config)
        self.blocks = nn.Sequential(*[AttentionBlock(config) for _ in range(config.num_layers)])
        self.ln = nn.LayerNorm(config.embed_dim) if config.used_ln else nn.Identity()

    def forward(self, inputs_embeds) :
        x = inputs_embeds + self.pos_embd(inputs_embeds)
        x = self.ln(self.blocks(x))
        return TFOutput(x)

class TFConfig :
    def __init__(self, 
        embed_dim = 512,
        num_heads = 8,
        dropout = 0.0,
        mlp_layers = 2,
        used_ln = True,
        num_layers = 4,
        max_len = 256,
        pos_embed_type = 'learnable'
    ) -> None:
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.mlp_layers = mlp_layers
        self.used_ln = used_ln
        self.num_layers = num_layers
        self.max_len = max_len
        self.pos_embed_type = pos_embed_type