import torch
from torch import nn
import math
import torch.nn.functional as F


config = {
   'lr': 5e-5,
   'batch_size': 1024,
   'warmup_steps': 128,
   'd_model': 512,
   'total_epochs': 2 ** 15,
   'n_heads': 16,
   'n_layers': 4,
   'init_range': 0.2,
   'scalar_dropout': 0.1,
   'embed_dropout': 0.1,
   'final_dropout': 0.1,
   'pred_dropout': True,
   'pred_batchnorm': False,
   'pred_dropout_p': 0.1
}


class Embedder(nn.Module):
    def __init__(self, num_embeds, max_len, config=config):
        super().__init__()
        self.embed = nn.Embedding(num_embeds, config['d_model'], padding_idx=0)
        self.pos_enc = PositionalEncoding(config['d_model'], max_len)
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(config['d_model'], config['n_heads'], batch_first=True), config['n_layers'])
        self.fc = nn.Linear(config['d_model'], config['d_model'])
        self.affinity_embed = nn.Sequential(nn.Linear(1, config['d_model']))
        self.init_weights(config)
        self.config = config

    def init_weights(self, config):
        initrange = config['init_range']
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)
        for layer in self.affinity_embed:
            if isinstance(layer, nn.Linear):
                layer.bias.data.zero_()
                layer.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, scalar=None):
        embeds = self.pos_enc(self.embed(x))
        if scalar != None:
            embeds += F.dropout(torch.tile(self.affinity_embed(scalar).unsqueeze(1), (1, embeds.shape[1], 1)), self.config['scalar_dropout'], training=self.training)
        embeds = F.dropout(embeds, self.config['embed_dropout'], training=self.training)
        transformer_out = self.encoder(embeds, src_key_padding_mask=(x == 0))
        return F.dropout(self.fc(transformer_out[:, 0, :]), self.config['final_dropout'], training=self.training)


class Predictor(nn.Module):
    def __init__(self, config=config):
        super().__init__()
        in_dim = config['d_model'] * 2
        if config['pred_dropout']:
            if config['pred_batchnorm']:
                self.fc = nn.Sequential(nn.Linear(in_dim, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(config['pred_dropout_p']),
                                        nn.BatchNorm1d(2048),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(config['pred_dropout_p']),
                                        nn.BatchNorm1d(2048),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(config['pred_dropout_p']),
                                        nn.BatchNorm1d(2048),
                                        nn.Linear(2048, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1))
            else:
                self.fc = nn.Sequential(nn.Linear(in_dim, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(config['pred_dropout_p']),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(config['pred_dropout_p']),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(config['pred_dropout_p']),
                                        nn.Linear(2048, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1))
        else:
            if config['pred_batchnorm']:
                self.fc = nn.Sequential(nn.Linear(in_dim, 2048),
                                        nn.ReLU(),
                                        nn.BatchNorm1d(2048),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.BatchNorm1d(2048),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.BatchNorm1d(2048),
                                        nn.Linear(2048, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1))
            else:
                self.fc = nn.Sequential(nn.Linear(in_dim, 2048),
                                        nn.ReLU(),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Linear(2048, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1))

    def forward(self, x):
        return self.fc(x)


class ContextAttn(nn.Module):
    def __init__(self, config=config):
        super().__init__()
        self.attn = nn.TransformerEncoder(nn.TransformerEncoderLayer(config['d_model'], 8, batch_first=False), 4)
        self.fc = nn.Linear(config['d_model'], config['d_model'])

    def forward(self, x):
        return self.fc(self.attn(x).mean(0))


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe
